commit dc93ee7257b9a0eed1a4632dec3c7b16a65e9782 Author: Damian Johnson atagar@torproject.org Date: Thu Jun 25 16:13:54 2020 -0700
Use Synchronous for Controller
Finally migrating our Controller class from Illia's AsyncClassWrapper to our Synchronous mixin.
Benefits are...
* Class no longer requires a synchronous and asynchronous copy.
* Controller can be implemented as a fully asynchronous class, while still functioning in synchronous contexts.
Downside is...
* Python type checkers (like mypy) only recognice our Controller as an asynchronous class, producing false positives for synchronous users. --- run_tests.py | 4 + stem/connection.py | 9 +- stem/control.py | 579 +++++----------------- stem/descriptor/remote.py | 8 +- stem/interpreter/__init__.py | 2 +- stem/interpreter/commands.py | 4 +- stem/util/__init__.py | 43 +- stem/util/test_tools.py | 3 + test/integ/connection/authentication.py | 13 +- test/integ/control/controller.py | 818 ++++++++++++++++++-------------- test/runner.py | 64 ++- test/settings.cfg | 8 + test/unit/control/controller.py | 119 +++-- test/unit/descriptor/remote.py | 2 +- 14 files changed, 715 insertions(+), 961 deletions(-)
diff --git a/run_tests.py b/run_tests.py index c738f9b8..5218008f 100755 --- a/run_tests.py +++ b/run_tests.py @@ -259,6 +259,10 @@ def main(): # 2.7 or later because before that test results didn't have a 'skipped' # attribute.
+ # TODO: handling of earlier python versions is no longer necessary here + # TODO: this invokes all asynchronous tests, even if we have a --test or + # --exclude-test argument + skipped_tests = 0
if args.run_integ: diff --git a/stem/connection.py b/stem/connection.py index 86d32d7f..8495da2a 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -89,7 +89,7 @@ fine-grained control over the authentication process. For instance... ::
connect - Simple method for getting authenticated control connection for synchronous usage. - async_connect - Simple method for getting authenticated control connection for asynchronous usage. + async_connect - Simple method for getting authenticated control connection for asynchronous usage.
authenticate - Main method for authenticating to a control socket authenticate_none - Authenticates to an open control socket @@ -292,7 +292,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') raise
-async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.AsyncController) -> Any: +async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.Controller) -> Any: """ Convenience function for quickly getting a control connection for asynchronous usage. This is very handy for debugging or CLI setup, handling @@ -364,6 +364,7 @@ async def _connect_async(control_port: Tuple[str, Union[str, int]], control_sock control_connection = _connection_for_default_port(address) else: control_connection = stem.socket.ControlPort(address, int(port)) + await control_connection.connect() except stem.SocketError as exc: error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc) @@ -405,9 +406,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None: return control_socket - elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller): - # TODO: Controller no longer extends BaseController (we'll probably change that) - + else: return controller(control_socket, is_authenticated = True) except IncorrectSocketType: if isinstance(control_socket, stem.socket.ControlPort): diff --git a/stem/control.py b/stem/control.py index 47ddaa35..7b90eed0 100644 --- a/stem/control.py +++ b/stem/control.py @@ -269,9 +269,9 @@ import stem.util.tor_tools import stem.version
from stem import UNDEFINED, CircStatus, Signal -from stem.util import log +from stem.util import Synchronous, log from types import TracebackType -from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events, # but if it takes longer than this we terminate. @@ -553,56 +553,7 @@ def event_description(event: str) -> str: return EVENT_DESCRIPTIONS.get(event.lower())
-class _BaseControllerSocketMixin: - _socket: stem.socket.ControlSocket - - def is_alive(self) -> bool: - """ - Checks if our socket is currently connected. This is a pass-through for our - socket's :func:`~stem.socket.BaseSocket.is_alive` method. - - :returns: **bool** that's **True** if our socket is connected and **False** otherwise - """ - - return self._socket.is_alive() - - def is_localhost(self) -> bool: - """ - Returns if the connection is for the local system or not. - - .. versionadded:: 1.3.0 - - :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise - """ - - return self._socket.is_localhost() - - def connection_time(self) -> float: - """ - Provides the unix timestamp for when our socket was either connected or - disconnected. That is to say, the time we connected if we're currently - connected and the time we disconnected if we're not connected. - - .. versionadded:: 1.3.0 - - :returns: **float** for when we last connected or disconnected, zero if - we've never connected - """ - - return self._socket.connection_time() - - def get_socket(self) -> stem.socket.ControlSocket: - """ - Provides the socket used to speak with the tor process. Communicating with - the socket directly isn't advised since it may confuse this controller. - - :returns: :class:`~stem.socket.ControlSocket` we're communicating with - """ - - return self._socket - - -class BaseController(_BaseControllerSocketMixin): +class BaseController(Synchronous): """ Controller for the tor process. This is a minimal base class for other controllers, providing basic process communication and event listing. Don't @@ -619,21 +570,13 @@ class BaseController(_BaseControllerSocketMixin): """
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: - self._socket = control_socket + super(BaseController, self).__init__()
- self._asyncio_loop = asyncio.get_event_loop() - - self._msg_lock = asyncio.Lock() + self._socket = control_socket
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread) self._status_listeners_lock = threading.RLock()
- # queues where incoming messages are directed - self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]] - self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage] - - self._event_notice = asyncio.Event() - # saves our socket's prior _connect() and _close() methods so they can be # called along with ours
@@ -650,11 +593,22 @@ class BaseController(_BaseControllerSocketMixin):
self._reader_loop_task = None # type: Optional[asyncio.Task] self._event_loop_task = None # type: Optional[asyncio.Task] + if self._socket.is_alive(): self._create_loop_tasks()
if is_authenticated: - self._asyncio_loop.create_task(self._post_authentication()) + self._loop.create_task(self._post_authentication()) + + def __ainit__(self) -> None: + self._msg_lock = asyncio.Lock() + + # queues where incoming messages are directed + + self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]] + self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage] + + self._event_notice = asyncio.Event()
async def msg(self, message: str) -> stem.response.ControlMessage: """ @@ -736,8 +690,43 @@ class BaseController(_BaseControllerSocketMixin): # provide an assurance to the caller that when we raise a SocketClosed # exception we are shut down afterward for realz.
- await self.close() - raise + await self.close() + raise + + def is_alive(self) -> bool: + """ + Checks if our socket is currently connected. This is a pass-through for our + socket's :func:`~stem.socket.BaseSocket.is_alive` method. + + :returns: **bool** that's **True** if our socket is connected and **False** otherwise + """ + + return self._socket.is_alive() + + def is_localhost(self) -> bool: + """ + Returns if the connection is for the local system or not. + + .. versionadded:: 1.3.0 + + :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise + """ + + return self._socket.is_localhost() + + def connection_time(self) -> float: + """ + Provides the unix timestamp for when our socket was either connected or + disconnected. That is to say, the time we connected if we're currently + connected and the time we disconnected if we're not connected. + + .. versionadded:: 1.3.0 + + :returns: **float** for when we last connected or disconnected, zero if + we've never connected + """ + + return self._socket.connection_time()
def is_authenticated(self) -> bool: """ @@ -778,6 +767,18 @@ class BaseController(_BaseControllerSocketMixin): if t.is_alive() and threading.current_thread() != t: t.join()
+ self.stop() + + def get_socket(self) -> stem.socket.ControlSocket: + """ + Provides the socket used to speak with the tor process. Communicating with + the socket directly isn't advised since it may confuse this controller. + + :returns: :class:`~stem.socket.ControlSocket` we're communicating with + """ + + return self._socket + def get_latest_heartbeat(self) -> float: """ Provides the unix timestamp for when we last heard from tor. This is zero @@ -858,7 +859,7 @@ class BaseController(_BaseControllerSocketMixin):
async def _connect(self) -> None: self._create_loop_tasks() - await self._notify_status_listeners(State.INIT, acquire_send_lock=False) + await self._notify_status_listeners(State.INIT, acquire_send_lock = False) await self._socket_connect() self._is_authenticated = False
@@ -874,13 +875,14 @@ class BaseController(_BaseControllerSocketMixin): self._reader_loop_task = None event_loop_task = self._event_loop_task self._event_loop_task = None + if reader_loop_task and self.is_alive(): await reader_loop_task + if event_loop_task: await event_loop_task
- await self._notify_status_listeners(State.CLOSED, acquire_send_lock=False) - + await self._notify_status_listeners(State.CLOSED, acquire_send_lock = False) await self._socket_close()
async def _post_authentication(self) -> None: @@ -899,6 +901,7 @@ class BaseController(_BaseControllerSocketMixin): # need to have it to ensure it doesn't change beneath us.
send_lock = self._socket._get_send_lock() + try: if acquire_send_lock: await send_lock.acquire() @@ -944,8 +947,8 @@ class BaseController(_BaseControllerSocketMixin): them if we're restarted. """
- self._reader_loop_task = self._asyncio_loop.create_task(self._reader_loop()) - self._event_loop_task = self._asyncio_loop.create_task(self._event_loop()) + self._reader_loop_task = self._loop.create_task(self._reader_loop()) + self._event_loop_task = self._loop.create_task(self._event_loop())
async def _reader_loop(self) -> None: """ @@ -1011,21 +1014,24 @@ class BaseController(_BaseControllerSocketMixin): self._event_notice.clear()
-class AsyncController(BaseController): +class Controller(BaseController): """ Connection with Tor's control socket. This is built on top of the BaseController and provides a more user friendly API for library users. """
- @classmethod - def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController': + @staticmethod + def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller': """ - Constructs a :class:`~stem.socket.ControlPort` based AsyncController. + Constructs a :class:`~stem.socket.ControlPort` based Controller.
If the **port** is **'default'** then this checks on both 9051 (default for relays) and 9151 (default for the Tor Browser). This default may change in the future.
+ .. versionchanged:: 1.5.0 + Use both port 9051 and 9151 by default. + :param address: ip address of the controller :param port: port number of the controller
@@ -1034,13 +1040,31 @@ class AsyncController(BaseController): :raises: :class:`stem.SocketError` if we're unable to establish a connection """
- control_socket = _init_control_port(address, port) - return cls(control_socket) + import stem.connection + + if not stem.util.connection.is_valid_ipv4_address(address): + raise ValueError('Invalid IP address: %s' % address) + elif port != 'default' and not stem.util.connection.is_valid_port(port): + raise ValueError('Invalid port: %s' % port) + + if port == 'default': + control_port = stem.connection._connection_for_default_port(address) + else: + control_port = stem.socket.ControlPort(address, int(port)) + + controller = Controller(control_port)
- @classmethod - def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController': + try: + controller.connect() + return controller + except: + controller.stop() + raise + + @staticmethod + def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller': """ - Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController. + Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
:param path: path where the control socket is located
@@ -1049,8 +1073,15 @@ class AsyncController(BaseController): :raises: :class:`stem.SocketError` if we're unable to establish a connection """
- control_socket = _init_control_socket_file(path) - return cls(control_socket) + control_socket = stem.socket.ControlSocketFile(path) + controller = Controller(control_socket) + + try: + controller.connect() + return controller + except: + controller.stop() + raise
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._is_caching_enabled = True @@ -1062,13 +1093,12 @@ class AsyncController(BaseController): # mapping of event types to their listeners
self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]]] - self._event_listeners_lock = asyncio.Lock() self._enabled_features = [] # type: List[str]
self._last_address_exc = None # type: Optional[BaseException] self._last_fingerprint_exc = None # type: Optional[BaseException]
- super(AsyncController, self).__init__(control_socket, is_authenticated) + super(Controller, self).__init__(control_socket, is_authenticated)
async def _sighup_listener(event: stem.response.events.SignalEvent) -> None: if event.signal == Signal.RELOAD: @@ -1101,11 +1131,16 @@ class AsyncController(BaseController): self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER), )
- self._asyncio_loop.create_task(_add_event_listeners()) + self._loop.create_task(_add_event_listeners()) + + def __ainit__(self): + super(Controller, self).__ainit__() + + self._event_listeners_lock = asyncio.Lock()
async def close(self) -> None: self.clear_cache() - await super(AsyncController, self).close() + await super(Controller, self).close()
async def authenticate(self, *args: Any, **kwargs: Any) -> None: """ @@ -1186,7 +1221,7 @@ class AsyncController(BaseController): raise stem.ProtocolError('Tor geoip database is unavailable') elif param == 'address' and self._last_address_exc: raise self._last_address_exc # we already know we can't resolve an address - elif param == 'fingerprint' and self._last_fingerprint_exc and self.get_conf('ORPort', None) is None: + elif param == 'fingerprint' and self._last_fingerprint_exc and await self.get_conf('ORPort', None) is None: raise self._last_fingerprint_exc # we already know we're not a relay
# check for cached results @@ -2082,7 +2117,6 @@ class AsyncController(BaseController): request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
response = stem.response._convert_to_single_line(await self.msg(request)) - stem.response.convert('SINGLELINE', response)
if not response.is_ok(): raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code) @@ -3778,7 +3812,7 @@ class AsyncController(BaseController): await self.msg('DROPGUARDS')
async def _post_authentication(self) -> None: - await super(AsyncController, self)._post_authentication() + await super(Controller, self)._post_authentication()
# try to re-attach event listeners to the new instance
@@ -3834,9 +3868,10 @@ class AsyncController(BaseController): if listener_type == event_type: for listener in event_listeners: try: - potential_coroutine = listener(event) - if asyncio.iscoroutine(potential_coroutine): - await potential_coroutine + listener_call = listener(event) + + if asyncio.iscoroutine(listener_call): + await listener_call except Exception as exc: log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
@@ -3883,346 +3918,6 @@ class AsyncController(BaseController): return (set_events, failed_events)
-def _set_doc_from_async_controller(func: Callable) -> Callable: - func.__doc__ = getattr(AsyncController, func.__name__).__doc__ - return func - - -class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): - """ - Connection with Tor's control socket. This wraps - :class:`~stem.control.AsyncController` to provide a synchronous - interface and for backwards compatibility. - """ - - @classmethod - def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller': - """ - Constructs a :class:`~stem.socket.ControlPort` based Controller. - - If the **port** is **'default'** then this checks on both 9051 (default - for relays) and 9151 (default for the Tor Browser). This default may change - in the future. - - .. versionchanged:: 1.5.0 - Use both port 9051 and 9151 by default. - - :param address: ip address of the controller - :param port: port number of the controller - - :returns: :class:`~stem.control.Controller` attached to the given port - - :raises: :class:`stem.SocketError` if we're unable to establish a connection - """ - - control_socket = _init_control_port(address, port) - controller = cls(control_socket) - controller.connect() - return controller - - @classmethod - def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller': - """ - Constructs a :class:`~stem.socket.ControlSocketFile` based Controller. - - :param str path: path where the control socket is located - - :returns: :class:`~stem.control.Controller` attached to the given socket file - - :raises: :class:`stem.SocketError` if we're unable to establish a connection - """ - - control_socket = _init_control_socket_file(path) - controller = cls(control_socket) - controller.connect() - return controller - - def __init__( - self, - control_socket: stem.socket.ControlSocket, - is_authenticated: bool = False, - ) -> None: - # if within an asyncio context use its loop, otherwise spawn our own - - try: - self._loop = asyncio.get_running_loop() - self._loop_thread = threading.current_thread() - except RuntimeError: - self._loop = asyncio.new_event_loop() - self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio') - self._loop_thread.setDaemon(True) - self._loop_thread.start() - - self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore - self._socket = self._wrapped_instance._socket - - @_set_doc_from_async_controller - def msg(self, message: str) -> stem.response.ControlMessage: - return self._execute_async_method('msg', message) - - @_set_doc_from_async_controller - def is_authenticated(self) -> bool: - return self._wrapped_instance.is_authenticated() - - @_set_doc_from_async_controller - def connect(self) -> None: - self._execute_async_method('connect') - - @_set_doc_from_async_controller - def reconnect(self, *args: Any, **kwargs: Any) -> None: - self._execute_async_method('reconnect', *args, **kwargs) - - @_set_doc_from_async_controller - def close(self) -> None: - self._execute_async_method('close') - - @_set_doc_from_async_controller - def get_latest_heartbeat(self) -> float: - return self._wrapped_instance.get_latest_heartbeat() - - @_set_doc_from_async_controller - def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None: - self._wrapped_instance.add_status_listener(callback, spawn) - - @_set_doc_from_async_controller - def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool: - return self._wrapped_instance.remove_status_listener(callback) - - @_set_doc_from_async_controller - def authenticate(self, *args: Any, **kwargs: Any) -> None: - self._execute_async_method('authenticate', *args, **kwargs) - - @_set_doc_from_async_controller - def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]: - return self._execute_async_method('get_info', params, default, get_bytes) - - @_set_doc_from_async_controller - def get_version(self, default: Any = UNDEFINED) -> stem.version.Version: - return self._execute_async_method('get_version', default) - - @_set_doc_from_async_controller - def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy: - return self._execute_async_method('get_exit_policy', default) - - @_set_doc_from_async_controller - def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]: - return self._execute_async_method('get_ports', listener_type, default) - - @_set_doc_from_async_controller - def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]: - return self._execute_async_method('get_listeners', listener_type, default) - - @_set_doc_from_async_controller - def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats': - return self._execute_async_method('get_accounting_stats', default) - - @_set_doc_from_async_controller - def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse: - return self._execute_async_method('get_protocolinfo', default) - - @_set_doc_from_async_controller - def get_user(self, default: Any = UNDEFINED) -> str: - return self._execute_async_method('get_user', default) - - @_set_doc_from_async_controller - def get_pid(self, default: Any = UNDEFINED) -> int: - return self._execute_async_method('get_pid', default) - - @_set_doc_from_async_controller - def get_start_time(self, default: Any = UNDEFINED) -> float: - return self._execute_async_method('get_start_time', default) - - @_set_doc_from_async_controller - def get_uptime(self, default: Any = UNDEFINED) -> float: - return self._execute_async_method('get_uptime', default) - - @_set_doc_from_async_controller - def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed': - return self._execute_async_method('is_user_traffic_allowed') - - @_set_doc_from_async_controller - def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor: - return self._execute_async_method('get_microdescriptor', relay, default) - - @_set_doc_from_async_controller - def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: - return self._execute_async_generator_method('get_microdescriptors', default) - - @_set_doc_from_async_controller - def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor: - return self._execute_async_method('get_server_descriptor', relay, default) - - @_set_doc_from_async_controller - def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: - return self._execute_async_generator_method('get_server_descriptors', default) - - @_set_doc_from_async_controller - def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3: - return self._execute_async_method('get_network_status', relay, default) - - @_set_doc_from_async_controller - def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]: - return self._execute_async_generator_method('get_network_statuses', default) - - @_set_doc_from_async_controller - def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: - return self._execute_async_method('get_hidden_service_descriptor', address, default, servers, await_result, timeout) - - @_set_doc_from_async_controller - def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]: - return self._execute_async_method('get_conf', param, default, multiple) - - @_set_doc_from_async_controller - def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: - return self._execute_async_method('get_conf_map', params, default, multiple) - - @_set_doc_from_async_controller - def is_set(self, param: str, default: Any = UNDEFINED) -> bool: - return self._execute_async_method('is_set', param, default) - - @_set_doc_from_async_controller - def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None: - self._execute_async_method('set_conf', param, value) - - @_set_doc_from_async_controller - def reset_conf(self, *params: str) -> None: - self._execute_async_method('reset_conf', *params) - - @_set_doc_from_async_controller - def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None: - self._execute_async_method('set_options', params, reset) - - @_set_doc_from_async_controller - def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]: - return self._execute_async_method('get_hidden_service_conf', default) - - @_set_doc_from_async_controller - def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None: - self._execute_async_method('set_hidden_service_conf', conf) - - @_set_doc_from_async_controller - def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput': - return self._execute_async_method('create_hidden_service', path, port, target_address, target_port, auth_type, client_names) - - @_set_doc_from_async_controller - def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool: - return self._execute_async_method('remove_hidden_service', path, port) - - @_set_doc_from_async_controller - def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]: - return self._execute_async_method('list_ephemeral_hidden_services', default, our_services, detached) - - @_set_doc_from_async_controller - def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse: - return self._execute_async_method('create_ephemeral_hidden_service', ports, key_type, key_content, discard_key, detached, await_publication, timeout, basic_auth, max_streams) - - @_set_doc_from_async_controller - def remove_ephemeral_hidden_service(self, service_id: str) -> bool: - return self._execute_async_method('remove_ephemeral_hidden_service', service_id) - - @_set_doc_from_async_controller - def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None: - self._execute_async_method('add_event_listener', listener, *events) - - @_set_doc_from_async_controller - def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None: - self._execute_async_method('remove_event_listener', listener) - - @_set_doc_from_async_controller - def is_caching_enabled(self) -> bool: - return self._wrapped_instance.is_caching_enabled() - - @_set_doc_from_async_controller - def set_caching(self, enabled: bool) -> None: - self._wrapped_instance.set_caching(enabled) - - @_set_doc_from_async_controller - def clear_cache(self) -> None: - self._wrapped_instance.clear_cache() - - @_set_doc_from_async_controller - def load_conf(self, configtext: str) -> None: - self._execute_async_method('load_conf', configtext) - - @_set_doc_from_async_controller - def save_conf(self, force: bool = False) -> None: - return self._execute_async_method('save_conf', force) - - @_set_doc_from_async_controller - def is_feature_enabled(self, feature: str) -> bool: - return self._wrapped_instance.is_feature_enabled(feature) - - @_set_doc_from_async_controller - def enable_feature(self, features: Union[str, Sequence[str]]) -> None: - self._wrapped_instance.enable_feature(features) - - @_set_doc_from_async_controller - def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent: - return self._execute_async_method('get_circuit', circuit_id, default) - - @_set_doc_from_async_controller - def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]: - return self._execute_async_method('get_circuits', default) - - @_set_doc_from_async_controller - def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: - return self._execute_async_method('new_circuit', path, purpose, await_build, timeout) - - @_set_doc_from_async_controller - def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: - return self._execute_async_method('extend_circuit', circuit_id, path, purpose, await_build, timeout) - - @_set_doc_from_async_controller - def repurpose_circuit(self, circuit_id: str, purpose: str) -> None: - self._execute_async_method('repurpose_circuit', circuit_id, purpose) - - @_set_doc_from_async_controller - def close_circuit(self, circuit_id: str, flag: str = '') -> None: - self._execute_async_method('close_circuit', circuit_id, flag) - - @_set_doc_from_async_controller - def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]: - return self._execute_async_method('get_streams', default) - - @_set_doc_from_async_controller - def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None: - self._execute_async_method('attach_stream', stream_id, circuit_id, exiting_hop) - - @_set_doc_from_async_controller - def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None: - self._execute_async_method('close_stream', stream_id, reason, flag) - - @_set_doc_from_async_controller - def signal(self, signal: stem.Signal) -> None: - self._execute_async_method('signal', signal) - - @_set_doc_from_async_controller - def is_newnym_available(self) -> bool: - return self._wrapped_instance.is_newnym_available() - - @_set_doc_from_async_controller - def get_newnym_wait(self) -> float: - return self._wrapped_instance.get_newnym_wait() - - @_set_doc_from_async_controller - def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int: - return self._execute_async_method('get_effective_rate', default, burst) - - @_set_doc_from_async_controller - def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]: - return self._execute_async_method('map_address', mapping) - - @_set_doc_from_async_controller - def drop_guards(self) -> None: - self._execute_async_method('drop_guards') - - def __enter__(self) -> 'stem.control.Controller': - return self - - def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: - self.close() - - def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]: """ Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor @@ -4342,26 +4037,6 @@ async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float] time_left = None
try: - return await asyncio.wait_for(event_queue.get(), timeout=time_left) + return await asyncio.wait_for(event_queue.get(), timeout = time_left) except asyncio.TimeoutError: raise stem.Timeout('Reached our %0.1f second timeout' % timeout) - - -def _init_control_port(address: str, port: Union[int, str]) -> stem.socket.ControlPort: - import stem.connection - - if not stem.util.connection.is_valid_ipv4_address(address): - raise ValueError('Invalid IP address: %s' % address) - elif port != 'default' and not stem.util.connection.is_valid_port(port): - raise ValueError('Invalid port: %s' % port) - - if port == 'default': - control_port = stem.connection._connection_for_default_port(address) - else: - control_port = stem.socket.ControlPort(address, int(port)) - - return control_port - - -def _init_control_socket_file(path: str) -> stem.socket.ControlSocketFile: - return stem.socket.ControlSocketFile(path) diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index 942d81e9..3428f0d2 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -445,13 +445,13 @@ class Query(Synchronous): loop = asyncio.get_running_loop() self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
- async def run(self, suppress: bool = False, close: bool = True) -> List['stem.descriptor.Descriptor']: + async def run(self, suppress: bool = False, stop: bool = True) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so.
:param suppress: avoids raising exceptions if **True** - :param close: terminates the resources backing this query if **True**, + :param stop: terminates the resources backing this query if **True**, further method calls will raise a RuntimeError
:returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances @@ -465,14 +465,14 @@ class Query(Synchronous): * :class:`~stem.DownloadFailed` if our request fails """
- # TODO: We should replace our 'close' argument with a new API design prior + # TODO: We should replace our 'stop' argument with a new API design prior # to release. Self-destructing this object by default for synchronous users # is quite a step backward, but is acceptable as we iterate on this.
try: return [desc async for desc in self._run(suppress)] finally: - if close: + if stop: self._loop.call_soon_threadsafe(self._loop.stop)
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]: diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index 370b9aa6..872e7f6f 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -127,7 +127,7 @@ def main() -> None: async def handle_event(event_message: stem.response.ControlMessage) -> None: print(format(str(event_message), *STANDARD_OUTPUT))
- controller._wrapped_instance._handle_event = handle_event # type: ignore + controller._handle_event = handle_event # type: ignore
if sys.stdout.isatty(): events = args.run_cmd.upper().split(' ', 1)[1] diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index 99f1219d..b04fcc85 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole): # Intercept events our controller hears about at a pretty low level since # the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._wrapped_instance._handle_event + handle_event_real = self._controller._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None: await handle_event_real(event_message) @@ -139,7 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._wrapped_instance._handle_event = handle_event_wrapper # type: ignore + self._controller._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]: events = list(self._received_events) diff --git a/stem/util/__init__.py b/stem/util/__init__.py index d780a0de..de946fd9 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -13,7 +13,6 @@ import threading import typing import unittest.mock
-from concurrent.futures import Future from types import TracebackType from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
@@ -211,6 +210,7 @@ class Synchronous(object): self._no_op = Synchronous.is_asyncio_context()
if self._no_op: + self._loop = asyncio.get_running_loop() self.__ainit__() # this is already an asyncio context else: # Run coroutines through our loop. This calls methods by name rather than @@ -361,44 +361,3 @@ class Synchronous(object):
def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): return self._run_async_method('__aexit__', exit_type, value, traceback) - - -class AsyncClassWrapper: - _loop: asyncio.AbstractEventLoop - _loop_thread: threading.Thread - _wrapped_instance: type - - def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any: - # The asynchronous class should be initialized in the thread where - # its methods will be executed. - if self._loop_thread != threading.current_thread(): - async def init(): - return async_class(*args, **kwargs) - - return asyncio.run_coroutine_threadsafe(init(), self._loop).result() - - return async_class(*args, **kwargs) - - def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future: - return asyncio.run_coroutine_threadsafe( - getattr(self._wrapped_instance, method_name)(*args, **kwargs), - self._loop, - ) - - def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: - return self._call_async_method_soon(method_name, *args, **kwargs).result() - - def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Iterator: - async def convert_async_generator(generator: AsyncIterator) -> Iterator: - return iter([d async for d in generator]) - - return asyncio.run_coroutine_threadsafe( - convert_async_generator( - getattr(self._wrapped_instance, method_name)(*args, **kwargs), - ), - self._loop, - ).result() - - def __del__(self) -> None: - self._loop.call_soon_threadsafe(self._loop.stop) - self._loop_thread.join() diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py index 67133195..ac9f8b88 100644 --- a/stem/util/test_tools.py +++ b/stem/util/test_tools.py @@ -696,11 +696,14 @@ def async_test(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: loop = asyncio.new_event_loop() + try: result = loop.run_until_complete(func(*args, **kwargs)) finally: loop.close() + return result + return wrapper
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py index 042d3939..bbca5e43 100644 --- a/test/integ/connection/authentication.py +++ b/test/integ/connection/authentication.py @@ -3,7 +3,6 @@ Integration tests for authenticating to the control socket via stem.connection.authenticate* functions. """
-import asyncio import os import unittest
@@ -121,11 +120,8 @@ class TestAuthenticate(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller: - asyncio.run_coroutine_threadsafe( - stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()), - controller._loop, - ).result() + async with await runner.get_tor_controller(False) as controller: + await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot()) await test.runner.exercise_controller(self, controller)
@test.require.controller @@ -276,7 +272,8 @@ class TestAuthenticate(unittest.TestCase): await self._check_auth(auth_type, auth_value)
@test.require.controller - def test_wrong_password_with_controller(self): + @async_test + async def test_wrong_password_with_controller(self): """ We ran into a race condition where providing the wrong password to the Controller caused inconsistent responses. Checking for that... @@ -290,7 +287,7 @@ class TestAuthenticate(unittest.TestCase): self.skipTest('(requires only password auth)')
for i in range(10): - with runner.get_tor_controller(False) as controller: + async with await runner.get_tor_controller(False) as controller: with self.assertRaises(stem.connection.IncorrectPassword): controller.authenticate('wrong_password')
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py index 7853d407..33b62ea7 100644 --- a/test/integ/control/controller.py +++ b/test/integ/control/controller.py @@ -7,7 +7,6 @@ import os import shutil import socket import tempfile -import threading import time import unittest
@@ -38,24 +37,25 @@ TEST_ROUTER_STATUS_ENTRY = None class TestController(unittest.TestCase): @test.require.only_run_once @test.require.controller - def test_missing_capabilities(self): + @async_test + async def test_missing_capabilities(self): """ Check to see if tor supports any events, signals, or features that we don't. """
- with test.runner.get_runner().get_tor_controller() as controller: - for event in controller.get_info('events/names').split(): + async with await test.runner.get_runner().get_tor_controller() as controller: + for event in (await controller.get_info('events/names')).split(): if event not in EventType: test.register_new_capability('Event', event)
- for signal in controller.get_info('signal/names').split(): + for signal in (await controller.get_info('signal/names')).split(): if signal not in Signal: test.register_new_capability('Signal', signal)
# new features should simply be added to enable_feature()'s docs
- for feature in controller.get_info('features/names').split(): + for feature in (await controller.get_info('features/names')).split(): if feature not in ('EXTENDED_EVENTS', 'VERBOSE_NAMES'): test.register_new_capability('Feature', feature)
@@ -88,7 +88,7 @@ class TestController(unittest.TestCase): Checks that a notificiation listener is... well, notified of SIGHUPs. """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: received_events = []
def status_listener(my_controller, state, timestamp): @@ -97,7 +97,7 @@ class TestController(unittest.TestCase): controller.add_status_listener(status_listener)
before = time.time() - controller.signal(Signal.HUP) + await controller.signal(Signal.HUP)
# I really hate adding a sleep here, but signal() is non-blocking. while len(received_events) == 0: @@ -112,20 +112,21 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._wrapped_instance, state_controller) + self.assertEqual(controller, state_controller) self.assertEqual(State.RESET, state_type) self.assertTrue(state_timestamp > before and state_timestamp < after)
- controller.reset_conf('__OwningControllerProcess') + await controller.reset_conf('__OwningControllerProcess')
@test.require.controller - def test_event_handling(self): + @async_test + async def test_event_handling(self): """ Add a couple listeners for various events and make sure that they receive them. Then remove the listeners. """
- event_notice1, event_notice2 = threading.Event(), threading.Event() + event_notice1, event_notice2 = asyncio.Event(), asyncio.Event() event_buffer1, event_buffer2 = [], []
def listener1(event): @@ -138,30 +139,30 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - controller.add_event_listener(listener1, EventType.CONF_CHANGED) - controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG) + async with await runner.get_tor_controller() as controller: + await controller.add_event_listener(listener1, EventType.CONF_CHANGED) + await controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
# The NodeFamily is a harmless option we can toggle - controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33') + await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
# Wait for the event. Assert that we get it within 10 seconds - event_notice1.wait(10) + await asyncio.wait_for(event_notice1.wait(), timeout = 10) self.assertEqual(len(event_buffer1), 1) event_notice1.clear()
- event_notice2.wait(10) + await asyncio.wait_for(event_notice2.wait(), timeout = 10) self.assertTrue(len(event_buffer2) >= 1) event_notice2.clear()
# Checking that a listener's no longer called after being removed.
- controller.remove_event_listener(listener2) + await controller.remove_event_listener(listener2)
buffer2_size = len(event_buffer2)
- controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401') - event_notice1.wait(10) + await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401') + await asyncio.wait_for(event_notice1.wait(), timeout = 10) self.assertEqual(len(event_buffer1), 2) event_notice1.clear()
@@ -174,16 +175,17 @@ class TestController(unittest.TestCase):
self.assertTrue(isinstance(event, stem.response.events.ConfChangedEvent))
- controller.reset_conf('NodeFamily') + await controller.reset_conf('NodeFamily')
@test.require.controller - def test_reattaching_listeners(self): + @async_test + async def test_reattaching_listeners(self): """ Checks that event listeners are re-attached when a controller disconnects then reconnects to tor. """
- event_notice = threading.Event() + event_notice = asyncio.Event() event_buffer = []
def listener(event): @@ -192,79 +194,85 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - controller.add_event_listener(listener, EventType.CONF_CHANGED) + async with await runner.get_tor_controller() as controller: + await controller.add_event_listener(listener, EventType.CONF_CHANGED)
# trigger an event
- controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33') - event_notice.wait(4) + await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33') + await asyncio.wait_for(event_notice.wait(), timeout = 4) self.assertTrue(len(event_buffer) >= 1)
# disconnect, then reconnect and check that we get events again
- controller.close() + await controller.close() event_notice.clear() event_buffer = []
- controller.connect() - controller.authenticate(password = test.runner.CONTROL_PASSWORD) + await controller.connect() + await controller.authenticate(password = test.runner.CONTROL_PASSWORD) self.assertTrue(len(event_buffer) == 0) - controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401') + await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
- event_notice.wait(4) + await asyncio.wait_for(event_notice.wait(), timeout = 4) self.assertTrue(len(event_buffer) >= 1)
- controller.reset_conf('NodeFamily') + await controller.reset_conf('NodeFamily')
@test.require.controller - def test_getinfo(self): + @async_test + async def test_getinfo(self): """ Exercises GETINFO with valid and invalid queries. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: # successful single query
torrc_path = runner.get_torrc_path() - self.assertEqual(torrc_path, controller.get_info('config-file')) - self.assertEqual(torrc_path, controller.get_info('config-file', 'ho hum')) + self.assertEqual(torrc_path, await controller.get_info('config-file')) + self.assertEqual(torrc_path, await controller.get_info('config-file', 'ho hum'))
expected = {'config-file': torrc_path} - self.assertEqual(expected, controller.get_info(['config-file'])) - self.assertEqual(expected, controller.get_info(['config-file'], 'ho hum')) + self.assertEqual(expected, await controller.get_info(['config-file'])) + self.assertEqual(expected, await controller.get_info(['config-file'], 'ho hum'))
# successful batch query, we don't know the values so just checking for # the keys
getinfo_params = set(['version', 'config-file', 'config/names']) - self.assertEqual(getinfo_params, set(controller.get_info(['version', 'config-file', 'config/names']).keys())) + self.assertEqual(getinfo_params, set((await controller.get_info(['version', 'config-file', 'config/names'])).keys()))
# non-existant option
- self.assertRaises(stem.ControllerError, controller.get_info, 'blarg') - self.assertEqual('ho hum', controller.get_info('blarg', 'ho hum')) + with self.assertRaises(stem.ControllerError): + await controller.get_info('blarg') + + self.assertEqual('ho hum', await controller.get_info('blarg', 'ho hum'))
# empty input
- self.assertRaises(stem.ControllerError, controller.get_info, '') - self.assertEqual('ho hum', controller.get_info('', 'ho hum')) + with self.assertRaises(stem.ControllerError): + await controller.get_info('') + + self.assertEqual('ho hum', await controller.get_info('', 'ho hum'))
- self.assertEqual({}, controller.get_info([])) - self.assertEqual({}, controller.get_info([], {})) + self.assertEqual({}, await controller.get_info([])) + self.assertEqual({}, await controller.get_info([], {}))
@test.require.controller - def test_getinfo_freshrelaydescs(self): + @async_test + async def test_getinfo_freshrelaydescs(self): """ Exercises 'GETINFO status/fresh-relay-descs'. """
- with test.runner.get_runner().get_tor_controller() as controller: - response = controller.get_info('status/fresh-relay-descs') + async with await test.runner.get_runner().get_tor_controller() as controller: + response = await controller.get_info('status/fresh-relay-descs') div = response.find('\nextra-info ') - nickname = controller.get_conf('Nickname') + nickname = await controller.get_conf('Nickname')
if div == -1: self.fail('GETINFO response should have both a server and extrainfo descriptor:\n%s' % response) @@ -274,44 +282,47 @@ class TestController(unittest.TestCase):
self.assertEqual(nickname, server_desc.nickname) self.assertEqual(nickname, extrainfo_desc.nickname) - self.assertEqual(controller.get_info('address'), server_desc.address) + self.assertEqual(await controller.get_info('address'), server_desc.address) self.assertEqual(test.runner.ORPORT, server_desc.or_port)
@test.require.controller @test.require.online - def test_getinfo_dir_status(self): + @async_test + async def test_getinfo_dir_status(self): """ Exercise 'GETINFO dir/status-vote/*'. """
- with test.runner.get_runner().get_tor_controller() as controller: - consensus = controller.get_info('dir/status-vote/current/consensus') + async with await test.runner.get_runner().get_tor_controller() as controller: + consensus = await controller.get_info('dir/status-vote/current/consensus') self.assertTrue('moria1' in consensus, 'moria1 not found in the consensus')
if test.tor_version() >= stem.version.Version('0.4.3.1-alpha'): - microdescs = controller.get_info('dir/status-vote/current/consensus-microdesc') + microdescs = await controller.get_info('dir/status-vote/current/consensus-microdesc') self.assertTrue('moria1' in microdescs, 'moria1 not found in the microdescriptor consensus')
@test.require.controller - def test_get_version(self): + @async_test + async def test_get_version(self): """ Test that the convenient method get_version() works. """
- with test.runner.get_runner().get_tor_controller() as controller: - version = controller.get_version() + async with await test.runner.get_runner().get_tor_controller() as controller: + version = await controller.get_version() self.assertTrue(isinstance(version, stem.version.Version)) self.assertEqual(version, test.tor_version())
@test.require.controller - def test_get_exit_policy(self): + @async_test + async def test_get_exit_policy(self): """ Sanity test for get_exit_policy(). Our 'ExitRelay 0' torrc entry causes us to have a simple reject-all policy. """
- with test.runner.get_runner().get_tor_controller() as controller: - self.assertEqual(ExitPolicy('reject *:*'), controller.get_exit_policy()) + async with await test.runner.get_runner().get_tor_controller() as controller: + self.assertEqual(ExitPolicy('reject *:*'), await controller.get_exit_policy())
@test.require.controller @async_test @@ -322,20 +333,21 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller: - controller.authenticate(test.runner.CONTROL_PASSWORD) + async with await runner.get_tor_controller(False) as controller: + await controller.authenticate(test.runner.CONTROL_PASSWORD) await test.runner.exercise_controller(self, controller)
@test.require.controller - def test_protocolinfo(self): + @async_test + async def test_protocolinfo(self): """ Test that the convenient method protocolinfo() works. """
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller: - protocolinfo = controller.get_protocolinfo() + async with await runner.get_tor_controller(False) as controller: + protocolinfo = await controller.get_protocolinfo() self.assertTrue(isinstance(protocolinfo, stem.response.protocolinfo.ProtocolInfoResponse))
# Doing a sanity test on the ProtocolInfoResponse instance returned. @@ -355,14 +367,15 @@ class TestController(unittest.TestCase): self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
@test.require.controller - def test_getconf(self): + @async_test + async def test_getconf(self): """ Exercises GETCONF with valid and invalid queries. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: control_socket = controller.get_socket()
if isinstance(control_socket, stem.socket.ControlPort): @@ -373,79 +386,89 @@ class TestController(unittest.TestCase): config_key = 'ControlSocket'
# successful single query - self.assertEqual(connection_value, controller.get_conf(config_key)) - self.assertEqual(connection_value, controller.get_conf(config_key, 'la-di-dah')) + self.assertEqual(connection_value, await controller.get_conf(config_key)) + self.assertEqual(connection_value, await controller.get_conf(config_key, 'la-di-dah'))
# succeessful batch query expected = {config_key: [connection_value]} - self.assertEqual(expected, controller.get_conf_map([config_key])) - self.assertEqual(expected, controller.get_conf_map([config_key], 'la-di-dah')) + self.assertEqual(expected, await controller.get_conf_map([config_key])) + self.assertEqual(expected, await controller.get_conf_map([config_key], 'la-di-dah'))
request_params = ['ControlPORT', 'dirport', 'datadirectory'] - reply_params = controller.get_conf_map(request_params, multiple=False).keys() + reply_params = (await controller.get_conf_map(request_params, multiple=False)).keys() self.assertEqual(set(request_params), set(reply_params))
# queries an option that is unset
- self.assertEqual(None, controller.get_conf('HTTPSProxy')) - self.assertEqual('la-di-dah', controller.get_conf('HTTPSProxy', 'la-di-dah')) - self.assertEqual([], controller.get_conf('HTTPSProxy', [], multiple = True)) + self.assertEqual(None, await controller.get_conf('HTTPSProxy')) + self.assertEqual('la-di-dah', await controller.get_conf('HTTPSProxy', 'la-di-dah')) + self.assertEqual([], await controller.get_conf('HTTPSProxy', [], multiple = True))
# non-existant option(s) - self.assertRaises(stem.InvalidArguments, controller.get_conf, 'blarg') - self.assertEqual('la-di-dah', controller.get_conf('blarg', 'la-di-dah')) - self.assertRaises(stem.InvalidArguments, controller.get_conf_map, 'blarg') - self.assertEqual({'blarg': 'la-di-dah'}, controller.get_conf_map('blarg', 'la-di-dah'))
- self.assertRaises(stem.InvalidRequest, controller.get_conf_map, ['blarg', 'huadf'], multiple = True) - self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True)) + with self.assertRaises(stem.InvalidArguments): + await controller.get_conf('blarg') + + self.assertEqual('la-di-dah', await controller.get_conf('blarg', 'la-di-dah')) + + with self.assertRaises(stem.InvalidArguments): + await controller.get_conf_map('blarg') + + self.assertEqual({'blarg': 'la-di-dah'}, await controller.get_conf_map('blarg', 'la-di-dah')) + + with self.assertRaises(stem.InvalidRequest): + await controller.get_conf_map(['blarg', 'huadf'], multiple = True) + + self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, await controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True))
# multivalue configuration keys nodefamilies = [('abc', 'xyz', 'pqrs'), ('mno', 'tuv', 'wxyz')] - controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies])) - self.assertEqual([','.join(n) for n in nodefamilies], controller.get_conf('nodefamily', multiple = True)) - controller.msg('RESETCONF NodeFamily') + await controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies])) + self.assertEqual([','.join(n) for n in nodefamilies], await controller.get_conf('nodefamily', multiple = True)) + await controller.msg('RESETCONF NodeFamily')
# empty input - self.assertEqual(None, controller.get_conf('')) - self.assertEqual({}, controller.get_conf_map([])) - self.assertEqual({}, controller.get_conf_map([''])) - self.assertEqual(None, controller.get_conf(' ')) - self.assertEqual({}, controller.get_conf_map([' ', ' '])) + self.assertEqual(None, await controller.get_conf('')) + self.assertEqual({}, await controller.get_conf_map([])) + self.assertEqual({}, await controller.get_conf_map([''])) + self.assertEqual(None, await controller.get_conf(' ')) + self.assertEqual({}, await controller.get_conf_map([' ', ' ']))
- self.assertEqual('la-di-dah', controller.get_conf('', 'la-di-dah')) - self.assertEqual({}, controller.get_conf_map('', 'la-di-dah')) - self.assertEqual({}, controller.get_conf_map([], 'la-di-dah')) + self.assertEqual('la-di-dah', await controller.get_conf('', 'la-di-dah')) + self.assertEqual({}, await controller.get_conf_map('', 'la-di-dah')) + self.assertEqual({}, await controller.get_conf_map([], 'la-di-dah'))
@test.require.controller - def test_is_set(self): + @async_test + async def test_is_set(self): """ Exercises our is_set() method. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - custom_options = controller._execute_async_method('_get_custom_options') + async with await runner.get_tor_controller() as controller: + custom_options = await controller._get_custom_options() self.assertTrue('ControlPort' in custom_options or 'ControlSocket' in custom_options) self.assertEqual('1', custom_options['DownloadExtraInfo']) self.assertEqual('1112', custom_options['SocksPort'])
- self.assertTrue(controller.is_set('DownloadExtraInfo')) - self.assertTrue(controller.is_set('SocksPort')) - self.assertFalse(controller.is_set('CellStatistics')) - self.assertFalse(controller.is_set('ConnLimit')) + self.assertTrue(await controller.is_set('DownloadExtraInfo')) + self.assertTrue(await controller.is_set('SocksPort')) + self.assertFalse(await controller.is_set('CellStatistics')) + self.assertFalse(await controller.is_set('ConnLimit'))
# check we update when setting and resetting values
- controller.set_conf('ConnLimit', '1005') - self.assertTrue(controller.is_set('ConnLimit')) + await controller.set_conf('ConnLimit', '1005') + self.assertTrue(await controller.is_set('ConnLimit'))
- controller.reset_conf('ConnLimit') - self.assertFalse(controller.is_set('ConnLimit')) + await controller.reset_conf('ConnLimit') + self.assertFalse(await controller.is_set('ConnLimit'))
@test.require.controller - def test_hidden_services_conf(self): + @async_test + async def test_hidden_services_conf(self): """ Exercises the hidden service family of methods (get_hidden_service_conf, set_hidden_service_conf, create_hidden_service, and remove_hidden_service). @@ -459,16 +482,16 @@ class TestController(unittest.TestCase): service3_path = os.path.join(test_dir, 'test_hidden_service3') service4_path = os.path.join(test_dir, 'test_hidden_service4')
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: try: # initially we shouldn't be running any hidden services
- self.assertEqual({}, controller.get_hidden_service_conf()) + self.assertEqual({}, await controller.get_hidden_service_conf())
# try setting a blank config, shouldn't have any impact
- controller.set_hidden_service_conf({}) - self.assertEqual({}, controller.get_hidden_service_conf()) + await controller.set_hidden_service_conf({}) + self.assertEqual({}, await controller.get_hidden_service_conf())
# create a hidden service
@@ -491,58 +514,58 @@ class TestController(unittest.TestCase): }, }
- controller.set_hidden_service_conf(initialconf) - self.assertEqual(initialconf, controller.get_hidden_service_conf()) + await controller.set_hidden_service_conf(initialconf) + self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add already existing services, with/without explicit target
- self.assertEqual(None, controller.create_hidden_service(service1_path, 8020)) - self.assertEqual(None, controller.create_hidden_service(service1_path, 8021, target_port = 8021)) - self.assertEqual(initialconf, controller.get_hidden_service_conf()) + self.assertEqual(None, await controller.create_hidden_service(service1_path, 8020)) + self.assertEqual(None, await controller.create_hidden_service(service1_path, 8021, target_port = 8021)) + self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add a new service, with/without explicit target
hs_path = os.path.join(os.getcwd(), service3_path) - hs_address1 = controller.create_hidden_service(hs_path, 8888).hostname - hs_address2 = controller.create_hidden_service(hs_path, 8989, target_port = 8021).hostname + hs_address1 = (await controller.create_hidden_service(hs_path, 8888)).hostname + hs_address2 = (await controller.create_hidden_service(hs_path, 8989, target_port = 8021)).hostname
self.assertEqual(hs_address1, hs_address2) self.assertTrue(hs_address1.endswith('.onion'))
- conf = controller.get_hidden_service_conf() + conf = await controller.get_hidden_service_conf() self.assertEqual(3, len(conf)) self.assertEqual(2, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service, the service dir should still be there
- controller.remove_hidden_service(hs_path, 8888) - self.assertEqual(3, len(controller.get_hidden_service_conf())) + await controller.remove_hidden_service(hs_path, 8888) + self.assertEqual(3, len(await controller.get_hidden_service_conf()))
# remove a service completely, it should now be gone
- controller.remove_hidden_service(hs_path, 8989) - self.assertEqual(2, len(controller.get_hidden_service_conf())) + await controller.remove_hidden_service(hs_path, 8989) + self.assertEqual(2, len(await controller.get_hidden_service_conf()))
# add a new service, this time with client authentication
hs_path = os.path.join(os.getcwd(), service4_path) - hs_attributes = controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2']) + hs_attributes = await controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2'])
self.assertEqual(2, len(hs_attributes.hostname.splitlines())) self.assertEqual(2, len(hs_attributes.hostname_for_client)) self.assertTrue(hs_attributes.hostname_for_client['c1'].endswith('.onion')) self.assertTrue(hs_attributes.hostname_for_client['c2'].endswith('.onion'))
- conf = controller.get_hidden_service_conf() + conf = await controller.get_hidden_service_conf() self.assertEqual(3, len(conf)) self.assertEqual(1, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service
- controller.remove_hidden_service(hs_path, 8888) - self.assertEqual(2, len(controller.get_hidden_service_conf())) + await controller.remove_hidden_service(hs_path, 8888) + self.assertEqual(2, len(await controller.get_hidden_service_conf())) finally: - controller.set_hidden_service_conf({}) # drop hidden services created during the test + await controller.set_hidden_service_conf({}) # drop hidden services created during the test
# clean up the hidden service directories created as part of this test
@@ -553,47 +576,50 @@ class TestController(unittest.TestCase): pass
@test.require.controller - def test_without_ephemeral_hidden_services(self): + @async_test + async def test_without_ephemeral_hidden_services(self): """ Exercises ephemeral hidden service methods when none are present. """
- with test.runner.get_runner().get_tor_controller() as controller: - self.assertEqual([], controller.list_ephemeral_hidden_services()) - self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True)) - self.assertEqual(False, controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz')) + async with await test.runner.get_runner().get_tor_controller() as controller: + self.assertEqual([], await controller.list_ephemeral_hidden_services()) + self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True)) + self.assertEqual(False, await controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
@test.require.controller - def test_with_invalid_ephemeral_hidden_service_port(self): - with test.runner.get_runner().get_tor_controller() as controller: + @async_test + async def test_with_invalid_ephemeral_hidden_service_port(self): + async with await test.runner.get_runner().get_tor_controller() as controller: for ports in (4567890, [4567, 4567890], {4567: '-:4567'}): - exc_msg = "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET" - self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, ports) + with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"): + await controller.create_ephemeral_hidden_service(ports)
@test.require.controller - def test_ephemeral_hidden_services_v2(self): + @async_test + async def test_ephemeral_hidden_services_v2(self): """ Exercises creating v2 ephemeral hidden services. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024') - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services()) + async with await runner.get_tor_controller() as controller: + response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024') + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services()) self.assertTrue(response.private_key is not None) self.assertEqual('RSA1024', response.private_key_type) self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id)) - self.assertEqual([], controller.list_ephemeral_hidden_services()) + self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id)) + self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key) - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services()) + recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services()) self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one @@ -603,41 +629,42 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True) - self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services()) + response = await controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True) + self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services()) self.assertEqual(None, response.private_key) self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller: - self.assertEqual(2, len(controller.list_ephemeral_hidden_services())) - self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services())) + async with await runner.get_tor_controller() as second_controller: + self.assertEqual(2, len(await controller.list_ephemeral_hidden_services())) + self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller - def test_ephemeral_hidden_services_v3(self): + @async_test + async def test_ephemeral_hidden_services_v3(self): """ Exercises creating v3 ephemeral hidden services. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - response = controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3') - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services()) + async with await runner.get_tor_controller() as controller: + response = await controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3') + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services()) self.assertTrue(response.private_key is not None) self.assertEqual('ED25519-V3', response.private_key_type) self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id)) - self.assertEqual([], controller.list_ephemeral_hidden_services()) + self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id)) + self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key) - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services()) + recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services()) self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one @@ -647,38 +674,40 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True) - self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services()) + response = await controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True) + self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services()) self.assertEqual(None, response.private_key) self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller: - self.assertEqual(2, len(controller.list_ephemeral_hidden_services())) - self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services())) + async with await runner.get_tor_controller() as second_controller: + self.assertEqual(2, len(await controller.list_ephemeral_hidden_services())) + self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller - def test_with_ephemeral_hidden_services_basic_auth(self): + @async_test + async def test_with_ephemeral_hidden_services_basic_auth(self): """ Exercises creating ephemeral hidden services that uses basic authentication. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None}) - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services()) + async with await runner.get_tor_controller() as controller: + response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None}) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services()) self.assertTrue(response.private_key is not None) self.assertEqual(['bob'], list(response.client_auth.keys())) # newly created credentials were only created for bob
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id)) - self.assertEqual([], controller.list_ephemeral_hidden_services()) + self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id)) + self.assertEqual([], await controller.list_ephemeral_hidden_services())
@test.require.controller - def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self): + @async_test + async def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self): """ Exercises creating ephemeral hidden services when attempting to use basic auth but not including any credentials. @@ -686,12 +715,13 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - exc_msg = "ADD_ONION response didn't have an OK status: No auth clients specified" - self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, 4567, basic_auth = {}) + async with await runner.get_tor_controller() as controller: + with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: No auth clients specified"): + await controller.create_ephemeral_hidden_service(4567, basic_auth = {})
@test.require.controller - def test_with_detached_ephemeral_hidden_services(self): + @async_test + async def test_with_detached_ephemeral_hidden_services(self): """ Exercises creating detached ephemeral hidden services and methods when they're present. @@ -699,34 +729,35 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - response = controller.create_ephemeral_hidden_service(4567, detached = True) - self.assertEqual([], controller.list_ephemeral_hidden_services()) - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True)) + async with await runner.get_tor_controller() as controller: + response = await controller.create_ephemeral_hidden_service(4567, detached = True) + self.assertEqual([], await controller.list_ephemeral_hidden_services()) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# drop and recreate the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id)) - self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True)) - controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True) - self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True)) + self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id)) + self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True)) + await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# other controllers should be able to see this service, and drop it
- with runner.get_tor_controller() as second_controller: - self.assertEqual([response.service_id], second_controller.list_ephemeral_hidden_services(detached = True)) - self.assertEqual(True, second_controller.remove_ephemeral_hidden_service(response.service_id)) - self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True)) + async with await runner.get_tor_controller() as second_controller: + self.assertEqual([response.service_id], await second_controller.list_ephemeral_hidden_services(detached = True)) + self.assertEqual(True, await second_controller.remove_ephemeral_hidden_service(response.service_id)) + self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
# recreate the service and confirms that it outlives this controller
- response = second_controller.create_ephemeral_hidden_service(4567, detached = True) + response = await second_controller.create_ephemeral_hidden_service(4567, detached = True)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True)) - controller.remove_ephemeral_hidden_service(response.service_id) + self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True)) + await controller.remove_ephemeral_hidden_service(response.service_id)
@test.require.controller - def test_rejecting_unanonymous_hidden_services_creation(self): + @async_test + async def test_rejecting_unanonymous_hidden_services_creation(self): """ Attempt to create a non-anonymous hidden service despite not setting HiddenServiceSingleHopMode and HiddenServiceNonAnonymousMode. @@ -734,11 +765,12 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - self.assertEqual('Tor is in anonymous hidden service mode', str(controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567'))) + async with await runner.get_tor_controller() as controller: + self.assertEqual('Tor is in anonymous hidden service mode', str(await controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
@test.require.controller - def test_set_conf(self): + @async_test + async def test_set_conf(self): """ Exercises set_conf(), reset_conf(), and set_options() methods with valid and invalid requests. @@ -748,42 +780,42 @@ class TestController(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: try: # successfully set a single option - connlimit = int(controller.get_conf('ConnLimit')) - controller.set_conf('connlimit', str(connlimit - 1)) - self.assertEqual(connlimit - 1, int(controller.get_conf('ConnLimit'))) + connlimit = int(await controller.get_conf('ConnLimit')) + await controller.set_conf('connlimit', str(connlimit - 1)) + self.assertEqual(connlimit - 1, int(await controller.get_conf('ConnLimit')))
# successfully set a single list option exit_policy = ['accept *:7777', 'reject *:*'] - controller.set_conf('ExitPolicy', exit_policy) - self.assertEqual(exit_policy, controller.get_conf('ExitPolicy', multiple = True)) + await controller.set_conf('ExitPolicy', exit_policy) + self.assertEqual(exit_policy, await controller.get_conf('ExitPolicy', multiple = True))
# fail to set a single option try: - controller.set_conf('invalidkeyboo', 'abcde') + await controller.set_conf('invalidkeyboo', 'abcde') self.fail() except stem.InvalidArguments as exc: self.assertEqual(['invalidkeyboo'], exc.arguments)
# resets configuration parameters - controller.reset_conf('ConnLimit', 'ExitPolicy') - self.assertEqual(connlimit, int(controller.get_conf('ConnLimit'))) - self.assertEqual(None, controller.get_conf('ExitPolicy')) + await controller.reset_conf('ConnLimit', 'ExitPolicy') + self.assertEqual(connlimit, int(await controller.get_conf('ConnLimit'))) + self.assertEqual(None, await controller.get_conf('ExitPolicy'))
# successfully sets multiple config options - controller.set_options({ + await controller.set_options({ 'connlimit': str(connlimit - 2), 'contactinfo': 'stem@testing', })
- self.assertEqual(connlimit - 2, int(controller.get_conf('ConnLimit'))) - self.assertEqual('stem@testing', controller.get_conf('contactinfo')) + self.assertEqual(connlimit - 2, int(await controller.get_conf('ConnLimit'))) + self.assertEqual('stem@testing', await controller.get_conf('contactinfo'))
# fail to set multiple config options try: - controller.set_options({ + await controller.set_options({ 'contactinfo': 'stem@testing', 'bombay': 'vadapav', }) @@ -792,17 +824,17 @@ class TestController(unittest.TestCase): self.assertEqual(['bombay'], exc.arguments)
# context-sensitive keys (the only retched things for which order matters) - controller.set_options(( + await controller.set_options(( ('HiddenServiceDir', tmpdir), ('HiddenServicePort', '17234 127.0.0.1:17235'), ))
- self.assertEqual(tmpdir, controller.get_conf('HiddenServiceDir')) - self.assertEqual('17234 127.0.0.1:17235', controller.get_conf('HiddenServicePort')) + self.assertEqual(tmpdir, await controller.get_conf('HiddenServiceDir')) + self.assertEqual('17234 127.0.0.1:17235', await controller.get_conf('HiddenServicePort')) finally: # reverts configuration changes
- controller.set_options(( + await controller.set_options(( ('ExitPolicy', 'reject *:*'), ('ConnLimit', None), ('ContactInfo', None), @@ -811,47 +843,53 @@ class TestController(unittest.TestCase): ), reset = True)
@test.require.controller - def test_set_conf_for_usebridges(self): + @async_test + async def test_set_conf_for_usebridges(self): """ Ensure we can set UseBridges=1 and also set a Bridge. This is a tor regression check (:trac:`31945`). """
- with test.runner.get_runner().get_tor_controller() as controller: - orport = controller.get_conf('ORPort') + async with await test.runner.get_runner().get_tor_controller() as controller: + orport = await controller.get_conf('ORPort')
try: - controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe - controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')]) - self.assertEqual('127.0.0.1:9999', controller.get_conf('Bridge')) + await controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe + await controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')]) + self.assertEqual('127.0.0.1:9999', await controller.get_conf('Bridge')) finally: # reverts configuration changes
- controller.set_options(( + await controller.set_options(( ('ORPort', orport), ('UseBridges', None), ('Bridge', None), ), reset = True)
@test.require.controller - def test_set_conf_when_immutable(self): + @async_test + async def test_set_conf_when_immutable(self): """ Issue a SETCONF for tor options that cannot be changed while running. """
- with test.runner.get_runner().get_tor_controller() as controller: - self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running", controller.set_conf, 'DisableAllSwap', '1') - self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running", controller.set_options, {'User': 'atagar', 'DisableAllSwap': '1'}) + async with await test.runner.get_runner().get_tor_controller() as controller: + with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running"): + await controller.set_conf('DisableAllSwap', '1') + + with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running"): + await controller.set_options({'User': 'atagar', 'DisableAllSwap': '1'})
@test.require.controller - def test_loadconf(self): + @async_test + async def test_loadconf(self): """ Exercises Controller.load_conf with valid and invalid requests. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: oldconf = runner.get_torrc_contents()
try: @@ -863,98 +901,105 @@ class TestController(unittest.TestCase): # ("/home/atagar/Desktop/stem/test/data"->"/home/atagar/.tor") is not # allowed.
- self.assertRaises(stem.InvalidRequest, controller.load_conf, 'ContactInfo confloaded') + with self.assertRaises(stem.InvalidRequest): + await controller.load_conf('ContactInfo confloaded')
try: - controller.load_conf('Blahblah blah') + await controller.load_conf('Blahblah blah') self.fail() except stem.InvalidArguments as exc: self.assertEqual(['Blahblah'], exc.arguments)
# valid config
- controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n') - self.assertEqual('confloaded', controller.get_conf('ContactInfo')) + await controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n') + self.assertEqual('confloaded', await controller.get_conf('ContactInfo')) finally: # reload original valid config - controller.load_conf(oldconf) - controller.reset_conf('__OwningControllerProcess') + await controller.load_conf(oldconf) + await controller.reset_conf('__OwningControllerProcess')
@test.require.controller - def test_saveconf(self): + @async_test + async def test_saveconf(self): runner = test.runner.get_runner()
# only testing for success, since we need to run out of disk space to test # for failure - with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: oldconf = runner.get_torrc_contents()
try: - controller.set_conf('ContactInfo', 'confsaved') - controller.save_conf() + await controller.set_conf('ContactInfo', 'confsaved') + await controller.save_conf()
with open(runner.get_torrc_path()) as torrcfile: self.assertTrue('\nContactInfo confsaved\n' in torrcfile.read()) finally: - controller.load_conf(oldconf) - controller.save_conf() - controller.reset_conf('__OwningControllerProcess') + await controller.load_conf(oldconf) + await controller.save_conf() + await controller.reset_conf('__OwningControllerProcess')
@test.require.controller - def test_get_ports(self): + @async_test + async def test_get_ports(self): """ Test Controller.get_ports against a running tor instance. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - self.assertEqual([test.runner.ORPORT], controller.get_ports(Listener.OR)) - self.assertEqual([], controller.get_ports(Listener.DIR)) - self.assertEqual([test.runner.SOCKS_PORT], controller.get_ports(Listener.SOCKS)) - self.assertEqual([], controller.get_ports(Listener.TRANS)) - self.assertEqual([], controller.get_ports(Listener.NATD)) - self.assertEqual([], controller.get_ports(Listener.DNS)) + async with await runner.get_tor_controller() as controller: + self.assertEqual([test.runner.ORPORT], await controller.get_ports(Listener.OR)) + self.assertEqual([], await controller.get_ports(Listener.DIR)) + self.assertEqual([test.runner.SOCKS_PORT], await controller.get_ports(Listener.SOCKS)) + self.assertEqual([], await controller.get_ports(Listener.TRANS)) + self.assertEqual([], await controller.get_ports(Listener.NATD)) + self.assertEqual([], await controller.get_ports(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options(): - self.assertEqual([test.runner.CONTROL_PORT], controller.get_ports(Listener.CONTROL)) + self.assertEqual([test.runner.CONTROL_PORT], await controller.get_ports(Listener.CONTROL)) else: - self.assertEqual([], controller.get_ports(Listener.CONTROL)) + self.assertEqual([], await controller.get_ports(Listener.CONTROL))
@test.require.controller - def test_get_listeners(self): + @async_test + async def test_get_listeners(self): """ Test Controller.get_listeners against a running tor instance. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - self.assertEqual([('0.0.0.0', test.runner.ORPORT)], controller.get_listeners(Listener.OR)) - self.assertEqual([], controller.get_listeners(Listener.DIR)) - self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], controller.get_listeners(Listener.SOCKS)) - self.assertEqual([], controller.get_listeners(Listener.TRANS)) - self.assertEqual([], controller.get_listeners(Listener.NATD)) - self.assertEqual([], controller.get_listeners(Listener.DNS)) + async with await runner.get_tor_controller() as controller: + self.assertEqual([('0.0.0.0', test.runner.ORPORT)], await controller.get_listeners(Listener.OR)) + self.assertEqual([], await controller.get_listeners(Listener.DIR)) + self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], await controller.get_listeners(Listener.SOCKS)) + self.assertEqual([], await controller.get_listeners(Listener.TRANS)) + self.assertEqual([], await controller.get_listeners(Listener.NATD)) + self.assertEqual([], await controller.get_listeners(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options(): - self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], controller.get_listeners(Listener.CONTROL)) + self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], await controller.get_listeners(Listener.CONTROL)) else: - self.assertEqual([], controller.get_listeners(Listener.CONTROL)) + self.assertEqual([], await controller.get_listeners(Listener.CONTROL))
@test.require.controller @test.require.online @test.require.version(stem.version.Version('0.1.2.2-alpha')) - def test_enable_feature(self): + @async_test + async def test_enable_feature(self): """ Test Controller.enable_feature with valid and invalid inputs. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: self.assertTrue(controller.is_feature_enabled('VERBOSE_NAMES')) - self.assertRaises(stem.InvalidArguments, controller.enable_feature, ['NOT', 'A', 'FEATURE']) + + with self.assertRaises(stem.InvalidArguments): + await controller.enable_feature(['NOT', 'A', 'FEATURE'])
try: controller.enable_feature(['NOT', 'A', 'FEATURE']) @@ -964,58 +1009,70 @@ class TestController(unittest.TestCase): self.fail()
@test.require.controller - def test_signal(self): + @async_test + async def test_signal(self): """ Test controller.signal with valid and invalid signals. """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: # valid signal - controller.signal('CLEARDNSCACHE') + await controller.signal('CLEARDNSCACHE')
# invalid signals - self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR') + + with self.assertRaises(stem.InvalidArguments): + await controller.signal('FOOBAR')
@test.require.controller - def test_newnym_availability(self): + @async_test + async def test_newnym_availability(self): """ Test the is_newnym_available and get_newnym_wait methods. """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: self.assertEqual(True, controller.is_newnym_available()) self.assertEqual(0.0, controller.get_newnym_wait())
- controller.signal(stem.Signal.NEWNYM) + await controller.signal(stem.Signal.NEWNYM)
self.assertEqual(False, controller.is_newnym_available()) self.assertTrue(controller.get_newnym_wait() > 9.0)
@test.require.controller @test.require.online - def test_extendcircuit(self): - with test.runner.get_runner().get_tor_controller() as controller: + @async_test + async def test_extendcircuit(self): + async with await test.runner.get_runner().get_tor_controller() as controller: circuit_id = controller.extend_circuit('0')
# check if our circuit was created + self.assertNotEqual(None, controller.get_circuit(circuit_id, None)) circuit_id = controller.new_circuit() self.assertNotEqual(None, controller.get_circuit(circuit_id, None))
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, 'foo') - self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#') - self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo') + with self.assertRaises(stem.InvalidRequest): + await controller.extend_circuit('foo') + + with self.assertRaises(stem.InvalidRequest): + await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#') + + with self.assertRaises(stem.InvalidRequest): + await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
@test.require.controller @test.require.online - def test_repurpose_circuit(self): + @async_test + async def test_repurpose_circuit(self): """ Tests Controller.repurpose_circuit with valid and invalid input. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: circ_id = controller.new_circuit() controller.repurpose_circuit(circ_id, 'CONTROLLER') circuit = controller.get_circuit(circ_id) @@ -1025,38 +1082,47 @@ class TestController(unittest.TestCase): circuit = controller.get_circuit(circ_id) self.assertTrue(circuit.purpose == 'GENERAL')
- self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, 'f934h9f3h4', 'fooo') - self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, '4', 'fooo') + with self.assertRaises(stem.InvalidRequest): + await controller.repurpose_circuit('f934h9f3h4', 'fooo') + + with self.assertRaises(stem.InvalidRequest): + await controller.repurpose_circuit('4', 'fooo')
@test.require.controller @test.require.online - def test_close_circuit(self): + @async_test + async def test_close_circuit(self): """ Tests Controller.close_circuit with valid and invalid input. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: circuit_id = controller.new_circuit() controller.close_circuit(circuit_id) - circuit_output = controller.get_info('circuit-status') + circuit_output = await controller.get_info('circuit-status') circ = [x.split()[0] for x in circuit_output.splitlines()] self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit() controller.close_circuit(circuit_id, 'IfUnused') - circuit_output = controller.get_info('circuit-status') + circuit_output = await controller.get_info('circuit-status') circ = [x.split()[0] for x in circuit_output.splitlines()] self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit() - self.assertRaises(stem.InvalidArguments, controller.close_circuit, circuit_id + '1024') - self.assertRaises(stem.InvalidRequest, controller.close_circuit, '') + + with self.assertRaises(stem.InvalidArguments): + await controller.close_circuit(circuit_id + '1024') + + with self.assertRaises(stem.InvalidRequest): + await controller.close_circuit('')
@test.require.controller @test.require.online - def test_get_streams(self): + @async_test + async def test_get_streams(self): """ Tests Controller.get_streams(). """ @@ -1065,9 +1131,11 @@ class TestController(unittest.TestCase): port = 443
runner = test.runner.get_runner() - with runner.get_tor_controller() as controller: + + async with await runner.get_tor_controller() as controller: # we only need one proxy port, so take the first - socks_listener = controller.get_listeners(Listener.SOCKS)[0] + + socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s: s.settimeout(30) @@ -1081,17 +1149,18 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_close_stream(self): + @async_test + async def test_close_stream(self): """ Tests Controller.close_stream with valid and invalid input. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: # use the first socks listener
- socks_listener = controller.get_listeners(Listener.SOCKS)[0] + socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s: s.settimeout(30) @@ -1116,16 +1185,18 @@ class TestController(unittest.TestCase):
# unknown stream
- self.assertRaises(stem.InvalidArguments, controller.close_stream, 'blarg') + with self.assertRaises(stem.InvalidArguments): + await controller.close_stream('blarg')
@test.require.controller @test.require.online - def test_mapaddress(self): + @async_test + async def test_mapaddress(self): self.skipTest('(https://trac.torproject.org/projects/tor/ticket/25611)') runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: - controller.map_address({'1.2.1.2': 'ifconfig.me'}) + async with await runner.get_tor_controller() as controller: + await controller.map_address({'1.2.1.2': 'ifconfig.me'})
s = None response = None @@ -1136,7 +1207,7 @@ class TestController(unittest.TestCase): try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(30) - s.connect(('127.0.0.1', int(controller.get_conf('SocksPort')))) + s.connect(('127.0.0.1', int(await controller.get_conf('SocksPort')))) test.network.negotiate_socks(s, '1.2.1.2', 80) s.sendall(stem.util.str_tools._to_bytes(test.network.IP_REQUEST)) # make the http request for the ip address response = s.recv(1000) @@ -1158,14 +1229,15 @@ class TestController(unittest.TestCase): self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)), "'%s' isn't an address" % ip_addr)
@test.require.controller - def test_mapaddress_offline(self): + @async_test + async def test_mapaddress_offline(self): runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: # try mapping one element, ensuring results are as expected
map1 = {'1.2.1.2': 'ifconfig.me'} - x = controller.map_address(map1) + x = await controller.map_address(map1) self.assertEqual(x, map1)
# try mapping two elements, ensuring results are as expected @@ -1173,17 +1245,18 @@ class TestController(unittest.TestCase): map2 = {'1.2.3.4': 'foobar.example.com', '1.2.3.5': 'barfuzz.example.com'}
- x = controller.map_address(map2) + x = await controller.map_address(map2) self.assertEqual(x, map2)
# try mapping zero elements
- self.assertRaises(stem.InvalidRequest, controller.map_address, {}) + with self.assertRaises(stem.InvalidRequest): + await controller.map_address({})
# try a virtual mapping to IPv4, the default virtualaddressrange is 127.192.0.0/10
map3 = {'0.0.0.0': 'quux'} - x = controller.map_address(map3) + x = await controller.map_address(map3) self.assertEquals(len(x), 1) addr1, target = list(x.items())[0]
@@ -1193,15 +1266,15 @@ class TestController(unittest.TestCase): # try a virtual mapping to IPv6, the default IPv6 virtualaddressrange is FE80::/10
map4 = {'::': 'quibble'} - x = controller.map_address(map4) + x = await controller.map_address(map4) self.assertEquals(len(x), 1) addr2, target = list(x.items())[0]
self.assertTrue(addr2.startswith('[fe'), '%s did not start with [fe.' % addr2) self.assertEquals(target, 'quibble')
- def address_mappings(addr_type): - response = controller.get_info(['address-mappings/%s' % addr_type]) + async def address_mappings(addr_type): + response = await controller.get_info(['address-mappings/%s' % addr_type]) result = {}
for line in response['address-mappings/%s' % addr_type].splitlines(): @@ -1218,7 +1291,7 @@ class TestController(unittest.TestCase): '1.2.3.5': 'barfuzz.example.com', addr1: 'quux', addr2: 'quibble', - }, address_mappings('control')) + }, await address_mappings('control'))
# ask for a list of all the address mappings
@@ -1228,29 +1301,40 @@ class TestController(unittest.TestCase): '1.2.3.5': 'barfuzz.example.com', addr1: 'quux', addr2: 'quibble', - }, address_mappings('all')) + }, await address_mappings('all'))
# Now ask for a list of only the mappings configured with the # configuration. Ours should not be there.
- self.assertEquals({}, address_mappings('config')) + self.assertEquals({}, await address_mappings('config'))
@test.require.controller @test.require.online - def test_get_microdescriptor(self): + @async_test + async def test_get_microdescriptor(self): """ Basic checks for get_microdescriptor(). """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: # we should balk at invalid content - self.assertRaises(ValueError, controller.get_microdescriptor, '') - self.assertRaises(ValueError, controller.get_microdescriptor, 5) - self.assertRaises(ValueError, controller.get_microdescriptor, 'z' * 30) + + with self.assertRaises(ValueError): + await controller.get_microdescriptor('') + + with self.assertRaises(ValueError): + await controller.get_microdescriptor(5) + + with self.assertRaises(ValueError): + await controller.get_microdescriptor('z' * 30)
# try with a relay that doesn't exist - self.assertRaises(stem.ControllerError, controller.get_microdescriptor, 'blargg') - self.assertRaises(stem.ControllerError, controller.get_microdescriptor, '5' * 40) + + with self.assertRaises(stem.ControllerError): + await controller.get_microdescriptor('blargg') + + with self.assertRaises(stem.ControllerError): + await controller.get_microdescriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1261,7 +1345,8 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_microdescriptors(self): + @async_test + async def test_get_microdescriptors(self): """ Fetches a few descriptors via the get_microdescriptors() method. """ @@ -1271,7 +1356,7 @@ class TestController(unittest.TestCase): if not os.path.exists(runner.get_test_dir('cached-microdescs')): self.skipTest('(no cached microdescriptors)')
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: count = 0
for desc in controller.get_microdescriptors(): @@ -1283,22 +1368,33 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_server_descriptor(self): + @async_test + async def test_get_server_descriptor(self): """ Basic checks for get_server_descriptor(). """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: # we should balk at invalid content - self.assertRaises(ValueError, controller.get_server_descriptor, '') - self.assertRaises(ValueError, controller.get_server_descriptor, 5) - self.assertRaises(ValueError, controller.get_server_descriptor, 'z' * 30) + + with self.assertRaises(ValueError): + await controller.get_server_descriptor('') + + with self.assertRaises(ValueError): + await controller.get_server_descriptor(5) + + with self.assertRaises(ValueError): + await controller.get_server_descriptor('z' * 30)
# try with a relay that doesn't exist - self.assertRaises(stem.ControllerError, controller.get_server_descriptor, 'blargg') - self.assertRaises(stem.ControllerError, controller.get_server_descriptor, '5' * 40) + + with self.assertRaises(stem.ControllerError): + await controller.get_server_descriptor('blargg') + + with self.assertRaises(stem.ControllerError): + await controller.get_server_descriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1309,14 +1405,15 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_server_descriptors(self): + @async_test + async def test_get_server_descriptors(self): """ Fetches a few descriptors via the get_server_descriptors() method. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: count = 0
for desc in controller.get_server_descriptors(): @@ -1334,20 +1431,31 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_network_status(self): + @async_test + async def test_get_network_status(self): """ Basic checks for get_network_status(). """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: # we should balk at invalid content - self.assertRaises(ValueError, controller.get_network_status, '') - self.assertRaises(ValueError, controller.get_network_status, 5) - self.assertRaises(ValueError, controller.get_network_status, 'z' * 30) + + with self.assertRaises(ValueError): + await controller.get_network_status('') + + with self.assertRaises(ValueError): + await controller.get_network_status(5) + + with self.assertRaises(ValueError): + await controller.get_network_status('z' * 30)
# try with a relay that doesn't exist - self.assertRaises(stem.ControllerError, controller.get_network_status, 'blargg') - self.assertRaises(stem.ControllerError, controller.get_network_status, '5' * 40) + + with self.assertRaises(stem.ControllerError): + await controller.get_network_status('blargg') + + with self.assertRaises(stem.ControllerError): + await controller.get_network_status('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1358,14 +1466,15 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_network_statuses(self): + @async_test + async def test_get_network_statuses(self): """ Fetches a few descriptors via the get_network_statuses() method. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: count = 0
for desc in controller.get_network_statuses(): @@ -1381,14 +1490,15 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_hidden_service_descriptor(self): + @async_test + async def test_get_hidden_service_descriptor(self): """ Fetches a few descriptors via the get_hidden_service_descriptor() method. """
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller: + async with await runner.get_tor_controller() as controller: # fetch the descriptor for DuckDuckGo
desc = controller.get_hidden_service_descriptor('3g2upl4pq6kufc4m.onion') @@ -1396,8 +1506,8 @@ class TestController(unittest.TestCase):
# try to fetch something that doesn't exist
- exc_msg = 'No running hidden service at m4cfuk6qp4lpu2g3.onion' - self.assertRaisesWith(stem.DescriptorUnavailable, exc_msg, controller.get_hidden_service_descriptor, 'm4cfuk6qp4lpu2g3') + with self.assertRaisesWith(stem.DescriptorUnavailable, 'No running hidden service at m4cfuk6qp4lpu2g3.onion'): + await controller.get_hidden_service_descriptor('m4cfuk6qp4lpu2g3')
# ... but shouldn't fail if we have a default argument or aren't awaiting the descriptor
@@ -1406,7 +1516,8 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_attachstream(self): + @async_test + async def test_attachstream(self): host = socket.gethostbyname('www.torproject.org') port = 80
@@ -1416,15 +1527,16 @@ class TestController(unittest.TestCase): if stream.status == 'NEW' and circuit_id: controller.attach_stream(stream.id, circuit_id)
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: # try 10 times to build a circuit we can connect through + for i in range(10): - controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM) - controller.set_conf('__LeaveStreamsUnattached', '1') + await controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM) + await controller.set_conf('__LeaveStreamsUnattached', '1')
try: circuit_id = controller.new_circuit(await_build = True) - socks_listener = controller.get_listeners(Listener.SOCKS)[0] + socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s: s.settimeout(30) @@ -1435,7 +1547,7 @@ class TestController(unittest.TestCase): continue finally: controller.remove_event_listener(handle_streamcreated) - controller.reset_conf('__LeaveStreamsUnattached') + await controller.reset_conf('__LeaveStreamsUnattached')
our_stream = [stream for stream in streams if stream.target_address == host][0]
@@ -1446,38 +1558,40 @@ class TestController(unittest.TestCase):
@test.require.controller @test.require.online - def test_get_circuits(self): + @async_test + async def test_get_circuits(self): """ Fetches circuits via the get_circuits() method. """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: new_circ = controller.new_circuit() circuits = controller.get_circuits() self.assertTrue(new_circ in [circ.id for circ in circuits])
@test.require.controller - def test_transition_to_relay(self): + @async_test + async def test_transition_to_relay(self): """ Transitions Tor to turn into a relay, then back to a client. This helps to catch transition issues such as the one cited in :trac:`14901`. """
- with test.runner.get_runner().get_tor_controller() as controller: + async with await test.runner.get_runner().get_tor_controller() as controller: try: - controller.reset_conf('OrPort', 'DisableNetwork') - self.assertEqual(None, controller.get_conf('OrPort')) + await controller.reset_conf('OrPort', 'DisableNetwork') + self.assertEqual(None, await controller.get_conf('OrPort'))
# DisableNetwork ensures no port is actually opened - controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'}) + await controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'})
# TODO once tor 0.2.7.x exists, test that we can generate a descriptor on demand.
- self.assertEqual('9090', controller.get_conf('OrPort')) - controller.reset_conf('OrPort', 'DisableNetwork') - self.assertEqual(None, controller.get_conf('OrPort')) + self.assertEqual('9090', await controller.get_conf('OrPort')) + await controller.reset_conf('OrPort', 'DisableNetwork') + self.assertEqual(None, await controller.get_conf('OrPort')) finally: - controller.set_conf('OrPort', str(test.runner.ORPORT)) + await controller.set_conf('OrPort', str(test.runner.ORPORT))
def _get_router_status_entry(self, controller): """ diff --git a/test/runner.py b/test/runner.py index b132b8f5..a02fb769 100644 --- a/test/runner.py +++ b/test/runner.py @@ -88,7 +88,7 @@ class TorInaccessable(Exception):
async def exercise_controller(test_case, controller): - """with await test.runner.get_runner().get_tor_socket + """ Checks that we can now use the socket by issuing a 'GETINFO config-file' query. Controller can be either a :class:`stem.socket.ControlSocket` or :class:`stem.control.BaseController`. @@ -102,11 +102,10 @@ async def exercise_controller(test_case, controller):
if isinstance(controller, stem.socket.ControlSocket): await controller.send('GETINFO config-file') + config_file_response = await controller.recv() else: - config_file_response = controller.msg('GETINFO config-file') - if asyncio.iscoroutine(config_file_response): - config_file_response = await config_file_response + config_file_response = await controller.msg('GETINFO config-file')
test_case.assertEqual('config-file=%s\nOK' % torrc_path, str(config_file_response))
@@ -261,9 +260,19 @@ class Runner(object): stem.socket.recv_message = _chroot_recv_message
if self.is_accessible(): - self._owner_controller = stem.control.Controller(self._get_unconnected_socket(), False) - self._owner_controller.connect() - self._authenticate_controller(self._owner_controller) + # TODO: refactor so owner controller is less convoluted + + loop = asyncio.new_event_loop() + + self._owner_controller_thread = threading.Thread( + name = 'owning_controller', + target = loop.run_forever, + daemon = True, + ) + + self._owner_controller_thread.start() + + self._owner_controller = asyncio.run_coroutine_threadsafe(self.get_tor_controller(True), loop).result()
if test.Target.RELATIVE in self.attribute_targets: os.chdir(original_cwd) # revert our cwd back to normal @@ -279,7 +288,9 @@ class Runner(object): println('Shutting down tor... ', STATUS, NO_NL)
if self._owner_controller: - self._owner_controller.close() + asyncio.run_coroutine_threadsafe(self._owner_controller.close(), self._owner_controller._loop).result() + self._owner_controller._loop.call_soon_threadsafe(self._owner_controller._loop.stop) + self._owner_controller_thread.join() self._owner_controller = None
if self._tor_process: @@ -445,16 +456,6 @@ class Runner(object): tor_process = self._get('_tor_process') return tor_process.pid
- def _get_unconnected_socket(self): - if Torrc.PORT in self._custom_opts: - control_socket = stem.socket.ControlPort(port = CONTROL_PORT) - elif Torrc.SOCKET in self._custom_opts: - control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH) - else: - raise TorInaccessable('Unable to connect to tor') - - return control_socket - async def get_tor_socket(self, authenticate = True): """ Provides a socket connected to our tor test instance. @@ -466,7 +467,13 @@ class Runner(object): :raises: :class:`test.runner.TorInaccessable` if tor can't be connected to """
- control_socket = self._get_unconnected_socket() + if Torrc.PORT in self._custom_opts: + control_socket = stem.socket.ControlPort(port = CONTROL_PORT) + elif Torrc.SOCKET in self._custom_opts: + control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH) + else: + raise TorInaccessable('Unable to connect to tor') + await control_socket.connect()
if authenticate: @@ -474,10 +481,7 @@ class Runner(object):
return control_socket
- def _authenticate_controller(self, controller): - controller.authenticate(password=CONTROL_PASSWORD, chroot_path=self.get_chroot()) - - def get_tor_controller(self, authenticate = True): + async def get_tor_controller(self, authenticate = True): """ Provides a controller connected to our tor test instance.
@@ -488,19 +492,11 @@ class Runner(object): :raises: :class: `test.runner.TorInaccessable` if tor can't be connected to """
- loop = asyncio.new_event_loop() - loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller') - loop_thread.setDaemon(True) - loop_thread.start() - - async def wrapped_get_controller(): - control_socket = await self.get_tor_socket(False) - return stem.control.Controller(control_socket) - - controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result() + control_socket = await self.get_tor_socket(False) + controller = stem.control.Controller(control_socket)
if authenticate: - self._authenticate_controller(controller) + await controller.authenticate(password = CONTROL_PASSWORD, chroot_path = self.get_chroot())
return controller
diff --git a/test/settings.cfg b/test/settings.cfg index 70bdd069..ef543a18 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -235,6 +235,14 @@ mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]"
mypy.ignore * => "Descriptor" has no attribute "*
+# Metaprogramming false positive for our close method. + +mypy.ignore stem/control.py => Return type "Coroutine[Any, Any, None]" of "close" * + +# Interpreter uses a synchronous controller, which can cause false positives. + +mypy.ignore stem/interpreter/commands.py => "Coroutine[Any, Any, ControlMessage]" has no attribute "* + # Test modules we want to run. Modules are roughly ordered by the dependencies # so the lowest level tests come first. This is because a problem in say, # controller message parsing, will cause all higher level tests to fail too. diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index 84fcdfed..6c33da6b 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -21,11 +21,7 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType from stem.response import ControlMessage from stem.exit_policy import ExitPolicy -from stem.util.test_tools import ( - async_test, - coro_func_raising_exc, - coro_func_returning_value, -) +from stem.util.test_tools import coro_func_raising_exc, coro_func_returning_value
NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75' TEST_TIMESTAMP = 12345 @@ -44,7 +40,6 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))): self.controller = Controller(socket) - self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock() self.controller.add_event_listener(self.circ_listener, EventType.CIRC) @@ -69,24 +64,23 @@ class TestControl(unittest.TestCase): for event in stem.control.EventType: self.assertTrue(stem.control.event_description(event) is not None)
- @patch('stem.control.AsyncController.msg') + @patch('stem.control.Controller.msg') def test_get_info(self, msg_mock): message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO') msg_mock.side_effect = coro_func_returning_value(message) self.assertEqual('hi right back!', self.controller.get_info('hello'))
- @patch('stem.control.AsyncController.msg') - @async_test - async def test_get_info_address_caching(self, msg_mock): + @patch('stem.control.Controller.msg') + def test_get_info_address_caching(self, msg_mock): def set_message(*args): message = ControlMessage.from_str(*args) msg_mock.side_effect = coro_func_returning_value(message)
set_message('551 Address unknown\r\n')
- self.assertEqual(None, self.async_controller._last_address_exc) + self.assertEqual(None, self.controller._last_address_exc) self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address') - self.assertEqual('Address unknown', str(self.async_controller._last_address_exc)) + self.assertEqual('Address unknown', str(self.controller._last_address_exc)) self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back @@ -98,26 +92,26 @@ class TestControl(unittest.TestCase):
set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO') self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address') - await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n')) + self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n')) self.assertEqual('17.2.89.80', self.controller.get_info('address'))
# invalidates the cache, transitioning from one address to another
set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO') self.assertEqual('17.2.89.80', self.controller.get_info('address')) - await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n')) + self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n')) self.assertEqual('80.89.2.17', self.controller.get_info('address'))
- @patch('stem.control.AsyncController.msg') - @patch('stem.control.AsyncController.get_conf') + @patch('stem.control.Controller.msg') + @patch('stem.control.Controller.get_conf') def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock): message = ControlMessage.from_str('551 Not running in server mode\r\n') msg_mock.side_effect = coro_func_returning_value(message) - get_conf_mock.return_value = None + get_conf_mock.side_effect = coro_func_returning_value(None)
- self.assertEqual(None, self.async_controller._last_fingerprint_exc) + self.assertEqual(None, self.controller._last_fingerprint_exc) self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint') - self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc)) + self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc)) self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back @@ -127,11 +121,11 @@ class TestControl(unittest.TestCase):
# ... but if we become a relay we'll call it again
- get_conf_mock.return_value = '443' + get_conf_mock.side_effect = coro_func_returning_value('443') self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint') self.assertEqual(2, msg_mock.call_count)
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') def test_get_version(self, get_info_mock): """ Exercises the get_version() method. @@ -155,7 +149,7 @@ class TestControl(unittest.TestCase): self.assertEqual(version_2_1_object, self.controller.get_version())
# Turn off caching. - self.async_controller._is_caching_enabled = False + self.controller._is_caching_enabled = False # Return a version without caching, so it will be the new version. self.assertEqual(version_2_2_object, self.controller.get_version())
@@ -184,13 +178,13 @@ class TestControl(unittest.TestCase): # Turn caching back on before we leave. self.controller._is_caching_enabled = True
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') def test_get_exit_policy(self, get_info_mock): """ Exercises the get_exit_policy() method. """
- async def get_info_mock_side_effect(param, default = None): + async def get_info_mock_side_effect(self, param, default = None): return { 'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*', }[param] @@ -213,8 +207,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
- @patch('stem.control.AsyncController.get_info') - @patch('stem.control.AsyncController.get_conf') + @patch('stem.control.Controller.get_info') + @patch('stem.control.Controller.get_conf') def test_get_ports(self, get_conf_mock, get_info_mock): """ Exercises the get_ports() and get_listeners() methods. @@ -225,7 +219,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
- async def get_conf_mock_side_effect(param, *args, **kwargs): + async def get_conf_mock_side_effect(self, param, *args, **kwargs): return { 'ControlPort': '9050', 'ControlListenAddress': ['127.0.0.1'], @@ -239,7 +233,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- async def get_conf_mock_side_effect(param, *args, **kwargs): + async def get_conf_mock_side_effect(self, param, *args, **kwargs): return { 'ControlPort': '9050', 'ControlListenAddress': ['27.4.4.1'], @@ -290,14 +284,14 @@ class TestControl(unittest.TestCase): self.assertEqual([], self.controller.get_listeners(Listener.CONTROL)) self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') @patch('time.time', Mock(return_value = 1410723598.276578)) def test_get_accounting_stats(self, get_info_mock): """ Exercises the get_accounting_stats() method. """
- async def get_info_mock_side_effect(param, **kwargs): + async def get_info_mock_side_effect(self, param, **kwargs): return { 'accounting/enabled': '1', 'accounting/hibernating': 'awake', @@ -358,6 +352,7 @@ class TestControl(unittest.TestCase): self.assertRaises(ProtocolError, self.controller.get_protocolinfo)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False)) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) def test_get_user_remote(self): """ Exercise the get_user() method for a non-local socket. @@ -367,7 +362,7 @@ class TestControl(unittest.TestCase): self.assertEqual(123, self.controller.get_user(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar'))) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('atagar'))) def test_get_user_by_getinfo(self): """ Exercise the get_user() resolution via its getinfo option. @@ -376,7 +371,8 @@ class TestControl(unittest.TestCase): self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.util.system.pid_by_name', Mock(return_value = 432)) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value(432))) @patch('stem.util.system.user', Mock(return_value = 'atagar')) def test_get_user_by_system(self): """ @@ -386,6 +382,7 @@ class TestControl(unittest.TestCase): self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False)) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) def test_get_pid_remote(self): """ Exercise the get_pid() method for a non-local socket. @@ -395,7 +392,7 @@ class TestControl(unittest.TestCase): self.assertEqual(123, self.controller.get_pid(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321'))) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('321'))) def test_get_pid_by_getinfo(self): """ Exercise the get_pid() resolution via its getinfo option. @@ -404,7 +401,8 @@ class TestControl(unittest.TestCase): self.assertEqual(321, self.controller.get_pid())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.AsyncController.get_conf') + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.get_conf') @patch('stem.control.open', create = True) def test_get_pid_by_pid_file(self, open_mock, get_conf_mock): """ @@ -418,6 +416,8 @@ class TestControl(unittest.TestCase): open_mock.assert_called_once_with('/tmp/pid_file')
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.get_conf', Mock(side_effect = coro_func_returning_value(None))) @patch('stem.util.system.pid_by_name', Mock(return_value = 432)) def test_get_pid_by_name(self): """ @@ -426,9 +426,9 @@ class TestControl(unittest.TestCase):
self.assertEqual(432, self.controller.get_pid())
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) + @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False)) - @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') @patch('time.time', Mock(return_value = 1000.0)) def test_get_uptime_by_getinfo(self, getinfo_mock): """ @@ -443,8 +443,9 @@ class TestControl(unittest.TestCase): self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14')))) - @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12'))) + @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14')))) + @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value('12'))) @patch('stem.util.system.start_time', Mock(return_value = 5000.0)) @patch('time.time', Mock(return_value = 5200.0)) def test_get_uptime_by_process(self): @@ -454,7 +455,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(200.0, self.controller.get_uptime())
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') def test_get_network_status_for_ourselves(self, get_info_mock): """ Exercises the get_network_status() method for getting our own relay. @@ -472,7 +473,7 @@ class TestControl(unittest.TestCase):
desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
- async def get_info_mock_side_effect(param, **kwargs): + async def get_info_mock_side_effect(self, param, **kwargs): return { 'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31', 'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc, @@ -482,7 +483,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') def test_get_network_status_when_unavailable(self, get_info_mock): """ Exercises the get_network_status() method. @@ -494,7 +495,7 @@ class TestControl(unittest.TestCase): exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'" self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
- @patch('stem.control.AsyncController.get_info') + @patch('stem.control.Controller.get_info') def test_get_network_status(self, get_info_mock): """ Exercises the get_network_status() method. @@ -540,16 +541,14 @@ class TestControl(unittest.TestCase):
self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
- @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True)) - @patch('stem.control.AsyncController._attach_listeners') - @patch('stem.control.AsyncController.get_version') - def test_add_event_listener(self, get_version_mock, attach_listeners_mock): + @patch('stem.control.Controller.is_authenticated', Mock(return_value = True)) + @patch('stem.control.Controller._attach_listeners', Mock(side_effect = coro_func_returning_value(([], [])))) + @patch('stem.control.Controller.get_version') + def test_add_event_listener(self, get_version_mock): """ Exercises the add_event_listener and remove_event_listener methods. """
- attach_listeners_mock.side_effect = coro_func_returning_value(([], [])) - def set_version(version_str): version = stem.version.Version(version_str) get_version_mock.side_effect = coro_func_returning_value(version) @@ -621,10 +620,10 @@ class TestControl(unittest.TestCase): self._emit_event(BW_EVENT) self.bw_listener.assert_called_once_with(BW_EVENT)
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) - @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n')))) - @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None))) - @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) + @patch('stem.control.Controller.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n')))) + @patch('stem.control.Controller.add_event_listener', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.Controller.remove_event_listener', Mock(side_effect = coro_func_returning_value(None))) def test_timeout(self): """ Methods that have an 'await' argument also have an optional timeout. Check @@ -648,7 +647,7 @@ class TestControl(unittest.TestCase): response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams]) get_info_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.get_info', get_info_mock): + with patch('stem.control.Controller.get_info', get_info_mock): streams = self.controller.get_streams() self.assertEqual(len(valid_streams), len(streams))
@@ -669,7 +668,7 @@ class TestControl(unittest.TestCase): response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n') msg_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.msg', msg_mock): + with patch('stem.control.Controller.msg', msg_mock): self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
def test_parse_circ_path(self): @@ -712,7 +711,7 @@ class TestControl(unittest.TestCase): for test_input in malformed_inputs: self.assertRaises(ProtocolError, _parse_circ_path, test_input)
- @patch('stem.control.AsyncController.get_conf') + @patch('stem.control.Controller.get_conf') def test_get_effective_rate(self, get_conf_mock): """ Exercise the get_effective_rate() method. @@ -720,7 +719,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- async def get_conf_mock_side_effect(param, *args, **kwargs): + async def get_conf_mock_side_effect(self, param, *args, **kwargs): return { 'BandwidthRate': '1073741824', 'BandwidthBurst': '1073741824', @@ -749,19 +748,19 @@ class TestControl(unittest.TestCase): # with its work is to join on the thread.
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)): - with patch('stem.control.AsyncController.is_alive') as is_alive_mock: + with patch('stem.control.Controller.is_alive') as is_alive_mock: is_alive_mock.return_value = True loop = self.controller._loop - asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop) + asyncio.run_coroutine_threadsafe(Controller._event_loop(self.controller), loop)
try: # Converting an event back into an uncast ControlMessage, then feeding it # into our controller's event queue.
uncast_event = ControlMessage.from_str(event.raw_content()) - event_queue = self.async_controller._event_queue + event_queue = self.controller._event_queue asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result() asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result() # block until the event is consumed finally: is_alive_mock.return_value = False - asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result() + self.controller._close() diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index bb6f554c..3facd6a5 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -135,7 +135,7 @@ class TestDescriptorDownloader(unittest.TestCase): def test_reply_header_data(self): query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False) self.assertEqual(None, query.reply_headers) # initially we don't have a reply - query.run(close = False) + query.run(stop = False)
self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date')) self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))