[tor-commits] [stem/master] Make requesting for descriptor content asynchronous

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020


commit 841e2105147177f5959987c8bec1179dc94a59b3
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Thu May 14 00:08:59 2020 +0300

    Make requesting for descriptor content asynchronous
---
 stem/client/__init__.py         |  91 +++++++++++++-------------
 stem/descriptor/remote.py       | 138 ++++++++++++++++++++++++++++------------
 stem/util/test_tools.py         |   4 +-
 test/integ/client/connection.py |  38 +++++++----
 test/unit/descriptor/remote.py  |  24 +++++--
 5 files changed, 187 insertions(+), 108 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 639f118f..941f0ee7 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -26,7 +26,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
 """
 
 import hashlib
-import threading
 
 import stem
 import stem.client.cell
@@ -71,11 +70,10 @@ class Relay(object):
     self.link_protocol = LinkProtocol(link_protocol)
     self._orport = orport
     self._orport_buffer = b''  # unread bytes
-    self._orport_lock = threading.RLock()
-    self._circuits = {}  # type: Dict[int, stem.client.Circuit]
+    self._circuits = {}
 
   @staticmethod
-  def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay':  # type: ignore
+  async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay':  # type: ignore
     """
     Establishes a connection with the given ORPort.
 
@@ -97,8 +95,9 @@ class Relay(object):
 
     try:
       conn = stem.socket.RelaySocket(address, port)
+      await conn.connect()
     except stem.SocketError as exc:
-      if 'Connection refused' in str(exc):
+      if 'Connect call failed' in str(exc):
         raise stem.SocketError("Failed to connect to %s:%i. Maybe it isn't an ORPort?" % (address, port))
 
       # If not an ORPort (for instance, mistakenly connecting to a ControlPort
@@ -122,21 +121,21 @@ class Relay(object):
     #   first VERSIONS cell, always have CIRCID_LEN == 2 for backward
     #   compatibility.
 
-    conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2))  # type: ignore
-    response = conn.recv()
+    await conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2))  # type: ignore
+    response = await conn.recv()
 
     # Link negotiation ends right away if we lack a common protocol
     # version. (#25139)
 
     if not response:
-      conn.close()
+      await conn.close()
       raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
 
     versions_reply = stem.client.cell.Cell.pop(response, 2)[0]  # type: stem.client.cell.VersionsCell # type: ignore
     common_protocols = set(link_protocols).intersection(versions_reply.versions)
 
     if not common_protocols:
-      conn.close()
+      await conn.close()
       raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
 
     # Establishing connections requires sending a NETINFO, but including our
@@ -144,14 +143,14 @@ class Relay(object):
     # where it would help.
 
     link_protocol = max(common_protocols)
-    conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
+    await conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
 
     return Relay(conn, link_protocol)
 
-  def _recv_bytes(self) -> bytes:
-    return self._recv(True)  # type: ignore
+  async def _recv_bytes(self) -> bytes:
+    return await self._recv(True)  # type: ignore
 
-  def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
+  async def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
     """
     Reads the next cell from our ORPort. If none is present this blocks
     until one is available.
@@ -161,13 +160,13 @@ class Relay(object):
     :returns: next :class:`~stem.client.cell.Cell`
     """
 
-    with self._orport_lock:
+    async with self._orport_lock:
       # cells begin with [circ_id][cell_type][...]
 
       circ_id_size = self.link_protocol.circ_id_size.size
 
       while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size):
-        self._orport_buffer += self._orport.recv()  # read until we know the cell type
+        self._orport_buffer += await self._orport.recv()  # read until we know the cell type
 
       cell_type = Cell.by_value(CELL_TYPE_SIZE.pop(self._orport_buffer[circ_id_size:])[0])
 
@@ -177,13 +176,13 @@ class Relay(object):
         # variable length, our next field is the payload size
 
         while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN):
-          self._orport_buffer += self._orport.recv()  # read until we know the cell size
+          self._orport_buffer += await self._orport.recv()  # read until we know the cell size
 
         payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0]
         cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
 
       while len(self._orport_buffer) < cell_size:
-        self._orport_buffer += self._orport.recv()  # read until we have the full cell
+        self._orport_buffer += await self._orport.recv()  # read until we have the full cell
 
       if raw:
         content, self._orport_buffer = split(self._orport_buffer, cell_size)
@@ -192,7 +191,7 @@ class Relay(object):
         cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
         return cell
 
-  def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
+  async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
     """
     Sends a cell on the ORPort and provides the response we receive in reply.
 
@@ -219,9 +218,9 @@ class Relay(object):
 
     # TODO: why is this an iterator?
 
-    self._orport.recv(timeout = 0)  # discard unread data
-    self._orport.send(cell.pack(self.link_protocol))
-    response = self._orport.recv(timeout = 1)
+    await self._orport.recv(timeout = 0)  # discard unread data
+    await self._orport.send(cell.pack(self.link_protocol))
+    response = await self._orport.recv(timeout = 1)
     yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
 
   def is_alive(self) -> bool:
@@ -246,27 +245,27 @@ class Relay(object):
 
     return self._orport.connection_time()
 
-  def close(self) -> None:
+  async def close(self) -> None:
     """
     Closes our socket connection. This is a pass-through for our socket's
     :func:`~stem.socket.BaseSocket.close` method.
     """
 
-    with self._orport_lock:
-      return self._orport.close()
+    async with self._orport_lock:
+      return await self._orport.close()
 
-  def create_circuit(self) -> 'stem.client.Circuit':
+  async def create_circuit(self) -> 'stem.client.Circuit':
     """
     Establishes a new circuit.
     """
 
-    with self._orport_lock:
+    async with self._orport_lock:
       circ_id = max(self._circuits) + 1 if self._circuits else self.link_protocol.first_circ_id
 
       create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
       created_fast_cell = None
 
-      for cell in self._msg(create_fast_cell):
+      async for cell in self._msg(create_fast_cell):
         if isinstance(cell, stem.client.cell.CreatedFastCell):
           created_fast_cell = cell
           break
@@ -284,16 +283,16 @@ class Relay(object):
 
       return circ
 
-  def __iter__(self) -> Iterator['stem.client.Circuit']:
-    with self._orport_lock:
+  async def __aiter__(self) -> Iterator['stem.client.Circuit']:
+    async with self._orport_lock:
       for circ in self._circuits.values():
         yield circ
 
-  def __enter__(self) -> 'stem.client.Relay':
+  async def __aenter__(self) -> 'stem.client.Relay':
     return self
 
-  def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
-    self.close()
+  async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+    await self.close()
 
 
 class Circuit(object):
@@ -327,7 +326,7 @@ class Circuit(object):
     self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
     self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
 
-  def directory(self, request: str, stream_id: int = 0) -> bytes:
+  async def directory(self, request: str, stream_id: int = 0) -> bytes:
     """
     Request descriptors from the relay.
 
@@ -337,9 +336,9 @@ class Circuit(object):
     :returns: **str** with the requested descriptor data
     """
 
-    with self.relay._orport_lock:
-      self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
-      self._send(RelayCommand.DATA, request, stream_id = stream_id)
+    async with self.relay._orport_lock:
+      await self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
+      await self._send(RelayCommand.DATA, request, stream_id = stream_id)
 
       response = []  # type: List[stem.client.cell.RelayCell]
 
@@ -347,7 +346,7 @@ class Circuit(object):
         # Decrypt relay cells received in response. Our digest/key only
         # updates when handled successfully.
 
-        encrypted_cell = self.relay._recv_bytes()
+        encrypted_cell = await self.relay._recv_bytes()
 
         decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
 
@@ -362,7 +361,7 @@ class Circuit(object):
         else:
           response.append(decrypted_cell)
 
-  def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
+  async def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
     """
     Sends a message over the circuit.
 
@@ -371,24 +370,24 @@ class Circuit(object):
     :param stream_id: specific stream this concerns
     """
 
-    with self.relay._orport_lock:
+    async with self.relay._orport_lock:
       # Encrypt and send the cell. Our digest/key only updates if the cell is
       # successfully sent.
 
       cell = stem.client.cell.RelayCell(self.id, command, data, stream_id = stream_id)
       payload, forward_key, forward_digest = cell.encrypt(self.relay.link_protocol, self.forward_key, self.forward_digest)
-      self.relay._orport.send(payload)
+      await self.relay._orport.send(payload)
 
       self.forward_digest = forward_digest
       self.forward_key = forward_key
 
-  def close(self) -> None:
-    with self.relay._orport_lock:
-      self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
+  async def close(self)- > None:
+    async with self.relay._orport_lock:
+      await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
       del self.relay._circuits[self.id]
 
-  def __enter__(self) -> 'stem.client.Circuit':
+  async def __aenter__(self) -> 'stem.client.Circuit':
     return self
 
-  def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
-    self.close()
+  async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+    await self.close()
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e90c4442..eca846ee 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -83,6 +83,8 @@ content. For example...
   hashes.
 """
 
+import asyncio
+import functools
 import io
 import random
 import socket
@@ -93,6 +95,7 @@ import urllib.request
 
 import stem
 import stem.client
+import stem.control
 import stem.descriptor
 import stem.descriptor.networkstatus
 import stem.directory
@@ -227,7 +230,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
   return get_instance().get_detached_signatures(**query_args)
 
 
-class Query(object):
+class AsyncQuery(object):
   """
   Asynchronous request for descriptor content from a directory authority or
   mirror. These can either be made through the
@@ -427,32 +430,27 @@ class Query(object):
     self.reply_headers = None  # type: Optional[Dict[str, str]]
     self.kwargs = kwargs
 
-    self._downloader_thread = None  # type: Optional[threading.Thread]
-    self._downloader_thread_lock = threading.RLock()
+    self._downloader_task = None
+    self._downloader_lock = threading.RLock()
+
+    self._asyncio_loop = asyncio.get_event_loop()
 
     if start:
       self.start()
 
     if block:
-      self.run(True)
+      self._asyncio_loop.create_task(self.run(True))
 
   def start(self) -> None:
     """
     Starts downloading the scriptors if we haven't started already.
     """
 
-    with self._downloader_thread_lock:
-      if self._downloader_thread is None:
-        self._downloader_thread = threading.Thread(
-          name = 'Descriptor query',
-          target = self._download_descriptors,
-          args = (self.retries, self.timeout)
-        )
-
-        self._downloader_thread.setDaemon(True)
-        self._downloader_thread.start()
+    with self._downloader_lock:
+      if self._downloader_task is None:
+        self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
 
-  def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+  async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
     """
     Blocks until our request is complete then provides the descriptors. If we
     haven't yet started our request then this does so.
@@ -470,12 +468,12 @@ class Query(object):
         * :class:`~stem.DownloadFailed` if our request fails
     """
 
-    return list(self._run(suppress))
+    return [desc async for desc in self._run(suppress)]
 
-  def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
-    with self._downloader_thread_lock:
+  async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
+    with self._downloader_lock:
       self.start()
-      self._downloader_thread.join()
+      await self._downloader_task
 
       if self.error:
         if suppress:
@@ -508,8 +506,8 @@ class Query(object):
 
           raise self.error
 
-  def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
-    for desc in self._run(True):
+  async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]:
+    async for desc in self._run(True):
       yield desc
 
   def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
@@ -530,18 +528,18 @@ class Query(object):
     else:
       return random.choice(self.endpoints)
 
-  def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
+  async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
     try:
       self.start_time = time.time()
       endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
 
       if isinstance(endpoint, stem.ORPort):
         downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
-        self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource)
+        self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
       elif isinstance(endpoint, stem.DirPort):
         self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
         downloaded_from = self.download_url
-        self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
+        self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
       else:
         raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
 
@@ -555,7 +553,7 @@ class Query(object):
 
       if retries > 0 and (timeout is None or timeout > 0):
         log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
-        return self._download_descriptors(retries - 1, timeout)
+        return await self._download_descriptors(retries - 1, timeout)
       else:
         log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
         self.error = exc
@@ -563,6 +561,64 @@ class Query(object):
       self.is_done = True
 
 
+class Query(stem.util.AsyncClassWrapper):
+  def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs):
+    self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+    self._thread_for_wrapped_class.start()
+    self._wrapped_instance = self._init_async_class(
+      AsyncQuery,
+      resource,
+      descriptor_type,
+      endpoints,
+      compression,
+      retries,
+      fall_back_to_authority,
+      timeout,
+      start,
+      block,
+      validate,
+      document_handler,
+      **kwargs,
+    )
+
+  def start(self):
+    return self._call_async_method_soon('start')
+
+  def run(self, suppress = False):
+    return self._execute_async_method('run', suppress)
+
+  def __iter__(self):
+    for desc in self._execute_async_generator_method('__aiter__'):
+      yield desc
+
+  # Add public attributes of `AsyncQuery` as properties.
+  for attr in (
+    'descriptor_type',
+    'endpoints',
+    'resource',
+    'compression',
+    'retries',
+    'fall_back_to_authority',
+    'content',
+    'error',
+    'is_done',
+    'download_url',
+    'start_time',
+    'timeout',
+    'runtime',
+    'validate',
+    'document_handler',
+    'reply_headers',
+    'kwargs',
+  ):
+    locals()[attr] = property(
+      functools.partial(
+        lambda self, attr_name: getattr(self._wrapped_instance, attr_name),
+        attr_name=attr,
+      ),
+    )
+
+
 class DescriptorDownloader(object):
   """
   Configurable class that issues :class:`~stem.descriptor.remote.Query`
@@ -925,7 +981,7 @@ class DescriptorDownloader(object):
     return Query(resource, **args)
 
 
-def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
   """
   Downloads descriptors from the given orport. Payload is just like an http
   response (headers and all)...
@@ -956,15 +1012,15 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
 
   link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
 
-  with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
-    with relay.create_circuit() as circ:
+  async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+    async with await relay.create_circuit() as circ:
       request = '\r\n'.join((
         'GET %s HTTP/1.0' % resource,
         'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
         'User-Agent: %s' % stem.USER_AGENT,
       )) + '\r\n\r\n'
 
-      response = circ.directory(request, stream_id = 1)
+      response = await circ.directory(request, stream_id = 1)
       first_line, data = response.split(b'\r\n', 1)
       header_data, body_data = data.split(b'\r\n\r\n', 1)
 
@@ -983,7 +1039,7 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
       return _decompress(body_data, headers.get('Content-Encoding')), headers
 
 
-def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
   """
   Downloads descriptors from the given url.
 
@@ -998,17 +1054,19 @@ def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Comp
     * :class:`~stem.DownloadFailed` if our request fails
   """
 
+  # TODO: use an asyncronous solution for the HTTP request.
+  request = urllib.request.Request(
+    url,
+    headers = {
+      'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
+      'User-Agent': stem.USER_AGENT,
+    }
+  )
+  get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
+
+  loop = asyncio.get_event_loop()
   try:
-    response = urllib.request.urlopen(
-      urllib.request.Request(
-        url,
-        headers = {
-          'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
-          'User-Agent': stem.USER_AGENT,
-        }
-      ),
-      timeout = timeout,
-    )
+    response = await loop.run_in_executor(None, get_response)
   except socket.timeout as exc:
     raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
   except:
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index f3c736a1..455f3da3 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -251,7 +251,7 @@ class TimedTestRunner(unittest.TextTestRunner):
           TEST_RUNTIMES[self.id()] = time.time() - start_time
           return result
 
-        def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None:
+        def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, *args: Any, **kwargs: Any) -> None:
           """
           Asserts the given invokation raises the expected excepiton. This is
           similar to unittest's assertRaises and assertRaisesRegexp, but checks
@@ -262,7 +262,7 @@ class TimedTestRunner(unittest.TextTestRunner):
           vended API then please let us know.
           """
 
-          return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs)
+          return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs)
 
         def id(self) -> str:
           return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName)
diff --git a/test/integ/client/connection.py b/test/integ/client/connection.py
index 2294a07d..316d54ba 100644
--- a/test/integ/client/connection.py
+++ b/test/integ/client/connection.py
@@ -9,46 +9,57 @@ import stem
 import test.runner
 
 from stem.client import Relay
+from stem.util.test_tools import async_test
 
 
 class TestConnection(unittest.TestCase):
-  def test_invalid_arguments(self):
+  @async_test
+  async def test_invalid_arguments(self):
     """
     Provide invalid arguments to Relay.connect().
     """
 
-    self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address", Relay.connect, 'nope', 80)
-    self.assertRaisesWith(ValueError, "'-54' isn't a valid port", Relay.connect, '127.0.0.1', -54)
-    self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol.", Relay.connect, '127.0.0.1', 54, [])
+    with self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address"):
+      await Relay.connect('nope', 80)
+    with self.assertRaisesWith(ValueError, "'-54' isn't a valid port"):
+      await Relay.connect('127.0.0.1', -54)
+    with self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol."):
+      await Relay.connect('127.0.0.1', 54, [])
 
-  def test_not_orport(self):
+  @async_test
+  async def test_not_orport(self):
     """
     Attempt to connect to an ORPort that doesn't exist.
     """
 
-    self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', 1587)
+    with self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?"):
+      await Relay.connect('127.0.0.1', 1587)
 
     # connect to our ControlPort like it's an ORPort
 
     if test.runner.Torrc.PORT in test.runner.get_runner().get_options():
-      self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', test.runner.CONTROL_PORT)
+      with self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?"):
+        await Relay.connect('127.0.0.1', test.runner.CONTROL_PORT)
 
-  def test_no_common_link_protocol(self):
+  @async_test
+  async def test_no_common_link_protocol(self):
     """
     Connection without a commonly accepted link protocol version.
     """
 
     for link_protocol in (1, 2, 6, 20):
-      self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113', Relay.connect, '127.0.0.1', test.runner.ORPORT, [link_protocol])
+      with self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113'):
+        await Relay.connect('127.0.0.1', test.runner.ORPORT, [link_protocol])
 
-  def test_connection_time(self):
+  @async_test
+  async def test_connection_time(self):
     """
     Checks duration we've been connected.
     """
 
     before = time.time()
 
-    with Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
+    async with await Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
       connection_time = conn.connection_time()
       self.assertTrue(time.time() >= connection_time >= before)
       time.sleep(0.02)
@@ -57,10 +68,11 @@ class TestConnection(unittest.TestCase):
     self.assertFalse(conn.is_alive())
     self.assertTrue(conn.connection_time() >= connection_time + 0.02)
 
-  def test_established(self):
+  @async_test
+  async def test_established(self):
     """
     Successfully establish ORPort connection.
     """
 
-    conn = Relay.connect('127.0.0.1', test.runner.ORPORT)
+    conn = await Relay.connect('127.0.0.1', test.runner.ORPORT)
     self.assertTrue(int(conn.link_protocol) in (4, 5))
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index e57da92b..33ee57fb 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -13,9 +13,10 @@ import stem.descriptor.remote
 import stem.util.str_tools
 import test.require
 
-from unittest.mock import patch, Mock, MagicMock
+from unittest.mock import patch, Mock
 
 from stem.descriptor.remote import Compression
+from stem.util.test_tools import coro_func_returning_value
 from test.unit.descriptor import read_resource
 
 TEST_RESOURCE = '/tor/server/fp/9695DFC35FFEB861329B9F1AB04C46397020CE31'
@@ -78,11 +79,20 @@ def _orport_mock(data, encoding = 'identity', response_code_header = None):
     cell.data = hunk
     cells.append(cell)
 
-  connect_mock = MagicMock()
-  relay_mock = connect_mock().__enter__()
-  circ_mock = relay_mock.create_circuit().__enter__()
-  circ_mock.directory.return_value = data
-  return connect_mock
+  class AsyncMock(Mock):
+    async def __aenter__(self):
+      return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+      return
+
+  circ_mock = AsyncMock()
+  circ_mock.directory.side_effect = coro_func_returning_value(data)
+
+  relay_mock = AsyncMock()
+  relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
+
+  return coro_func_returning_value(relay_mock)
 
 
 def _dirport_mock(data, encoding = 'identity'):
@@ -294,7 +304,7 @@ class TestDescriptorDownloader(unittest.TestCase):
       skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
     )
 
-    self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint())
+    self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
 
     descriptors = list(query)
     self.assertEqual(1, len(descriptors))





More information about the tor-commits mailing list