commit 841e2105147177f5959987c8bec1179dc94a59b3 Author: Illia Volochii illia.volochii@gmail.com Date: Thu May 14 00:08:59 2020 +0300
Make requesting for descriptor content asynchronous --- stem/client/__init__.py | 91 +++++++++++++------------- stem/descriptor/remote.py | 138 ++++++++++++++++++++++++++++------------ stem/util/test_tools.py | 4 +- test/integ/client/connection.py | 38 +++++++---- test/unit/descriptor/remote.py | 24 +++++-- 5 files changed, 187 insertions(+), 108 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 639f118f..941f0ee7 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -26,7 +26,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as """
import hashlib -import threading
import stem import stem.client.cell @@ -71,11 +70,10 @@ class Relay(object): self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_buffer = b'' # unread bytes - self._orport_lock = threading.RLock() - self._circuits = {} # type: Dict[int, stem.client.Circuit] + self._circuits = {}
@staticmethod - def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore + async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore """ Establishes a connection with the given ORPort.
@@ -97,8 +95,9 @@ class Relay(object):
try: conn = stem.socket.RelaySocket(address, port) + await conn.connect() except stem.SocketError as exc: - if 'Connection refused' in str(exc): + if 'Connect call failed' in str(exc): raise stem.SocketError("Failed to connect to %s:%i. Maybe it isn't an ORPort?" % (address, port))
# If not an ORPort (for instance, mistakenly connecting to a ControlPort @@ -122,21 +121,21 @@ class Relay(object): # first VERSIONS cell, always have CIRCID_LEN == 2 for backward # compatibility.
- conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore - response = conn.recv() + await conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore + response = await conn.recv()
# Link negotiation ends right away if we lack a common protocol # version. (#25139)
if not response: - conn.close() + await conn.close() raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore common_protocols = set(link_protocols).intersection(versions_reply.versions)
if not common_protocols: - conn.close() + await conn.close() raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
# Establishing connections requires sending a NETINFO, but including our @@ -144,14 +143,14 @@ class Relay(object): # where it would help.
link_protocol = max(common_protocols) - conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol)) + await conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
return Relay(conn, link_protocol)
- def _recv_bytes(self) -> bytes: - return self._recv(True) # type: ignore + async def _recv_bytes(self) -> bytes: + return await self._recv(True) # type: ignore
- def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell': + async def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell': """ Reads the next cell from our ORPort. If none is present this blocks until one is available. @@ -161,13 +160,13 @@ class Relay(object): :returns: next :class:`~stem.client.cell.Cell` """
- with self._orport_lock: + async with self._orport_lock: # cells begin with [circ_id][cell_type][...]
circ_id_size = self.link_protocol.circ_id_size.size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size): - self._orport_buffer += self._orport.recv() # read until we know the cell type + self._orport_buffer += await self._orport.recv() # read until we know the cell type
cell_type = Cell.by_value(CELL_TYPE_SIZE.pop(self._orport_buffer[circ_id_size:])[0])
@@ -177,13 +176,13 @@ class Relay(object): # variable length, our next field is the payload size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN): - self._orport_buffer += self._orport.recv() # read until we know the cell size + self._orport_buffer += await self._orport.recv() # read until we know the cell size
payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0] cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
while len(self._orport_buffer) < cell_size: - self._orport_buffer += self._orport.recv() # read until we have the full cell + self._orport_buffer += await self._orport.recv() # read until we have the full cell
if raw: content, self._orport_buffer = split(self._orport_buffer, cell_size) @@ -192,7 +191,7 @@ class Relay(object): cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol) return cell
- def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']: + async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']: """ Sends a cell on the ORPort and provides the response we receive in reply.
@@ -219,9 +218,9 @@ class Relay(object):
# TODO: why is this an iterator?
- self._orport.recv(timeout = 0) # discard unread data - self._orport.send(cell.pack(self.link_protocol)) - response = self._orport.recv(timeout = 1) + await self._orport.recv(timeout = 0) # discard unread data + await self._orport.send(cell.pack(self.link_protocol)) + response = await self._orport.recv(timeout = 1) yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
def is_alive(self) -> bool: @@ -246,27 +245,27 @@ class Relay(object):
return self._orport.connection_time()
- def close(self) -> None: + async def close(self) -> None: """ Closes our socket connection. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.close` method. """
- with self._orport_lock: - return self._orport.close() + async with self._orport_lock: + return await self._orport.close()
- def create_circuit(self) -> 'stem.client.Circuit': + async def create_circuit(self) -> 'stem.client.Circuit': """ Establishes a new circuit. """
- with self._orport_lock: + async with self._orport_lock: circ_id = max(self._circuits) + 1 if self._circuits else self.link_protocol.first_circ_id
create_fast_cell = stem.client.cell.CreateFastCell(circ_id) created_fast_cell = None
- for cell in self._msg(create_fast_cell): + async for cell in self._msg(create_fast_cell): if isinstance(cell, stem.client.cell.CreatedFastCell): created_fast_cell = cell break @@ -284,16 +283,16 @@ class Relay(object):
return circ
- def __iter__(self) -> Iterator['stem.client.Circuit']: - with self._orport_lock: + async def __aiter__(self) -> Iterator['stem.client.Circuit']: + async with self._orport_lock: for circ in self._circuits.values(): yield circ
- def __enter__(self) -> 'stem.client.Relay': + async def __aenter__(self) -> 'stem.client.Relay': return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: - self.close() + async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + await self.close()
class Circuit(object): @@ -327,7 +326,7 @@ class Circuit(object): self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor() self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request: str, stream_id: int = 0) -> bytes: + async def directory(self, request: str, stream_id: int = 0) -> bytes: """ Request descriptors from the relay.
@@ -337,9 +336,9 @@ class Circuit(object): :returns: **str** with the requested descriptor data """
- with self.relay._orport_lock: - self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id) - self._send(RelayCommand.DATA, request, stream_id = stream_id) + async with self.relay._orport_lock: + await self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id) + await self._send(RelayCommand.DATA, request, stream_id = stream_id)
response = [] # type: List[stem.client.cell.RelayCell]
@@ -347,7 +346,7 @@ class Circuit(object): # Decrypt relay cells received in response. Our digest/key only # updates when handled successfully.
- encrypted_cell = self.relay._recv_bytes() + encrypted_cell = await self.relay._recv_bytes()
decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
@@ -362,7 +361,7 @@ class Circuit(object): else: response.append(decrypted_cell)
- def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None: + async def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None: """ Sends a message over the circuit.
@@ -371,24 +370,24 @@ class Circuit(object): :param stream_id: specific stream this concerns """
- with self.relay._orport_lock: + async with self.relay._orport_lock: # Encrypt and send the cell. Our digest/key only updates if the cell is # successfully sent.
cell = stem.client.cell.RelayCell(self.id, command, data, stream_id = stream_id) payload, forward_key, forward_digest = cell.encrypt(self.relay.link_protocol, self.forward_key, self.forward_digest) - self.relay._orport.send(payload) + await self.relay._orport.send(payload)
self.forward_digest = forward_digest self.forward_key = forward_key
- def close(self) -> None: - with self.relay._orport_lock: - self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) + async def close(self)- > None: + async with self.relay._orport_lock: + await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) del self.relay._circuits[self.id]
- def __enter__(self) -> 'stem.client.Circuit': + async def __aenter__(self) -> 'stem.client.Circuit': return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: - self.close() + async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + await self.close() diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index e90c4442..eca846ee 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -83,6 +83,8 @@ content. For example... hashes. """
+import asyncio +import functools import io import random import socket @@ -93,6 +95,7 @@ import urllib.request
import stem import stem.client +import stem.control import stem.descriptor import stem.descriptor.networkstatus import stem.directory @@ -227,7 +230,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query' return get_instance().get_detached_signatures(**query_args)
-class Query(object): +class AsyncQuery(object): """ Asynchronous request for descriptor content from a directory authority or mirror. These can either be made through the @@ -427,32 +430,27 @@ class Query(object): self.reply_headers = None # type: Optional[Dict[str, str]] self.kwargs = kwargs
- self._downloader_thread = None # type: Optional[threading.Thread] - self._downloader_thread_lock = threading.RLock() + self._downloader_task = None + self._downloader_lock = threading.RLock() + + self._asyncio_loop = asyncio.get_event_loop()
if start: self.start()
if block: - self.run(True) + self._asyncio_loop.create_task(self.run(True))
def start(self) -> None: """ Starts downloading the scriptors if we haven't started already. """
- with self._downloader_thread_lock: - if self._downloader_thread is None: - self._downloader_thread = threading.Thread( - name = 'Descriptor query', - target = self._download_descriptors, - args = (self.retries, self.timeout) - ) - - self._downloader_thread.setDaemon(True) - self._downloader_thread.start() + with self._downloader_lock: + if self._downloader_task is None: + self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
- def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: + async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -470,12 +468,12 @@ class Query(object): * :class:`~stem.DownloadFailed` if our request fails """
- return list(self._run(suppress)) + return [desc async for desc in self._run(suppress)]
- def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]: - with self._downloader_thread_lock: + async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]: + with self._downloader_lock: self.start() - self._downloader_thread.join() + await self._downloader_task
if self.error: if suppress: @@ -508,8 +506,8 @@ class Query(object):
raise self.error
- def __iter__(self) -> Iterator[stem.descriptor.Descriptor]: - for desc in self._run(True): + async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]: + async for desc in self._run(True): yield desc
def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint: @@ -530,18 +528,18 @@ class Query(object): else: return random.choice(self.endpoints)
- def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None: + async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None: try: self.start_time = time.time() endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort): downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource) - self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource) + self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource) elif isinstance(endpoint, stem.DirPort): self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/')) downloaded_from = self.download_url - self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout) + self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout) else: raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
@@ -555,7 +553,7 @@ class Query(object):
if retries > 0 and (timeout is None or timeout > 0): log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc)) - return self._download_descriptors(retries - 1, timeout) + return await self._download_descriptors(retries - 1, timeout) else: log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc)) self.error = exc @@ -563,6 +561,64 @@ class Query(object): self.is_done = True
+class Query(stem.util.AsyncClassWrapper): + def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs): + self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() + self._thread_for_wrapped_class.start() + self._wrapped_instance = self._init_async_class( + AsyncQuery, + resource, + descriptor_type, + endpoints, + compression, + retries, + fall_back_to_authority, + timeout, + start, + block, + validate, + document_handler, + **kwargs, + ) + + def start(self): + return self._call_async_method_soon('start') + + def run(self, suppress = False): + return self._execute_async_method('run', suppress) + + def __iter__(self): + for desc in self._execute_async_generator_method('__aiter__'): + yield desc + + # Add public attributes of `AsyncQuery` as properties. + for attr in ( + 'descriptor_type', + 'endpoints', + 'resource', + 'compression', + 'retries', + 'fall_back_to_authority', + 'content', + 'error', + 'is_done', + 'download_url', + 'start_time', + 'timeout', + 'runtime', + 'validate', + 'document_handler', + 'reply_headers', + 'kwargs', + ): + locals()[attr] = property( + functools.partial( + lambda self, attr_name: getattr(self._wrapped_instance, attr_name), + attr_name=attr, + ), + ) + + class DescriptorDownloader(object): """ Configurable class that issues :class:`~stem.descriptor.remote.Query` @@ -925,7 +981,7 @@ class DescriptorDownloader(object): return Query(resource, **args)
-def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]: +async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given orport. Payload is just like an http response (headers and all)... @@ -956,15 +1012,15 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
- with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay: - with relay.create_circuit() as circ: + async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay: + async with await relay.create_circuit() as circ: request = '\r\n'.join(( 'GET %s HTTP/1.0' % resource, 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)), 'User-Agent: %s' % stem.USER_AGENT, )) + '\r\n\r\n'
- response = circ.directory(request, stream_id = 1) + response = await circ.directory(request, stream_id = 1) first_line, data = response.split(b'\r\n', 1) header_data, body_data = data.split(b'\r\n\r\n', 1)
@@ -983,7 +1039,7 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: +async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given url.
@@ -998,17 +1054,19 @@ def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Comp * :class:`~stem.DownloadFailed` if our request fails """
+ # TODO: use an asyncronous solution for the HTTP request. + request = urllib.request.Request( + url, + headers = { + 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)), + 'User-Agent': stem.USER_AGENT, + } + ) + get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout) + + loop = asyncio.get_event_loop() try: - response = urllib.request.urlopen( - urllib.request.Request( - url, - headers = { - 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)), - 'User-Agent': stem.USER_AGENT, - } - ), - timeout = timeout, - ) + response = await loop.run_in_executor(None, get_response) except socket.timeout as exc: raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) except: diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py index f3c736a1..455f3da3 100644 --- a/stem/util/test_tools.py +++ b/stem/util/test_tools.py @@ -251,7 +251,7 @@ class TimedTestRunner(unittest.TextTestRunner): TEST_RUNTIMES[self.id()] = time.time() - start_time return result
- def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None: + def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, *args: Any, **kwargs: Any) -> None: """ Asserts the given invokation raises the expected excepiton. This is similar to unittest's assertRaises and assertRaisesRegexp, but checks @@ -262,7 +262,7 @@ class TimedTestRunner(unittest.TextTestRunner): vended API then please let us know. """
- return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs) + return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs)
def id(self) -> str: return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName) diff --git a/test/integ/client/connection.py b/test/integ/client/connection.py index 2294a07d..316d54ba 100644 --- a/test/integ/client/connection.py +++ b/test/integ/client/connection.py @@ -9,46 +9,57 @@ import stem import test.runner
from stem.client import Relay +from stem.util.test_tools import async_test
class TestConnection(unittest.TestCase): - def test_invalid_arguments(self): + @async_test + async def test_invalid_arguments(self): """ Provide invalid arguments to Relay.connect(). """
- self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address", Relay.connect, 'nope', 80) - self.assertRaisesWith(ValueError, "'-54' isn't a valid port", Relay.connect, '127.0.0.1', -54) - self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol.", Relay.connect, '127.0.0.1', 54, []) + with self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address"): + await Relay.connect('nope', 80) + with self.assertRaisesWith(ValueError, "'-54' isn't a valid port"): + await Relay.connect('127.0.0.1', -54) + with self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol."): + await Relay.connect('127.0.0.1', 54, [])
- def test_not_orport(self): + @async_test + async def test_not_orport(self): """ Attempt to connect to an ORPort that doesn't exist. """
- self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', 1587) + with self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?"): + await Relay.connect('127.0.0.1', 1587)
# connect to our ControlPort like it's an ORPort
if test.runner.Torrc.PORT in test.runner.get_runner().get_options(): - self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', test.runner.CONTROL_PORT) + with self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?"): + await Relay.connect('127.0.0.1', test.runner.CONTROL_PORT)
- def test_no_common_link_protocol(self): + @async_test + async def test_no_common_link_protocol(self): """ Connection without a commonly accepted link protocol version. """
for link_protocol in (1, 2, 6, 20): - self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113', Relay.connect, '127.0.0.1', test.runner.ORPORT, [link_protocol]) + with self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113'): + await Relay.connect('127.0.0.1', test.runner.ORPORT, [link_protocol])
- def test_connection_time(self): + @async_test + async def test_connection_time(self): """ Checks duration we've been connected. """
before = time.time()
- with Relay.connect('127.0.0.1', test.runner.ORPORT) as conn: + async with await Relay.connect('127.0.0.1', test.runner.ORPORT) as conn: connection_time = conn.connection_time() self.assertTrue(time.time() >= connection_time >= before) time.sleep(0.02) @@ -57,10 +68,11 @@ class TestConnection(unittest.TestCase): self.assertFalse(conn.is_alive()) self.assertTrue(conn.connection_time() >= connection_time + 0.02)
- def test_established(self): + @async_test + async def test_established(self): """ Successfully establish ORPort connection. """
- conn = Relay.connect('127.0.0.1', test.runner.ORPORT) + conn = await Relay.connect('127.0.0.1', test.runner.ORPORT) self.assertTrue(int(conn.link_protocol) in (4, 5)) diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index e57da92b..33ee57fb 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -13,9 +13,10 @@ import stem.descriptor.remote import stem.util.str_tools import test.require
-from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock
from stem.descriptor.remote import Compression +from stem.util.test_tools import coro_func_returning_value from test.unit.descriptor import read_resource
TEST_RESOURCE = '/tor/server/fp/9695DFC35FFEB861329B9F1AB04C46397020CE31' @@ -78,11 +79,20 @@ def _orport_mock(data, encoding = 'identity', response_code_header = None): cell.data = hunk cells.append(cell)
- connect_mock = MagicMock() - relay_mock = connect_mock().__enter__() - circ_mock = relay_mock.create_circuit().__enter__() - circ_mock.directory.return_value = data - return connect_mock + class AsyncMock(Mock): + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return + + circ_mock = AsyncMock() + circ_mock.directory.side_effect = coro_func_returning_value(data) + + relay_mock = AsyncMock() + relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock) + + return coro_func_returning_value(relay_mock)
def _dirport_mock(data, encoding = 'identity'): @@ -294,7 +304,7 @@ class TestDescriptorDownloader(unittest.TestCase): skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE, )
- self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint()) + self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
descriptors = list(query) self.assertEqual(1, len(descriptors))