commit 841e2105147177f5959987c8bec1179dc94a59b3
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu May 14 00:08:59 2020 +0300
Make requesting for descriptor content asynchronous
---
stem/client/__init__.py | 91 +++++++++++++-------------
stem/descriptor/remote.py | 138 ++++++++++++++++++++++++++++------------
stem/util/test_tools.py | 4 +-
test/integ/client/connection.py | 38 +++++++----
test/unit/descriptor/remote.py | 24 +++++--
5 files changed, 187 insertions(+), 108 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 639f118f..941f0ee7 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -26,7 +26,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
"""
import hashlib
-import threading
import stem
import stem.client.cell
@@ -71,11 +70,10 @@ class Relay(object):
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
- self._orport_lock = threading.RLock()
- self._circuits = {} # type: Dict[int, stem.client.Circuit]
+ self._circuits = {}
@staticmethod
- def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
+ async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
"""
Establishes a connection with the given ORPort.
@@ -97,8 +95,9 @@ class Relay(object):
try:
conn = stem.socket.RelaySocket(address, port)
+ await conn.connect()
except stem.SocketError as exc:
- if 'Connection refused' in str(exc):
+ if 'Connect call failed' in str(exc):
raise stem.SocketError("Failed to connect to %s:%i. Maybe it isn't an ORPort?" % (address, port))
# If not an ORPort (for instance, mistakenly connecting to a ControlPort
@@ -122,21 +121,21 @@ class Relay(object):
# first VERSIONS cell, always have CIRCID_LEN == 2 for backward
# compatibility.
- conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore
- response = conn.recv()
+ await conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore
+ response = await conn.recv()
# Link negotiation ends right away if we lack a common protocol
# version. (#25139)
if not response:
- conn.close()
+ await conn.close()
raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore
common_protocols = set(link_protocols).intersection(versions_reply.versions)
if not common_protocols:
- conn.close()
+ await conn.close()
raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
# Establishing connections requires sending a NETINFO, but including our
@@ -144,14 +143,14 @@ class Relay(object):
# where it would help.
link_protocol = max(common_protocols)
- conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
+ await conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
return Relay(conn, link_protocol)
- def _recv_bytes(self) -> bytes:
- return self._recv(True) # type: ignore
+ async def _recv_bytes(self) -> bytes:
+ return await self._recv(True) # type: ignore
- def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
+ async def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
"""
Reads the next cell from our ORPort. If none is present this blocks
until one is available.
@@ -161,13 +160,13 @@ class Relay(object):
:returns: next :class:`~stem.client.cell.Cell`
"""
- with self._orport_lock:
+ async with self._orport_lock:
# cells begin with [circ_id][cell_type][...]
circ_id_size = self.link_protocol.circ_id_size.size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size):
- self._orport_buffer += self._orport.recv() # read until we know the cell type
+ self._orport_buffer += await self._orport.recv() # read until we know the cell type
cell_type = Cell.by_value(CELL_TYPE_SIZE.pop(self._orport_buffer[circ_id_size:])[0])
@@ -177,13 +176,13 @@ class Relay(object):
# variable length, our next field is the payload size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN):
- self._orport_buffer += self._orport.recv() # read until we know the cell size
+ self._orport_buffer += await self._orport.recv() # read until we know the cell size
payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0]
cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
while len(self._orport_buffer) < cell_size:
- self._orport_buffer += self._orport.recv() # read until we have the full cell
+ self._orport_buffer += await self._orport.recv() # read until we have the full cell
if raw:
content, self._orport_buffer = split(self._orport_buffer, cell_size)
@@ -192,7 +191,7 @@ class Relay(object):
cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
return cell
- def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
+ async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
"""
Sends a cell on the ORPort and provides the response we receive in reply.
@@ -219,9 +218,9 @@ class Relay(object):
# TODO: why is this an iterator?
- self._orport.recv(timeout = 0) # discard unread data
- self._orport.send(cell.pack(self.link_protocol))
- response = self._orport.recv(timeout = 1)
+ await self._orport.recv(timeout = 0) # discard unread data
+ await self._orport.send(cell.pack(self.link_protocol))
+ response = await self._orport.recv(timeout = 1)
yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
def is_alive(self) -> bool:
@@ -246,27 +245,27 @@ class Relay(object):
return self._orport.connection_time()
- def close(self) -> None:
+ async def close(self) -> None:
"""
Closes our socket connection. This is a pass-through for our socket's
:func:`~stem.socket.BaseSocket.close` method.
"""
- with self._orport_lock:
- return self._orport.close()
+ async with self._orport_lock:
+ return await self._orport.close()
- def create_circuit(self) -> 'stem.client.Circuit':
+ async def create_circuit(self) -> 'stem.client.Circuit':
"""
Establishes a new circuit.
"""
- with self._orport_lock:
+ async with self._orport_lock:
circ_id = max(self._circuits) + 1 if self._circuits else self.link_protocol.first_circ_id
create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
created_fast_cell = None
- for cell in self._msg(create_fast_cell):
+ async for cell in self._msg(create_fast_cell):
if isinstance(cell, stem.client.cell.CreatedFastCell):
created_fast_cell = cell
break
@@ -284,16 +283,16 @@ class Relay(object):
return circ
- def __iter__(self) -> Iterator['stem.client.Circuit']:
- with self._orport_lock:
+ async def __aiter__(self) -> Iterator['stem.client.Circuit']:
+ async with self._orport_lock:
for circ in self._circuits.values():
yield circ
- def __enter__(self) -> 'stem.client.Relay':
+ async def __aenter__(self) -> 'stem.client.Relay':
return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+ await self.close()
class Circuit(object):
@@ -327,7 +326,7 @@ class Circuit(object):
self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request: str, stream_id: int = 0) -> bytes:
+ async def directory(self, request: str, stream_id: int = 0) -> bytes:
"""
Request descriptors from the relay.
@@ -337,9 +336,9 @@ class Circuit(object):
:returns: **str** with the requested descriptor data
"""
- with self.relay._orport_lock:
- self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
- self._send(RelayCommand.DATA, request, stream_id = stream_id)
+ async with self.relay._orport_lock:
+ await self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
+ await self._send(RelayCommand.DATA, request, stream_id = stream_id)
response = [] # type: List[stem.client.cell.RelayCell]
@@ -347,7 +346,7 @@ class Circuit(object):
# Decrypt relay cells received in response. Our digest/key only
# updates when handled successfully.
- encrypted_cell = self.relay._recv_bytes()
+ encrypted_cell = await self.relay._recv_bytes()
decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
@@ -362,7 +361,7 @@ class Circuit(object):
else:
response.append(decrypted_cell)
- def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
+ async def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
"""
Sends a message over the circuit.
@@ -371,24 +370,24 @@ class Circuit(object):
:param stream_id: specific stream this concerns
"""
- with self.relay._orport_lock:
+ async with self.relay._orport_lock:
# Encrypt and send the cell. Our digest/key only updates if the cell is
# successfully sent.
cell = stem.client.cell.RelayCell(self.id, command, data, stream_id = stream_id)
payload, forward_key, forward_digest = cell.encrypt(self.relay.link_protocol, self.forward_key, self.forward_digest)
- self.relay._orport.send(payload)
+ await self.relay._orport.send(payload)
self.forward_digest = forward_digest
self.forward_key = forward_key
- def close(self) -> None:
- with self.relay._orport_lock:
- self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
+ async def close(self)- > None:
+ async with self.relay._orport_lock:
+ await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
del self.relay._circuits[self.id]
- def __enter__(self) -> 'stem.client.Circuit':
+ async def __aenter__(self) -> 'stem.client.Circuit':
return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+ await self.close()
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e90c4442..eca846ee 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -83,6 +83,8 @@ content. For example...
hashes.
"""
+import asyncio
+import functools
import io
import random
import socket
@@ -93,6 +95,7 @@ import urllib.request
import stem
import stem.client
+import stem.control
import stem.descriptor
import stem.descriptor.networkstatus
import stem.directory
@@ -227,7 +230,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
return get_instance().get_detached_signatures(**query_args)
-class Query(object):
+class AsyncQuery(object):
"""
Asynchronous request for descriptor content from a directory authority or
mirror. These can either be made through the
@@ -427,32 +430,27 @@ class Query(object):
self.reply_headers = None # type: Optional[Dict[str, str]]
self.kwargs = kwargs
- self._downloader_thread = None # type: Optional[threading.Thread]
- self._downloader_thread_lock = threading.RLock()
+ self._downloader_task = None
+ self._downloader_lock = threading.RLock()
+
+ self._asyncio_loop = asyncio.get_event_loop()
if start:
self.start()
if block:
- self.run(True)
+ self._asyncio_loop.create_task(self.run(True))
def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
- with self._downloader_thread_lock:
- if self._downloader_thread is None:
- self._downloader_thread = threading.Thread(
- name = 'Descriptor query',
- target = self._download_descriptors,
- args = (self.retries, self.timeout)
- )
-
- self._downloader_thread.setDaemon(True)
- self._downloader_thread.start()
+ with self._downloader_lock:
+ if self._downloader_task is None:
+ self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
- def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+ async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -470,12 +468,12 @@ class Query(object):
* :class:`~stem.DownloadFailed` if our request fails
"""
- return list(self._run(suppress))
+ return [desc async for desc in self._run(suppress)]
- def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
- with self._downloader_thread_lock:
+ async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
+ with self._downloader_lock:
self.start()
- self._downloader_thread.join()
+ await self._downloader_task
if self.error:
if suppress:
@@ -508,8 +506,8 @@ class Query(object):
raise self.error
- def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
- for desc in self._run(True):
+ async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]:
+ async for desc in self._run(True):
yield desc
def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
@@ -530,18 +528,18 @@ class Query(object):
else:
return random.choice(self.endpoints)
- def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
+ async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
try:
self.start_time = time.time()
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort):
downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
- self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource)
+ self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
elif isinstance(endpoint, stem.DirPort):
self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
downloaded_from = self.download_url
- self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
+ self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
else:
raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
@@ -555,7 +553,7 @@ class Query(object):
if retries > 0 and (timeout is None or timeout > 0):
log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
- return self._download_descriptors(retries - 1, timeout)
+ return await self._download_descriptors(retries - 1, timeout)
else:
log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
self.error = exc
@@ -563,6 +561,64 @@ class Query(object):
self.is_done = True
+class Query(stem.util.AsyncClassWrapper):
+ def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs):
+ self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+ self._thread_for_wrapped_class.start()
+ self._wrapped_instance = self._init_async_class(
+ AsyncQuery,
+ resource,
+ descriptor_type,
+ endpoints,
+ compression,
+ retries,
+ fall_back_to_authority,
+ timeout,
+ start,
+ block,
+ validate,
+ document_handler,
+ **kwargs,
+ )
+
+ def start(self):
+ return self._call_async_method_soon('start')
+
+ def run(self, suppress = False):
+ return self._execute_async_method('run', suppress)
+
+ def __iter__(self):
+ for desc in self._execute_async_generator_method('__aiter__'):
+ yield desc
+
+ # Add public attributes of `AsyncQuery` as properties.
+ for attr in (
+ 'descriptor_type',
+ 'endpoints',
+ 'resource',
+ 'compression',
+ 'retries',
+ 'fall_back_to_authority',
+ 'content',
+ 'error',
+ 'is_done',
+ 'download_url',
+ 'start_time',
+ 'timeout',
+ 'runtime',
+ 'validate',
+ 'document_handler',
+ 'reply_headers',
+ 'kwargs',
+ ):
+ locals()[attr] = property(
+ functools.partial(
+ lambda self, attr_name: getattr(self._wrapped_instance, attr_name),
+ attr_name=attr,
+ ),
+ )
+
+
class DescriptorDownloader(object):
"""
Configurable class that issues :class:`~stem.descriptor.remote.Query`
@@ -925,7 +981,7 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given orport. Payload is just like an http
response (headers and all)...
@@ -956,15 +1012,15 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
- with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
- with relay.create_circuit() as circ:
+ async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+ async with await relay.create_circuit() as circ:
request = '\r\n'.join((
'GET %s HTTP/1.0' % resource,
'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
'User-Agent: %s' % stem.USER_AGENT,
)) + '\r\n\r\n'
- response = circ.directory(request, stream_id = 1)
+ response = await circ.directory(request, stream_id = 1)
first_line, data = response.split(b'\r\n', 1)
header_data, body_data = data.split(b'\r\n\r\n', 1)
@@ -983,7 +1039,7 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given url.
@@ -998,17 +1054,19 @@ def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Comp
* :class:`~stem.DownloadFailed` if our request fails
"""
+ # TODO: use an asyncronous solution for the HTTP request.
+ request = urllib.request.Request(
+ url,
+ headers = {
+ 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
+ 'User-Agent': stem.USER_AGENT,
+ }
+ )
+ get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
+
+ loop = asyncio.get_event_loop()
try:
- response = urllib.request.urlopen(
- urllib.request.Request(
- url,
- headers = {
- 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent': stem.USER_AGENT,
- }
- ),
- timeout = timeout,
- )
+ response = await loop.run_in_executor(None, get_response)
except socket.timeout as exc:
raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
except:
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index f3c736a1..455f3da3 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -251,7 +251,7 @@ class TimedTestRunner(unittest.TextTestRunner):
TEST_RUNTIMES[self.id()] = time.time() - start_time
return result
- def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None:
+ def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, *args: Any, **kwargs: Any) -> None:
"""
Asserts the given invokation raises the expected excepiton. This is
similar to unittest's assertRaises and assertRaisesRegexp, but checks
@@ -262,7 +262,7 @@ class TimedTestRunner(unittest.TextTestRunner):
vended API then please let us know.
"""
- return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs)
+ return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs)
def id(self) -> str:
return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName)
diff --git a/test/integ/client/connection.py b/test/integ/client/connection.py
index 2294a07d..316d54ba 100644
--- a/test/integ/client/connection.py
+++ b/test/integ/client/connection.py
@@ -9,46 +9,57 @@ import stem
import test.runner
from stem.client import Relay
+from stem.util.test_tools import async_test
class TestConnection(unittest.TestCase):
- def test_invalid_arguments(self):
+ @async_test
+ async def test_invalid_arguments(self):
"""
Provide invalid arguments to Relay.connect().
"""
- self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address", Relay.connect, 'nope', 80)
- self.assertRaisesWith(ValueError, "'-54' isn't a valid port", Relay.connect, '127.0.0.1', -54)
- self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol.", Relay.connect, '127.0.0.1', 54, [])
+ with self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address"):
+ await Relay.connect('nope', 80)
+ with self.assertRaisesWith(ValueError, "'-54' isn't a valid port"):
+ await Relay.connect('127.0.0.1', -54)
+ with self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol."):
+ await Relay.connect('127.0.0.1', 54, [])
- def test_not_orport(self):
+ @async_test
+ async def test_not_orport(self):
"""
Attempt to connect to an ORPort that doesn't exist.
"""
- self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', 1587)
+ with self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?"):
+ await Relay.connect('127.0.0.1', 1587)
# connect to our ControlPort like it's an ORPort
if test.runner.Torrc.PORT in test.runner.get_runner().get_options():
- self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', test.runner.CONTROL_PORT)
+ with self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?"):
+ await Relay.connect('127.0.0.1', test.runner.CONTROL_PORT)
- def test_no_common_link_protocol(self):
+ @async_test
+ async def test_no_common_link_protocol(self):
"""
Connection without a commonly accepted link protocol version.
"""
for link_protocol in (1, 2, 6, 20):
- self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113', Relay.connect, '127.0.0.1', test.runner.ORPORT, [link_protocol])
+ with self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113'):
+ await Relay.connect('127.0.0.1', test.runner.ORPORT, [link_protocol])
- def test_connection_time(self):
+ @async_test
+ async def test_connection_time(self):
"""
Checks duration we've been connected.
"""
before = time.time()
- with Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
+ async with await Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
connection_time = conn.connection_time()
self.assertTrue(time.time() >= connection_time >= before)
time.sleep(0.02)
@@ -57,10 +68,11 @@ class TestConnection(unittest.TestCase):
self.assertFalse(conn.is_alive())
self.assertTrue(conn.connection_time() >= connection_time + 0.02)
- def test_established(self):
+ @async_test
+ async def test_established(self):
"""
Successfully establish ORPort connection.
"""
- conn = Relay.connect('127.0.0.1', test.runner.ORPORT)
+ conn = await Relay.connect('127.0.0.1', test.runner.ORPORT)
self.assertTrue(int(conn.link_protocol) in (4, 5))
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index e57da92b..33ee57fb 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -13,9 +13,10 @@ import stem.descriptor.remote
import stem.util.str_tools
import test.require
-from unittest.mock import patch, Mock, MagicMock
+from unittest.mock import patch, Mock
from stem.descriptor.remote import Compression
+from stem.util.test_tools import coro_func_returning_value
from test.unit.descriptor import read_resource
TEST_RESOURCE = '/tor/server/fp/9695DFC35FFEB861329B9F1AB04C46397020CE31'
@@ -78,11 +79,20 @@ def _orport_mock(data, encoding = 'identity', response_code_header = None):
cell.data = hunk
cells.append(cell)
- connect_mock = MagicMock()
- relay_mock = connect_mock().__enter__()
- circ_mock = relay_mock.create_circuit().__enter__()
- circ_mock.directory.return_value = data
- return connect_mock
+ class AsyncMock(Mock):
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ return
+
+ circ_mock = AsyncMock()
+ circ_mock.directory.side_effect = coro_func_returning_value(data)
+
+ relay_mock = AsyncMock()
+ relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
+
+ return coro_func_returning_value(relay_mock)
def _dirport_mock(data, encoding = 'identity'):
@@ -294,7 +304,7 @@ class TestDescriptorDownloader(unittest.TestCase):
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
)
- self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint())
+ self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
descriptors = list(query)
self.assertEqual(1, len(descriptors))