[tor-commits] [stem/master] Correct rebase discrepancies

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 18a3280d9cb27a81bb01d8e964449adda3dc734e
Author: Damian Johnson <atagar at torproject.org>
Date:   Mon May 18 14:43:18 2020 -0700

    Correct rebase discrepancies
    
    To make Illia's branch cleanly mergable I rebased onto our present master.
    Manually resolving the conflicts resulted in a slightly different result than
    he had. This delta makes us perfectly match his commit c788fd8.
---
 stem/client/__init__.py                 | 11 +++---
 stem/connection.py                      | 25 ++++++-------
 stem/control.py                         | 63 ++++++++++++++++++---------------
 stem/descriptor/remote.py               | 16 ++++-----
 stem/interpreter/__init__.py            |  4 +--
 stem/interpreter/autocomplete.py        |  8 ++---
 stem/interpreter/commands.py            | 12 +++----
 stem/interpreter/help.py                |  9 ++---
 stem/response/__init__.py               |  2 +-
 stem/socket.py                          | 45 +++++++++++------------
 test/integ/connection/authentication.py |  6 +++-
 test/unit/control/controller.py         |  6 ++--
 12 files changed, 105 insertions(+), 102 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 941f0ee7..8ea7b3c1 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -33,7 +33,7 @@ import stem.socket
 import stem.util.connection
 
 from types import TracebackType
-from typing import Dict, Iterator, List, Optional, Sequence, Type, Union
+from typing import AsyncIterator, Dict, List, Optional, Sequence, Type, Union
 
 from stem.client.cell import (
   CELL_TYPE_SIZE,
@@ -70,7 +70,8 @@ class Relay(object):
     self.link_protocol = LinkProtocol(link_protocol)
     self._orport = orport
     self._orport_buffer = b''  # unread bytes
-    self._circuits = {}
+    self._orport_lock = stem.util.CombinedReentrantAndAsyncioLock()
+    self._circuits = {}  # type: Dict[int, stem.client.Circuit]
 
   @staticmethod
   async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay':  # type: ignore
@@ -191,7 +192,7 @@ class Relay(object):
         cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
         return cell
 
-  async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
+  async def _msg(self, cell: 'stem.client.cell.Cell') -> AsyncIterator['stem.client.cell.Cell']:
     """
     Sends a cell on the ORPort and provides the response we receive in reply.
 
@@ -283,7 +284,7 @@ class Relay(object):
 
       return circ
 
-  async def __aiter__(self) -> Iterator['stem.client.Circuit']:
+  async def __aiter__(self) -> AsyncIterator['stem.client.Circuit']:
     async with self._orport_lock:
       for circ in self._circuits.values():
         yield circ
@@ -381,7 +382,7 @@ class Circuit(object):
       self.forward_digest = forward_digest
       self.forward_key = forward_key
 
-  async def close(self)- > None:
+  async def close(self) -> None:
     async with self.relay._orport_lock:
       await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
       del self.relay._circuits[self.id]
diff --git a/stem/connection.py b/stem/connection.py
index de76a345..213ba010 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -159,7 +159,7 @@ import stem.util.str_tools
 import stem.util.system
 import stem.version
 
-from typing import Any, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
 from stem.util import log
 
 AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -227,7 +227,7 @@ COMMON_TOR_COMMANDS = (
 )
 
 
-def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any:
+def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.Controller] = stem.control.Controller) -> Any:
   """
   Convenience function for quickly getting a control connection for synchronous
   usage. This is very handy for debugging or CLI setup, handling setup and
@@ -269,7 +269,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
   # TODO: change this function's API so we can provide a concrete type
 
   if controller is None or not issubclass(controller, stem.control.Controller):
-    raise ValueError('Controller should be a stem.control.BaseController subclass.')
+    raise ValueError('Controller should be a stem.control.Controller subclass.')
 
   async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
   async_controller_thread.start()
@@ -326,7 +326,7 @@ async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1'
   return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
 
 
-async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller):
+async def _connect_async(control_port: Tuple[str, Union[str, int]], control_socket: str, password: Optional[str], password_prompt: bool, chroot_path: Optional[str], controller: Type[Union[stem.control.BaseController, stem.control.Controller]]) -> Any:
   if control_port is None and control_socket is None:
     raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.')
   elif control_port:
@@ -377,7 +377,7 @@ async def _connect_async(control_port, control_socket, password, password_prompt
   return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
 
 
-async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
+async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[Union[stem.control.BaseController, stem.control.Controller]]]) -> Any:
   """
   Helper for the connect_* functions that authenticates the socket and
   constructs the controller.
@@ -402,7 +402,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
     elif issubclass(controller, stem.control.BaseController):
       return controller(control_socket, is_authenticated = True)
     elif issubclass(controller, stem.control.Controller):
-      return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread())
+      return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread()))
   except IncorrectSocketType:
     if isinstance(control_socket, stem.socket.ControlPort):
       print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
@@ -995,7 +995,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController,
     auth_response = await _msg(controller, 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(binascii.b2a_hex(client_hash)))
   except stem.ControllerError as exc:
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -1007,7 +1007,7 @@ async def authenticate_safecookie(controller: Union[stem.control.BaseController,
   # if we got anything but an OK response then err
   if not auth_response.is_ok():
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -1051,9 +1051,7 @@ async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.s
   # next followed by authentication. Transparently reconnect if that happens.
 
   if not protocolinfo_response or str(protocolinfo_response) == 'Authentication required.':
-    potential_coroutine = controller.connect()
-    if asyncio.iscoroutine(potential_coroutine):
-      await potential_coroutine
+    await controller.connect()
 
     try:
       protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
@@ -1074,10 +1072,7 @@ async def _msg(controller: Union[stem.control.BaseController, stem.socket.Contro
     await controller.send(message)
     return await controller.recv()
   else:
-    message = controller.msg(message)
-    if asyncio.iscoroutine(message):
-      message = await message
-    return message
+    return await controller.msg(message)
 
 
 def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
diff --git a/stem/control.py b/stem/control.py
index 6e15c16c..84f8f39b 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -271,7 +271,7 @@ import stem.version
 from stem import UNDEFINED, CircStatus, Signal
 from stem.util import log
 from types import TracebackType
-from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
 
 # When closing the controller we attempt to finish processing enqueued events,
 # but if it takes longer than this we terminate.
@@ -554,6 +554,8 @@ def event_description(event: str) -> str:
 
 
 class _BaseControllerSocketMixin:
+  _socket: stem.socket.ControlSocket
+
   def is_alive(self) -> bool:
     """
     Checks if our socket is currently connected. This is a pass-through for our
@@ -589,7 +591,7 @@ class _BaseControllerSocketMixin:
 
     return self._socket.connection_time()
 
-  def get_socket(self):
+  def get_socket(self) -> stem.socket.ControlSocket:
     """
     Provides the socket used to speak with the tor process. Communicating with
     the socket directly isn't advised since it may confuse this controller.
@@ -839,7 +841,7 @@ class BaseController(_BaseControllerSocketMixin):
   async def __aenter__(self) -> 'stem.control.BaseController':
     return self
 
-  await def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+  async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
     await self.close()
 
   async def _handle_event(self, event_message: stem.response.ControlMessage) -> None:
@@ -997,7 +999,7 @@ class AsyncController(BaseController):
   """
 
   @classmethod
-  def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.AsyncController':
+  def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController':
     """
     Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
 
@@ -1017,7 +1019,7 @@ class AsyncController(BaseController):
     return cls(control_socket)
 
   @classmethod
-  def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.AsyncController':
+  def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController':
     """
     Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
 
@@ -1189,8 +1191,8 @@ class AsyncController(BaseController):
         return list(reply.values())[0]
 
     try:
-      response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(params)))
-      response._assert_matches(params)
+      response = stem.response._convert_to_getinfo(await self.msg('GETINFO %s' % ' '.join(param_set)))
+      response._assert_matches(param_set)
 
       # usually we want unicode values under python 3.x
 
@@ -1765,7 +1767,7 @@ class AsyncController(BaseController):
     return stem.descriptor.microdescriptor.Microdescriptor(desc_content)
 
   @with_default(yields = True)
-  async def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
+  async def get_microdescriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.microdescriptor.Microdescriptor]:
     """
     get_microdescriptors(default = UNDEFINED)
 
@@ -1859,7 +1861,7 @@ class AsyncController(BaseController):
     return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
 
   @with_default(yields = True)
-  async def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
+  async def get_server_descriptors(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.server_descriptor.RelayDescriptor]:
     """
     get_server_descriptors(default = UNDEFINED)
 
@@ -1954,7 +1956,7 @@ class AsyncController(BaseController):
     return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content)
 
   @with_default(yields = True)
-  async def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
+  async def get_network_statuses(self, default: Any = UNDEFINED) -> AsyncIterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
     """
     get_network_statuses(default = UNDEFINED)
 
@@ -2061,6 +2063,7 @@ class AsyncController(BaseController):
         request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
 
       response = stem.response._convert_to_single_line(await self.msg(request))
+      stem.response.convert('SINGLELINE', response)
 
       if not response.is_ok():
         raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -2153,7 +2156,7 @@ class AsyncController(BaseController):
   async def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]:
     return await self.get_conf(param, default, multiple = True)  # type: ignore
 
-  await def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
+  async def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
     """
     get_conf_map(params, default = UNDEFINED, multiple = True)
 
@@ -2971,7 +2974,7 @@ class AsyncController(BaseController):
         else:
           request += ' ClientAuth=%s' % client_name
 
-    response = stem.response._convert_to_add_onion(await self.msg(request))
+    response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(await self.msg(request)))
 
     if await_publication:
       # We should receive five UPLOAD events, followed by up to another five
@@ -3024,7 +3027,7 @@ class AsyncController(BaseController):
     else:
       raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
 
-  async def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None:
+  async def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None:
     """
     Directs further tor controller events to a given function. The function is
     expected to take a single argument, which is a
@@ -3082,7 +3085,7 @@ class AsyncController(BaseController):
       if failed_events:
         raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events))
 
-  async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None:
+  async def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None:
     """
     Stops a listener from being notified of further tor events.
 
@@ -3253,7 +3256,7 @@ class AsyncController(BaseController):
     :raises: :class:`stem.ControllerError` if the call fails
     """
 
-    response = stem.response._convert_to_single_line(async self.msg('LOADCONF\n%s' % configtext))
+    response = stem.response._convert_to_single_line(await self.msg('LOADCONF\n%s' % configtext))
 
     if response.code in ('552', '553'):
       if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'):
@@ -3379,7 +3382,7 @@ class AsyncController(BaseController):
     response = await self.get_info('circuit-status')
 
     for circ in response.splitlines():
-      circ_message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))))
+      circ_message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))))
       circuits.append(circ_message)  # type: ignore
 
     return circuits
@@ -3563,7 +3566,7 @@ class AsyncController(BaseController):
     response = await self.get_info('stream-status')
 
     for stream in response.splitlines():
-      message = stem.response._convert_to_event(await stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))))
+      message = stem.response._convert_to_event(stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))))
       streams.append(message)  # type: ignore
 
     return streams
@@ -3744,7 +3747,7 @@ class AsyncController(BaseController):
     response = await self.msg('MAPADDRESS %s' % mapaddress_arg)
     return stem.response._convert_to_mapaddress(response).entries
 
-  await def drop_guards(self) -> None:
+  async def drop_guards(self) -> None:
     """
     Drops our present guard nodes and picks a new set.
 
@@ -3812,7 +3815,7 @@ class AsyncController(BaseController):
         if listener_type == event_type:
           for listener in event_listeners:
             try:
-              potential_coroutine = listener(event_message)
+              potential_coroutine = listener(event)
               if asyncio.iscoroutine(potential_coroutine):
                 await potential_coroutine
             except Exception as exc:
@@ -3874,7 +3877,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
   """
 
   @classmethod
-  def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
+  def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller':
     """
     Constructs a :class:`~stem.socket.ControlPort` based Controller.
 
@@ -3885,8 +3888,8 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
     .. versionchanged:: 1.5.0
        Use both port 9051 and 9151 by default.
 
-    :param str address: ip address of the controller
-    :param int port: port number of the controller
+    :param address: ip address of the controller
+    :param port: port number of the controller
 
     :returns: :class:`~stem.control.Controller` attached to the given port
 
@@ -3899,7 +3902,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
     return controller
 
   @classmethod
-  def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.Controller':
+  def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller':
     """
     Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
 
@@ -3915,15 +3918,19 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
     controller.connect()
     return controller
 
-  def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
-  def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
+  def __init__(
+      self,
+      control_socket: stem.socket.ControlSocket,
+      is_authenticated: bool = False,
+      started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None,
+  ) -> None:
     if started_async_controller_thread:
       self._thread_for_wrapped_class = started_async_controller_thread
     else:
       self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
       self._thread_for_wrapped_class.start()
 
-    self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated)
+    self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated)  # type: ignore
     self._socket = self._wrapped_instance._socket
 
   @_set_doc_from_async_controller
@@ -3956,7 +3963,7 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
 
   @_set_doc_from_async_controller
   def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
-    self._wrapped_instance.remove_status_listener(callback)
+    return self._wrapped_instance.remove_status_listener(callback)
 
   @_set_doc_from_async_controller
   def authenticate(self, *args: Any, **kwargs: Any) -> None:
@@ -4306,7 +4313,7 @@ def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], k
   raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries))
 
 
-async def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any:
+async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float], start_time: float) -> Any:
   """
   Pulls an item from a queue with a given timeout.
   """
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index d7a833a4..e7ccaa24 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -104,7 +104,7 @@ import stem.util.tor_tools
 
 from stem.descriptor import Compression
 from stem.util import log, str_tools
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
 
 # Tor has a limited number of descriptors we can fetch explicitly by their
 # fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -392,7 +392,7 @@ class AsyncQuery(object):
     self.reply_headers = None  # type: Optional[Dict[str, str]]
     self.kwargs = kwargs
 
-    self._downloader_task = None
+    self._downloader_task = None  # type: Optional[asyncio.Task]
     self._downloader_lock = threading.RLock()
 
     self._asyncio_loop = asyncio.get_event_loop()
@@ -401,7 +401,7 @@ class AsyncQuery(object):
       self.start()
 
     if block:
-      self._asyncio_loop.create_task(self.run(True))
+      self.run(True)
 
   def start(self) -> None:
     """
@@ -432,7 +432,7 @@ class AsyncQuery(object):
 
     return [desc async for desc in self._run(suppress)]
 
-  async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
+  async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
     with self._downloader_lock:
       self.start()
       await self._downloader_task
@@ -468,7 +468,7 @@ class AsyncQuery(object):
 
           raise self.error
 
-  async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]:
+  async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]:
     async for desc in self._run(True):
       yield desc
 
@@ -665,7 +665,7 @@ class Query(stem.util.AsyncClassWrapper):
   def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
     self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
     self._thread_for_wrapped_class.start()
-    self._wrapped_instance: AsyncQuery = self._init_async_class(
+    self._wrapped_instance: AsyncQuery = self._init_async_class(  # type: ignore
       AsyncQuery,
       resource,
       descriptor_type,
@@ -688,7 +688,7 @@ class Query(stem.util.AsyncClassWrapper):
 
     self._call_async_method_soon('start')
 
-  def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+  def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
     """
     Blocks until our request is complete then provides the descriptors. If we
     haven't yet started our request then this does so.
@@ -708,7 +708,7 @@ class Query(stem.util.AsyncClassWrapper):
 
     return self._execute_async_method('run', suppress)
 
-  def __iter__(self):
+  def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
     for desc in self._execute_async_generator_method('__aiter__'):
       yield desc
 
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index ae064a0a..370b9aa6 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -124,10 +124,10 @@ def main() -> None:
 
     if args.run_cmd:
       if args.run_cmd.upper().startswith('SETEVENTS '):
-        async def handle_event(event_message):
+        async def handle_event(event_message: stem.response.ControlMessage) -> None:
           print(format(str(event_message), *STANDARD_OUTPUT))
 
-        controller._wrapped_instance._handle_event = handle_event
+        controller._wrapped_instance._handle_event = handle_event  # type: ignore
 
         if sys.stdout.isatty():
           events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py
index 54642472..ed51fd3d 100644
--- a/stem/interpreter/autocomplete.py
+++ b/stem/interpreter/autocomplete.py
@@ -11,7 +11,7 @@ import stem.control
 import stem.util.conf
 
 from stem.interpreter import uses_settings
-from typing import List, Optional
+from typing import cast, List, Optional
 
 
 @uses_settings
@@ -28,7 +28,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
   # GETINFO commands. Lines are of the form '[option] -- [description]'. This
   # strips '*' from options that accept values.
 
-  results = controller.get_info('info/names', None)
+  results = cast(str, controller.get_info('info/names', None))
 
   if results:
     for line in results.splitlines():
@@ -40,7 +40,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
   # GETCONF, SETCONF, and RESETCONF commands. Lines are of the form
   # '[option] [type]'.
 
-  results = controller.get_info('config/names', None)
+  results = cast(str, controller.get_info('config/names', None))
 
   if results:
     for line in results.splitlines():
@@ -62,7 +62,7 @@ def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Co
   )
 
   for prefix, getinfo_cmd in options:
-    results = controller.get_info(getinfo_cmd, None)
+    results = cast(str, controller.get_info(getinfo_cmd, None))
 
     if results:
       commands += [prefix + value for value in results.split()]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index edbcca70..99f1219d 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -21,7 +21,7 @@ import stem.util.tor_tools
 
 from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg
 from stem.util.term import format
-from typing import Iterator, List, TextIO
+from typing import cast, Iterator, List, TextIO
 
 MAX_EVENTS = 100
 
@@ -45,7 +45,7 @@ def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str:
 
   if not arg:
     try:
-      return controller.get_info('fingerprint')
+      return cast(str, controller.get_info('fingerprint'))
     except:
       raise ValueError("We aren't a relay, no information to provide")
   elif stem.util.tor_tools.is_valid_fingerprint(arg):
@@ -132,14 +132,14 @@ class ControlInterpreter(code.InteractiveConsole):
 
     async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
       await handle_event_real(event_message)
-      self._received_events.insert(0, event_message)
+      self._received_events.insert(0, event_message)  # type: ignore
 
       if len(self._received_events) > MAX_EVENTS:
         self._received_events.pop()
 
     # type check disabled due to https://github.com/python/mypy/issues/708
 
-    self._controller._wrapped_instance._handle_event = handle_event_wrapper
+    self._controller._wrapped_instance._handle_event = handle_event_wrapper  # type: ignore
 
   def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
     events = list(self._received_events)
@@ -207,7 +207,7 @@ class ControlInterpreter(code.InteractiveConsole):
     extrainfo_desc_query = downloader.get_extrainfo_descriptors(fingerprint)
 
     for desc in server_desc_query:
-      server_desc = desc
+      server_desc = cast(stem.descriptor.server_descriptor.RelayDescriptor, desc)
 
     for desc in extrainfo_desc_query:
       extrainfo_desc = desc
@@ -220,7 +220,7 @@ class ControlInterpreter(code.InteractiveConsole):
       pass
 
     try:
-      address_extrainfo.append(self._controller.get_info('ip-to-country/%s' % ns_desc.address))
+      address_extrainfo.append(cast(str, self._controller.get_info('ip-to-country/%s' % ns_desc.address)))
     except:
       pass
 
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py
index 14b46e35..f2bbbafd 100644
--- a/stem/interpreter/help.py
+++ b/stem/interpreter/help.py
@@ -6,6 +6,7 @@ Provides our /help responses.
 """
 
 import functools
+from typing import cast
 
 import stem.control
 import stem.util.conf
@@ -74,7 +75,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
   output += '\n'
 
   if arg == 'GETINFO':
-    results = controller.get_info('info/names', None)
+    results = cast(str, controller.get_info('info/names', None))
 
     if results:
       for line in results.splitlines():
@@ -84,7 +85,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
           output += format('%-33s' % opt, *BOLD_OUTPUT)
           output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n'
   elif arg == 'GETCONF':
-    results = controller.get_info('config/names', None)
+    results = cast(str, controller.get_info('config/names', None))
 
     if results:
       options = [opt.split(' ', 1)[0] for opt in results.splitlines()]
@@ -103,7 +104,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
       output += format('%-15s' % signal, *BOLD_OUTPUT)
       output += format(' - %s' % summary, *STANDARD_OUTPUT) + '\n'
   elif arg == 'SETEVENTS':
-    results = controller.get_info('events/names', None)
+    results = cast(str, controller.get_info('events/names', None))
 
     if results:
       entries = results.split()
@@ -118,7 +119,7 @@ def _response(controller: stem.control.Controller, arg: str, config: stem.util.c
 
         output += format(line.rstrip(), *STANDARD_OUTPUT) + '\n'
   elif arg == 'USEFEATURE':
-    results = controller.get_info('features/names', None)
+    results = cast(str, controller.get_info('features/names', None))
 
     if results:
       output += format(results, *STANDARD_OUTPUT) + '\n'
diff --git a/stem/response/__init__.py b/stem/response/__init__.py
index 52dc74e4..2e251144 100644
--- a/stem/response/__init__.py
+++ b/stem/response/__init__.py
@@ -202,7 +202,7 @@ class ControlMessage(object):
 
       content = re.sub(b'([\r]?)\n', b'\r\n', content)
 
-    msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None))
+    msg = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None))
 
     if msg_type is not None:
       convert(msg_type, msg, **kwargs)
diff --git a/stem/socket.py b/stem/socket.py
index 0feae831..ff99c5b1 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -85,6 +85,7 @@ import asyncio
 import re
 import socket
 import ssl
+import sys
 import threading
 import time
 
@@ -93,7 +94,7 @@ import stem.util.str_tools
 
 from stem.util import log
 from types import TracebackType
-from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
+from typing import Awaitable, BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
 
 MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]')
 ERROR_MSG = 'Error while receiving a control message (%s): %s'
@@ -109,8 +110,8 @@ class BaseSocket(object):
   """
 
   def __init__(self) -> None:
-    self._reader = None
-    self._writer = None
+    self._reader = None  # type: Optional[asyncio.StreamReader]
+    self._writer = None  # type: Optional[asyncio.StreamWriter]
     self._is_alive = False
     self._connection_time = 0.0  # time when we last connected or disconnected
 
@@ -209,7 +210,7 @@ class BaseSocket(object):
       if self._writer:
         self._writer.close()
         # `StreamWriter.wait_closed` was added in Python 3.7.
-        if hasattr(self._writer, 'wait_closed'):
+        if sys.version_info >= (3, 7):
           await self._writer.wait_closed()
 
       self._reader = None
@@ -220,7 +221,7 @@ class BaseSocket(object):
       if is_change:
         await self._close()
 
-  async def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None:
+  async def _send(self, message: Union[bytes, str], handler: Callable[[asyncio.StreamWriter, Union[bytes, str]], Awaitable[None]]) -> None:
     """
     Send message in a thread safe manner.
     """
@@ -241,11 +242,11 @@ class BaseSocket(object):
         raise
 
   @overload
-  async def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes:
+  async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[bytes]]) -> bytes:
     ...
 
   @overload
-  async def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage:
+  async def _recv(self, handler: Callable[[asyncio.StreamReader], Awaitable[stem.response.ControlMessage]]) -> stem.response.ControlMessage:
     ...
 
   async def _recv(self, handler):
@@ -303,7 +304,7 @@ class BaseSocket(object):
     return self
 
   async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
-    self.close()
+    await self.close()
 
   async def _connect(self) -> None:
     """
@@ -320,12 +321,6 @@ class BaseSocket(object):
     pass
 
   async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
-    """
-    Constructs and connects new socket. This is implemented by subclasses.
-
-    :returns: **tuple** with our reader and writer streams
-    """
-
     raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
 
 
@@ -380,7 +375,7 @@ class RelaySocket(BaseSocket):
       * :class:`stem.SocketClosed` if the socket closes before we receive a complete message
     """
 
-    async def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes:
+    async def wrapped_recv(reader: asyncio.StreamReader) -> Optional[bytes]:
       read_coroutine = reader.read(1024)
       if timeout is None:
         return await read_coroutine
@@ -524,7 +519,7 @@ async def send_message(writer: asyncio.StreamWriter, message: Union[bytes, str],
     <line 3>\\r\\n
     .\\r\\n
 
-  :param writer: stream derived from the control socket
+  :param writer: writer object
   :param message: message to be sent on the control socket
   :param raw: leaves the message formatting untouched, passing it to the
     socket as-is
@@ -591,9 +586,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float]
 
   while True:
     try:
-      line = reader.readline()
-      if asyncio.iscoroutine(line):
-        line = await line
+      line = await reader.readline()
     except AttributeError:
       # if the control_file has been closed then we will receive:
       # AttributeError: 'NoneType' object has no attribute 'recv'
@@ -693,7 +686,7 @@ async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float]
       raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
 
 
-def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optional[float] = None) -> 'stem.response.ControlMessage':
+def recv_message_from_bytes_io(reader: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage:
   """
   Pulls from an I/O stream until we either have a complete message or
   encounter a problem.
@@ -708,7 +701,9 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
       a complete message
   """
 
-  parsed_content, raw_content, first_line = None, None, True
+  parsed_content = []  # type: List[Tuple[str, str, bytes]]
+  raw_content = bytearray()
+  first_line = True
 
   while True:
     try:
@@ -739,10 +734,10 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
       log.info(ERROR_MSG % ('SocketClosed', 'empty socket content'))
       raise stem.SocketClosed('Received empty socket content.')
     elif not MESSAGE_PREFIX.match(line):
-      log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line)))
+      log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8'))))
       raise stem.ProtocolError('Badly formatted reply line: beginning is malformed')
     elif not line.endswith(b'\r\n'):
-      log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line)))
+      log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8'))))
       raise stem.ProtocolError('All lines should end with CRLF')
 
     status_code, divider, content = line[:3], line[3:4], line[4:-2]  # strip CRLF off content
@@ -781,11 +776,11 @@ def recv_message_from_bytes_io(reader: asyncio.StreamReader, arrived_at: Optiona
           line = reader.readline()
           raw_content += line
         except socket.error as exc:
-          log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
+          log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8')))))
           raise stem.SocketClosed(exc)
 
         if not line.endswith(b'\r\n'):
-          log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content))))
+          log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8'))))
           raise stem.ProtocolError('All lines should end with CRLF')
         elif line == b'.\r\n':
           break  # data block termination
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 683e555f..3eaae8d9 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -3,6 +3,7 @@ Integration tests for authenticating to the control socket via
 stem.connection.authenticate* functions.
 """
 
+import asyncio
 import os
 import unittest
 
@@ -121,7 +122,10 @@ class TestAuthenticate(unittest.TestCase):
     runner = test.runner.get_runner()
 
     with await runner.get_tor_controller(False) as controller:
-      await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
+      asyncio.run_coroutine_threadsafe(
+        stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
+        controller._thread_for_wrapped_class.loop,
+      ).result()
       await test.runner.exercise_controller(self, controller)
 
   @test.require.controller
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index a11aba45..e4b11788 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -222,7 +222,7 @@ class TestControl(unittest.TestCase):
 
     get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
 
-    async def get_conf_mock_side_effect(param, **kwargs):
+    async def get_conf_mock_side_effect(param, *args, **kwargs):
       return {
         'ControlPort': '9050',
         'ControlListenAddress': ['127.0.0.1'],
@@ -236,7 +236,7 @@ class TestControl(unittest.TestCase):
 
     # non-local addresss
 
-    async def get_conf_mock_side_effect(param, **kwargs):
+    async def get_conf_mock_side_effect(param, *args, **kwargs):
       return {
         'ControlPort': '9050',
         'ControlListenAddress': ['27.4.4.1'],
@@ -717,7 +717,7 @@ class TestControl(unittest.TestCase):
 
     # check default if nothing was set
 
-    async def get_conf_mock_side_effect(param, **kwargs):
+    async def get_conf_mock_side_effect(param, *args, **kwargs):
       return {
         'BandwidthRate': '1073741824',
         'BandwidthBurst': '1073741824',





More information about the tor-commits mailing list