commit 18a3280d9cb27a81bb01d8e964449adda3dc734e Author: Damian Johnson atagar@torproject.org Date: Mon May 18 14:43:18 2020 -0700
Correct rebase discrepancies
To make Illia's branch cleanly mergable I rebased onto our present master. Manually resolving the conflicts resulted in a slightly different result than he had. This delta makes us perfectly match his commit c788fd8. --- stem/client/__init__.py | 11 +++--- stem/connection.py | 25 ++++++------- stem/control.py | 63 ++++++++++++++++++--------------- stem/descriptor/remote.py | 16 ++++----- stem/interpreter/__init__.py | 4 +-- stem/interpreter/autocomplete.py | 8 ++--- stem/interpreter/commands.py | 12 +++---- stem/interpreter/help.py | 9 ++--- stem/response/__init__.py | 2 +- stem/socket.py | 45 +++++++++++------------ test/integ/connection/authentication.py | 6 +++- test/unit/control/controller.py | 6 ++-- 12 files changed, 105 insertions(+), 102 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 941f0ee7..8ea7b3c1 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -33,7 +33,7 @@ import stem.socket import stem.util.connection
from types import TracebackType -from typing import Dict, Iterator, List, Optional, Sequence, Type, Union +from typing import AsyncIterator, Dict, List, Optional, Sequence, Type, Union
from stem.client.cell import ( CELL_TYPE_SIZE, @@ -70,7 +70,8 @@ class Relay(object): self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_buffer = b'' # unread bytes - self._circuits = {} + self._orport_lock = stem.util.CombinedReentrantAndAsyncioLock() + self._circuits = {} # type: Dict[int, stem.client.Circuit]
@staticmethod async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore @@ -191,7 +192,7 @@ class Relay(object): cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol) return cell
- async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']: + async def _msg(self, cell: 'stem.client.cell.Cell') -> AsyncIterator['stem.client.cell.Cell']: """ Sends a cell on the ORPort and provides the response we receive in reply.
@@ -283,7 +284,7 @@ class Relay(object):
return circ
- async def __aiter__(self) -> Iterator['stem.client.Circuit']: + async def __aiter__(self) -> AsyncIterator['stem.client.Circuit']: async with self._orport_lock: for circ in self._circuits.values(): yield circ @@ -381,7 +382,7 @@ class Circuit(object): self.forward_digest = forward_digest self.forward_key = forward_key
- async def close(self)- > None: + async def close(self) -> None: async with self.relay._orport_lock: await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) del self.relay._circuits[self.id] diff --git a/stem/connection.py b/stem/connection.py index de76a345..213ba010 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -159,7 +159,7 @@ import stem.util.str_tools import stem.util.system import stem.version
-from typing import Any, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN') @@ -227,7 +227,7 @@ COMMON_TOR_COMMANDS = ( )
-def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any: +def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.Controller] = stem.control.Controller) -> Any: """ Convenience function for quickly getting a control connection for synchronous usage. This is very handy for debugging or CLI setup, handling setup and @@ -269,7 +269,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') # TODO: change this function's API so we can provide a concrete type
if controller is None or not issubclass(controller, stem.control.Controller): - raise ValueError('Controller should be a stem.control.BaseController subclass.') + raise ValueError('Controller should be a stem.control.Controller subclass.')
async_controller_thread = stem.util.ThreadForWrappedAsyncClass() async_controller_thread.start() @@ -326,7 +326,7 @@ async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1' return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
-async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller): +async def _connect_async(control_port: Tuple[str, Union[str, int]], control_socket: str, password: Optional[str], password_prompt: bool, chroot_path: Optional[str], controller: Type[Union[stem.control.BaseController, stem.control.Controller]]) -> Any: if control_port is None and control_socket is None: raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.') elif control_port: @@ -377,7 +377,7 @@ async def _connect_async(control_port, control_socket, password, password_prompt return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any: +async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[Union[stem.control.BaseController, stem.control.Controller]]]) -> Any: """ Helper for the connect_* functions that authenticates the socket and constructs the controller. @@ -402,7 +402,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str elif issubclass(controller, stem.control.BaseController): return controller(control_socket, is_authenticated = True) elif issubclass(controller, stem.control.Controller): - return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread()) + return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread())) except IncorrectSocketType: if isinstance(control_socket, stem.socket.ControlPort): print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port)) @@ -995,7 +995,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController, auth_response = await _msg(controller, 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(binascii.b2a_hex(client_hash))) except stem.ControllerError as exc: try: - controller.connect() + await controller.connect() except: pass
@@ -1007,7 +1007,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController, # if we got anything but an OK response then err if not auth_response.is_ok(): try: - controller.connect() + await controller.connect() except: pass
@@ -1051,9 +1051,7 @@ async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.s # next followed by authentication. Transparently reconnect if that happens.
if not protocolinfo_response or str(protocolinfo_response) == 'Authentication required.': - potential_coroutine = controller.connect() - if asyncio.iscoroutine(potential_coroutine): - await potential_coroutine + await controller.connect()
try: protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1') @@ -1074,10 +1072,7 @@ async def _msg(controller: Union[stem.control.BaseController, stem.socket.Contro await controller.send(message) return await controller.recv() else: - message = controller.msg(message) - if asyncio.iscoroutine(message): - message = await message - return message + return await controller.msg(message)
def _connection_for_default_port(address: str) -> stem.socket.ControlPort: diff --git a/stem/control.py b/stem/control.py index 6e15c16c..84f8f39b 100644 --- a/stem/control.py +++ b/stem/control.py @@ -271,7 +271,7 @@ import stem.version from stem import UNDEFINED, CircStatus, Signal from stem.util import log from types import TracebackType -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events, # but if it takes longer than this we terminate. @@ -554,6 +554,8 @@ def event_description(event: str) -> str:
class _BaseControllerSocketMixin: + _socket: stem.socket.ControlSocket + def is_alive(self) -> bool: """ Checks if our socket is currently connected. This is a pass-through for our @@ -589,7 +591,7 @@ class _BaseControllerSocketMixin:
return self._socket.connection_time()
- def get_socket(self): + def get_socket(self) -> stem.socket.ControlSocket: """ Provides the socket used to speak with the tor process. Communicating with the socket directly isn't advised since it may confuse this controller. @@ -839,7 +841,7 @@ class BaseController(_BaseControllerSocketMixin): async def __aenter__(self) -> 'stem.control.BaseController': return self
- await def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: await self.close()
async def _handle_event(self, event_message: stem.response.ControlMessage) -> None: @@ -997,7 +999,7 @@ class AsyncController(BaseController): """
@classmethod - def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.AsyncController': + def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController': """ Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
@@ -1017,7 +1019,7 @@ class AsyncController(BaseController): return cls(control_socket)
@classmethod - def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.AsyncController': + def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController': """ Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
@@ -1189,8 +1191,8 @@ class AsyncController(BaseController): return list(reply.values())[0]
try: - response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(params))) - response._assert_matches(params) + response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(param_set))) + response._assert_matches(param_set)
# usually we want unicode values under python 3.x
@@ -1765,7 +1767,7 @@ class AsyncController(BaseController): return stem.descriptor.microdescriptor.Microdescriptor(desc_content)
@with_default(yields = True) - async def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: + async def get_microdescriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.microdescriptor.Microdescriptor]: """ get_microdescriptors(default = UNDEFINED)
@@ -1859,7 +1861,7 @@ class AsyncController(BaseController): return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True) - async def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: + async def get_server_descriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ get_server_descriptors(default = UNDEFINED)
@@ -1954,7 +1956,7 @@ class AsyncController(BaseController): return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content)
@with_default(yields = True) - async def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]: + async def get_network_statuses(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]: """ get_network_statuses(default = UNDEFINED)
@@ -2061,6 +2063,7 @@ class AsyncController(BaseController): request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
response = stem.response._convert_to_single_line(await self.msg(request)) + stem.response.convert('SINGLELINE', response)
if not response.is_ok(): raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code) @@ -2153,7 +2156,7 @@ class AsyncController(BaseController): async def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]: return await self.get_conf(param, default, multiple = True) # type: ignore
- await def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: + async def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: """ get_conf_map(params, default = UNDEFINED, multiple = True)
@@ -2971,7 +2974,7 @@ class AsyncController(BaseController): else: request += ' ClientAuth=%s' % client_name
- response = stem.response._convert_to_add_onion(await self.msg(request)) + response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(await self.msg(request)))
if await_publication: # We should receive five UPLOAD events, followed by up to another five @@ -3024,7 +3027,7 @@ class AsyncController(BaseController): else: raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
- async def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None: + async def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None: """ Directs further tor controller events to a given function. The function is expected to take a single argument, which is a @@ -3082,7 +3085,7 @@ class AsyncController(BaseController): if failed_events: raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events))
- async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None: + async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None: """ Stops a listener from being notified of further tor events.
@@ -3253,7 +3256,7 @@ class AsyncController(BaseController): :raises: :class:`stem.ControllerError` if the call fails """
- response = stem.response._convert_to_single_line(async self.msg('LOADCONF\n%s' % configtext)) + response = stem.response._convert_to_single_line(await self.msg('LOADCONF\n%s' % configtext))
if response.code in ('552', '553'): if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'): @@ -3379,7 +3382,7 @@ class AsyncController(BaseController): response = await self.get_info('circuit-status')
for circ in response.splitlines(): - circ_message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ)))) + circ_message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ)))) circuits.append(circ_message) # type: ignore
return circuits @@ -3563,7 +3566,7 @@ class AsyncController(BaseController): response = await self.get_info('stream-status')
for stream in response.splitlines(): - message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream)))) + message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream)))) streams.append(message) # type: ignore
return streams @@ -3744,7 +3747,7 @@ class AsyncController(BaseController): response = await self.msg('MAPADDRESS %s' % mapaddress_arg) return stem.response._convert_to_mapaddress(response).entries
- await def drop_guards(self) -> None: + async def drop_guards(self) -> None: """ Drops our present guard nodes and picks a new set.
@@ -3812,7 +3815,7 @@ class AsyncController(BaseController): if listener_type == event_type: for listener in event_listeners: try: - potential_coroutine = listener(event_message) + potential_coroutine = listener(event) if asyncio.iscoroutine(potential_coroutine): await potential_coroutine except Exception as exc: @@ -3874,7 +3877,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): """
@classmethod - def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller': + def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller': """ Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -3885,8 +3888,8 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): .. versionchanged:: 1.5.0 Use both port 9051 and 9151 by default.
- :param str address: ip address of the controller - :param int port: port number of the controller + :param address: ip address of the controller + :param port: port number of the controller
:returns: :class:`~stem.control.Controller` attached to the given port
@@ -3899,7 +3902,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): return controller
@classmethod - def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.Controller': + def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller': """ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
@@ -3915,15 +3918,19 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): controller.connect() return controller
- def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None: - def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None): + def __init__( + self, + control_socket: stem.socket.ControlSocket, + is_authenticated: bool = False, + started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None, + ) -> None: if started_async_controller_thread: self._thread_for_wrapped_class = started_async_controller_thread else: self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() self._thread_for_wrapped_class.start()
- self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) + self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore self._socket = self._wrapped_instance._socket
@_set_doc_from_async_controller @@ -3956,7 +3963,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
@_set_doc_from_async_controller def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool: - self._wrapped_instance.remove_status_listener(callback) + return self._wrapped_instance.remove_status_listener(callback)
@_set_doc_from_async_controller def authenticate(self, *args: Any, **kwargs: Any) -> None: @@ -4306,7 +4313,7 @@ def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], k raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries))
-async def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any: +async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float], start_time: float) -> Any: """ Pulls an item from a queue with a given timeout. """ diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index d7a833a4..e7ccaa24 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -104,7 +104,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression from stem.util import log, str_tools -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their # fingerprint or hashes due to a limit on the url length by squid proxies. @@ -392,7 +392,7 @@ class AsyncQuery(object): self.reply_headers = None # type: Optional[Dict[str, str]] self.kwargs = kwargs
- self._downloader_task = None + self._downloader_task = None # type: Optional[asyncio.Task] self._downloader_lock = threading.RLock()
self._asyncio_loop = asyncio.get_event_loop() @@ -401,7 +401,7 @@ class AsyncQuery(object): self.start()
if block: - self._asyncio_loop.create_task(self.run(True)) + self.run(True)
def start(self) -> None: """ @@ -432,7 +432,7 @@ class AsyncQuery(object):
return [desc async for desc in self._run(suppress)]
- async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]: + async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]: with self._downloader_lock: self.start() await self._downloader_task @@ -468,7 +468,7 @@ class AsyncQuery(object):
raise self.error
- async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]: + async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]: async for desc in self._run(True): yield desc
@@ -665,7 +665,7 @@ class Query(stem.util.AsyncClassWrapper): def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() self._thread_for_wrapped_class.start() - self._wrapped_instance: AsyncQuery = self._init_async_class( + self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore AsyncQuery, resource, descriptor_type, @@ -688,7 +688,7 @@ class Query(stem.util.AsyncClassWrapper):
self._call_async_method_soon('start')
- def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: + def run(self, suppress = False) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -708,7 +708,7 @@ class Query(stem.util.AsyncClassWrapper):
return self._execute_async_method('run', suppress)
- def __iter__(self): + def __iter__(self) -> Iterator[stem.descriptor.Descriptor]: for desc in self._execute_async_generator_method('__aiter__'): yield desc
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index ae064a0a..370b9aa6 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -124,10 +124,10 @@ def main() -> None:
if args.run_cmd: if args.run_cmd.upper().startswith('SETEVENTS '): - async def handle_event(event_message): + async def handle_event(event_message: stem.response.ControlMessage) -> None: print(format(str(event_message), *STANDARD_OUTPUT))
- controller._wrapped_instance._handle_event = handle_event + controller._wrapped_instance._handle_event = handle_event # type: ignore
if sys.stdout.isatty(): events = args.run_cmd.upper().split(' ', 1)[1] diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py index 54642472..ed51fd3d 100644 --- a/stem/interpreter/autocomplete.py +++ b/stem/interpreter/autocomplete.py @@ -11,7 +11,7 @@ import stem.control import stem.util.conf
from stem.interpreter import uses_settings -from typing import List, Optional +from typing import cast, List, Optional
@uses_settings @@ -28,7 +28,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co # GETINFO commands. Lines are of the form '[option] -- [description]'. This # strips '*' from options that accept values.
- results = controller.get_info('info/names', None) + results = cast(str, controller.get_info('info/names', None))
if results: for line in results.splitlines(): @@ -40,7 +40,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co # GETCONF, SETCONF, and RESETCONF commands. Lines are of the form # '[option] [type]'.
- results = controller.get_info('config/names', None) + results = cast(str, controller.get_info('config/names', None))
if results: for line in results.splitlines(): @@ -62,7 +62,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co )
for prefix, getinfo_cmd in options: - results = controller.get_info(getinfo_cmd, None) + results = cast(str, controller.get_info(getinfo_cmd, None))
if results: commands += [prefix + value for value in results.split()] diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index edbcca70..99f1219d 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -21,7 +21,7 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg from stem.util.term import format -from typing import Iterator, List, TextIO +from typing import cast, Iterator, List, TextIO
MAX_EVENTS = 100
@@ -45,7 +45,7 @@ def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str:
if not arg: try: - return controller.get_info('fingerprint') + return cast(str, controller.get_info('fingerprint')) except: raise ValueError("We aren't a relay, no information to provide") elif stem.util.tor_tools.is_valid_fingerprint(arg): @@ -132,14 +132,14 @@ class ControlInterpreter(code.InteractiveConsole):
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None: await handle_event_real(event_message) - self._received_events.insert(0, event_message) + self._received_events.insert(0, event_message) # type: ignore
if len(self._received_events) > MAX_EVENTS: self._received_events.pop()
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._wrapped_instance._handle_event = handle_event_wrapper + self._controller._wrapped_instance._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]: events = list(self._received_events) @@ -207,7 +207,7 @@ class ControlInterpreter(code.InteractiveConsole): extrainfo_desc_query = downloader.get_extrainfo_descriptors(fingerprint)
for desc in server_desc_query: - server_desc = desc + server_desc = cast(stem.descriptor.server_descriptor.RelayDescriptor, desc)
for desc in extrainfo_desc_query: extrainfo_desc = desc @@ -220,7 +220,7 @@ class ControlInterpreter(code.InteractiveConsole): pass
try: - address_extrainfo.append(self._controller.get_info('ip-to-country/%s' % ns_desc.address)) + address_extrainfo.append(cast(str, self._controller.get_info('ip-to-country/%s' % ns_desc.address))) except: pass
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py index 14b46e35..f2bbbafd 100644 --- a/stem/interpreter/help.py +++ b/stem/interpreter/help.py @@ -6,6 +6,7 @@ Provides our /help responses. """
import functools +from typing import cast
import stem.control import stem.util.conf @@ -74,7 +75,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c output += '\n'
if arg == 'GETINFO': - results = controller.get_info('info/names', None) + results = cast(str, controller.get_info('info/names', None))
if results: for line in results.splitlines(): @@ -84,7 +85,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c output += format('%-33s' % opt, *BOLD_OUTPUT) output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n' elif arg == 'GETCONF': - results = controller.get_info('config/names', None) + results = cast(str, controller.get_info('config/names', None))
if results: options = [opt.split(' ', 1)[0] for opt in results.splitlines()] @@ -103,7 +104,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c output += format('%-15s' % signal, *BOLD_OUTPUT) output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n' elif arg == 'SETEVENTS': - results = controller.get_info('events/names', None) + results = cast(str, controller.get_info('events/names', None))
if results: entries = results.split() @@ -118,7 +119,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
output += format(line.rstrip(), *STANDARD_OUTPUT) + '\n' elif arg == 'USEFEATURE': - results = controller.get_info('features/names', None) + results = cast(str, controller.get_info('features/names', None))
if results: output += format(results, *STANDARD_OUTPUT) + '\n' diff --git a/stem/response/__init__.py b/stem/response/__init__.py index 52dc74e4..2e251144 100644 --- a/stem/response/__init__.py +++ b/stem/response/__init__.py @@ -202,7 +202,7 @@ class ControlMessage(object):
content = re.sub(b'([\r]?)\n', b'\r\n', content)
- msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None)) + msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None))
if msg_type is not None: convert(msg_type, msg, **kwargs) diff --git a/stem/socket.py b/stem/socket.py index 0feae831..ff99c5b1 100644 --- a/stem/socket.py +++ b/stem/socket.py @@ -85,6 +85,7 @@ import asyncio import re import socket import ssl +import sys import threading import time
@@ -93,7 +94,7 @@ import stem.util.str_tools
from stem.util import log from types import TracebackType -from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload +from typing import Awaitable, BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]') ERROR_MSG = 'Error while receiving a control message (%s): %s' @@ -109,8 +110,8 @@ class BaseSocket(object): """
def __init__(self) -> None: - self._reader = None - self._writer = None + self._reader = None # type: Optional[asyncio.StreamReader] + self._writer = None # type: Optional[asyncio.StreamWriter] self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected
@@ -209,7 +210,7 @@ class BaseSocket(object): if self._writer: self._writer.close() # `StreamWriter.wait_closed` was added in Python 3.7. - if hasattr(self._writer, 'wait_closed'): + if sys.version_info >= (3, 7): await self._writer.wait_closed()
self._reader = None @@ -220,7 +221,7 @@ class BaseSocket(object): if is_change: await self._close()
- async def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None: + async def _send(self, message: Union[bytes, str], handler: Callable[[asyncio.StreamWriter, Union[bytes, str]], Awaitable[None]]) -> None: """ Send message in a thread safe manner. """ @@ -241,11 +242,11 @@ class BaseSocket(object): raise
@overload - async def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes: + async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[bytes]]) -> bytes: ...
@overload - async def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage: + async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[stem.response.ControlMessage]]) -> stem.response.ControlMessage: ...
async def _recv(self, handler): @@ -303,7 +304,7 @@ class BaseSocket(object): return self
async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): - self.close() + await self.close()
async def _connect(self) -> None: """ @@ -320,12 +321,6 @@ class BaseSocket(object): pass
async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: - """ - Constructs and connects new socket. This is implemented by subclasses. - - :returns: **tuple** with our reader and writer streams - """ - raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
@@ -380,7 +375,7 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- async def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes: + async def wrapped_recv(reader: asyncio.StreamReader) -> Optional[bytes]: read_coroutine = reader.read(1024) if timeout is None: return await read_coroutine @@ -524,7 +519,7 @@ async def send_message(writer: asyncio.StreamWriter, message: Union[bytes, str], <line 3>\r\n .\r\n
- :param writer: stream derived from the control socket + :param writer: writer object :param message: message to be sent on the control socket :param raw: leaves the message formatting untouched, passing it to the socket as-is @@ -591,9 +586,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float]
while True: try: - line = reader.readline() - if asyncio.iscoroutine(line): - line = await line + line = await reader.readline() except AttributeError: # if the control_file has been closed then we will receive: # AttributeError: 'NoneType' object has no attribute 'recv' @@ -693,7 +686,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float] raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optional[float] = None) -> 'stem.response.ControlMessage': +def recv_message_from_bytes_io(reader: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage: """ Pulls from an I/O stream until we either have a complete message or encounter a problem. @@ -708,7 +701,9 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona a complete message """
- parsed_content, raw_content, first_line = None, None, True + parsed_content = [] # type: List[Tuple[str, str, bytes]] + raw_content = bytearray() + first_line = True
while True: try: @@ -739,10 +734,10 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona log.info(ERROR_MSG % ('SocketClosed', 'empty socket content')) raise stem.SocketClosed('Received empty socket content.') elif not MESSAGE_PREFIX.match(line): - log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('Badly formatted reply line: beginning is malformed') elif not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF')
status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content @@ -781,11 +776,11 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona line = reader.readline() raw_content += line except socket.error as exc: - log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content))))) + log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8'))))) raise stem.SocketClosed(exc)
if not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content)))) + log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF') elif line == b'.\r\n': break # data block termination diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py index 683e555f..3eaae8d9 100644 --- a/test/integ/connection/authentication.py +++ b/test/integ/connection/authentication.py @@ -3,6 +3,7 @@ Integration tests for authenticating to the control socket via stem.connection.authenticate* functions. """
+import asyncio import os import unittest
@@ -121,7 +122,10 @@ class TestAuthenticate(unittest.TestCase): runner = test.runner.get_runner()
with await runner.get_tor_controller(False) as controller: - await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot()) + asyncio.run_coroutine_threadsafe( + stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()), + controller._thread_for_wrapped_class.loop, + ).result() await test.runner.exercise_controller(self, controller)
@test.require.controller diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index a11aba45..e4b11788 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -222,7 +222,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
- async def get_conf_mock_side_effect(param, **kwargs): + async def get_conf_mock_side_effect(param, *args, **kwargs): return { 'ControlPort': '9050', 'ControlListenAddress': ['127.0.0.1'], @@ -236,7 +236,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- async def get_conf_mock_side_effect(param, **kwargs): + async def get_conf_mock_side_effect(param, *args, **kwargs): return { 'ControlPort': '9050', 'ControlListenAddress': ['27.4.4.1'], @@ -717,7 +717,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- async def get_conf_mock_side_effect(param, **kwargs): + async def get_conf_mock_side_effect(param, *args, **kwargs): return { 'BandwidthRate': '1073741824', 'BandwidthBurst': '1073741824',