tor-commits
Threads by month
- ----- 2025 -----
- May
- April
- March
- February
- January
- ----- 2024 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2023 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2022 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2021 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2020 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2019 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2018 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2017 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2016 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2015 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2014 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2013 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2012 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2011 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
July 2020
- 17 participants
- 2100 discussions
commit 675f49fcbe6dc1a52d10215c07adff56001faa70
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sat May 30 17:21:57 2020 -0700
Drop ThreadForWrappedAsyncClass
To call an asynchronous function we require a loop and thread for it to run
within...
def call(my_async_function):
loop = asyncio.get_event_loop()
loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
loop_thread.setDaemon(True)
loop_thread.start()
result = asyncio.run_coroutine_threadsafe(my_async_function, loop)
loop.call_soon_threadsafe(loop.stop)
loop_thread.join()
return result
ThreadForWrappedAsyncClass bundled these together, but I found it more
confusing than helpful. These threads failed to clean themselves up,
causing 'lingering thread' notifications when we run our tests.
---
stem/connection.py | 30 ++++++++++++++++++------------
stem/control.py | 21 ++++++++++-----------
stem/descriptor/remote.py | 7 +++++--
stem/util/__init__.py | 33 +++++++++++----------------------
test/integ/connection/authentication.py | 2 +-
test/runner.py | 18 +++++++++---------
test/unit/control/controller.py | 2 +-
7 files changed, 55 insertions(+), 58 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index 213ba010..86d32d7f 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -159,7 +159,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
-from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -271,18 +271,24 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
if controller is None or not issubclass(controller, stem.control.Controller):
raise ValueError('Controller should be a stem.control.Controller subclass.')
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
- async_controller_thread.start()
+ loop = asyncio.new_event_loop()
+ loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
+ loop_thread.setDaemon(True)
+ loop_thread.start()
- connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
try:
- connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result()
- if connection is None and async_controller_thread.is_alive():
- async_controller_thread.join()
+ connection = asyncio.run_coroutine_threadsafe(_connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller), loop).result()
+
+ if connection is None and loop_thread.is_alive():
+ loop.call_soon_threadsafe(loop.stop)
+ loop_thread.join()
+
return connection
except:
- if async_controller_thread.is_alive():
- async_controller_thread.join()
+ if loop_thread.is_alive():
+ loop.call_soon_threadsafe(loop.stop)
+ loop_thread.join()
+
raise
@@ -399,10 +405,10 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None:
return control_socket
- elif issubclass(controller, stem.control.BaseController):
+ elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller):
+ # TODO: Controller no longer extends BaseController (we'll probably change that)
+
return controller(control_socket, is_authenticated = True)
- elif issubclass(controller, stem.control.Controller):
- return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread()))
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
diff --git a/stem/control.py b/stem/control.py
index 084976ad..47ddaa35 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3941,13 +3941,17 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
self,
control_socket: stem.socket.ControlSocket,
is_authenticated: bool = False,
- started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None,
) -> None:
- if started_async_controller_thread:
- self._thread_for_wrapped_class = started_async_controller_thread
- else:
- self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
- self._thread_for_wrapped_class.start()
+ # if within an asyncio context use its loop, otherwise spawn our own
+
+ try:
+ self._loop = asyncio.get_running_loop()
+ self._loop_thread = threading.current_thread()
+ except RuntimeError:
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop_thread.setDaemon(True)
+ self._loop_thread.start()
self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore
self._socket = self._wrapped_instance._socket
@@ -4212,11 +4216,6 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
def drop_guards(self) -> None:
self._execute_async_method('drop_guards')
- def __del__(self) -> None:
- loop = self._thread_for_wrapped_class.loop
- if loop.is_running():
- loop.call_soon_threadsafe(loop.stop)
-
def __enter__(self) -> 'stem.control.Controller':
return self
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e7ccaa24..c23ab7a9 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -663,8 +663,11 @@ class Query(stem.util.AsyncClassWrapper):
"""
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
- self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
- self._thread_for_wrapped_class.start()
+ self._loop = asyncio.get_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop_thread.setDaemon(True)
+ self._loop_thread.start()
+
self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore
AsyncQuery,
resource,
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index a90aa7ac..25282b99 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,7 +10,7 @@ import datetime
import threading
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Type, Union
__all__ = [
'conf',
@@ -144,41 +144,26 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
return my_hash
-class ThreadForWrappedAsyncClass(threading.Thread):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, *kwargs)
- self.loop = asyncio.new_event_loop()
- self.setDaemon(True)
-
- def run(self) -> None:
- self.loop.run_forever()
-
- def join(self, timeout: Optional[float] = None) -> None:
- self.loop.call_soon_threadsafe(self.loop.stop)
- super().join(timeout)
- self.loop.close()
-
-
class AsyncClassWrapper:
- _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+ _loop: asyncio.AbstractEventLoop
+ _loop_thread: threading.Thread
_wrapped_instance: type
def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any:
- thread = self._thread_for_wrapped_class
# The asynchronous class should be initialized in the thread where
# its methods will be executed.
- if thread != threading.current_thread():
+ if self._loop_thread != threading.current_thread():
async def init():
return async_class(*args, **kwargs)
- return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+ return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
return async_class(*args, **kwargs)
def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future:
return asyncio.run_coroutine_threadsafe(
getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- self._thread_for_wrapped_class.loop,
+ self._loop,
)
def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
@@ -192,5 +177,9 @@ class AsyncClassWrapper:
convert_async_generator(
getattr(self._wrapped_instance, method_name)(*args, **kwargs),
),
- self._thread_for_wrapped_class.loop,
+ self._loop,
).result()
+
+ def __del__(self) -> None:
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 3709c275..042d3939 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -124,7 +124,7 @@ class TestAuthenticate(unittest.TestCase):
with runner.get_tor_controller(False) as controller:
asyncio.run_coroutine_threadsafe(
stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
- controller._thread_for_wrapped_class.loop,
+ controller._loop,
).result()
await test.runner.exercise_controller(self, controller)
diff --git a/test/runner.py b/test/runner.py
index 4f237552..b132b8f5 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,16 +488,16 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
- async_controller_thread.start()
+ loop = asyncio.new_event_loop()
+ loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller')
+ loop_thread.setDaemon(True)
+ loop_thread.start()
- try:
- control_socket = asyncio.run_coroutine_threadsafe(self.get_tor_socket(False), async_controller_thread.loop).result()
- controller = stem.control.Controller(control_socket, started_async_controller_thread = async_controller_thread)
- except Exception:
- if async_controller_thread.is_alive():
- async_controller_thread.join()
- raise
+ async def wrapped_get_controller():
+ control_socket = await self.get_tor_socket(False)
+ return stem.control.Controller(control_socket)
+
+ controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
if authenticate:
self._authenticate_controller(controller)
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 850918f7..84fcdfed 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -751,7 +751,7 @@ class TestControl(unittest.TestCase):
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
- loop = self.controller._thread_for_wrapped_class.loop
+ loop = self.controller._loop
asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
1
0

[stem/master] Switch to asyncio locks in socket classes to make them usable too
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit cad4b7204b7bbec52f3eb5b04811f332a12aa85d
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun May 24 02:23:16 2020 +0300
Switch to asyncio locks in socket classes to make them usable too
---
stem/control.py | 18 +++++++----
stem/socket.py | 93 ++++++++++++++++++++++++++++-----------------------------
2 files changed, 58 insertions(+), 53 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index 6ca6c23b..293d4bd3 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -858,7 +858,7 @@ class BaseController(_BaseControllerSocketMixin):
async def _connect(self) -> None:
self._create_loop_tasks()
- self._notify_status_listeners(State.INIT)
+ await self._notify_status_listeners(State.INIT, acquire_send_lock=False)
await self._socket_connect()
self._is_authenticated = False
@@ -879,7 +879,7 @@ class BaseController(_BaseControllerSocketMixin):
if event_loop_task:
await event_loop_task
- self._notify_status_listeners(State.CLOSED)
+ await self._notify_status_listeners(State.CLOSED, acquire_send_lock=False)
await self._socket_close()
@@ -888,7 +888,7 @@ class BaseController(_BaseControllerSocketMixin):
self._is_authenticated = True
- def _notify_status_listeners(self, state: 'stem.control.State') -> None:
+ async def _notify_status_listeners(self, state: 'stem.control.State', acquire_send_lock: bool = True) -> None:
"""
Informs our status listeners that a state change occurred.
@@ -898,7 +898,10 @@ class BaseController(_BaseControllerSocketMixin):
# Any changes to our is_alive() state happen under the send lock, so we
# need to have it to ensure it doesn't change beneath us.
- with self._socket._get_send_lock():
+ send_lock = self._socket._get_send_lock()
+ try:
+ if acquire_send_lock:
+ await send_lock.acquire()
with self._status_listeners_lock:
# States imply that our socket is either alive or not, which may not
# hold true when multiple events occur in quick succession. For
@@ -931,6 +934,9 @@ class BaseController(_BaseControllerSocketMixin):
self._state_change_threads.append(notice_thread)
else:
listener(self, state, change_timestamp)
+ finally:
+ if acquire_send_lock:
+ send_lock.release()
def _create_loop_tasks(self) -> None:
"""
@@ -1064,10 +1070,10 @@ class AsyncController(BaseController):
super(AsyncController, self).__init__(control_socket, is_authenticated)
- def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
+ async def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
if event.signal == Signal.RELOAD:
self.clear_cache()
- self._notify_status_listeners(State.RESET)
+ await self._notify_status_listeners(State.RESET)
def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None:
if self.is_caching_enabled():
diff --git a/stem/socket.py b/stem/socket.py
index ff99c5b1..a0e5bf55 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -86,7 +86,6 @@ import re
import socket
import ssl
import sys
-import threading
import time
import stem.response
@@ -119,8 +118,14 @@ class BaseSocket(object):
# so prevents deadlock where we block writes because we're waiting to read
# a message that isn't coming.
- self._send_lock = threading.RLock()
- self._recv_lock = threading.RLock()
+ # The class is often initialized in a thread with an event loop different
+ # from one where it will be used. The asyncio locks are bound to the loop
+ # running in a thread where they are initialized. Therefore, we are
+ # creating them in _get_send_lock and _get_recv_lock when they are used the
+ # first time.
+
+ self._send_lock = None # type: Optional[asyncio.Lock]
+ self._recv_lock = None # type: Optional[asyncio.Lock]
def is_alive(self) -> bool:
"""
@@ -173,15 +178,15 @@ class BaseSocket(object):
:raises: :class:`stem.SocketError` if unable to make a socket
"""
- with self._send_lock:
+ async with self._get_send_lock():
# Closes the socket if we're currently attached to one. Once we're no
# longer alive it'll be safe to acquire the recv lock because recv()
# calls no longer block (raising SocketClosed instead).
if self.is_alive():
- await self.close()
+ await self._close_wo_send_lock()
- with self._recv_lock:
+ async with self._get_recv_lock():
self._reader, self._writer = await self._open_connection()
self._is_alive = True
self._connection_time = time.time()
@@ -201,32 +206,35 @@ class BaseSocket(object):
Shuts down the socket. If it's already closed then this is a no-op.
"""
- with self._send_lock:
- # Function is idempotent with one exception: we notify _close() if this
- # is causing our is_alive() state to change.
+ async with self._get_send_lock():
+ await self._close_wo_send_lock()
- is_change = self.is_alive()
+ async def _close_wo_send_lock(self) -> None:
+ # Function is idempotent with one exception: we notify _close() if this
+ # is causing our is_alive() state to change.
- if self._writer:
- self._writer.close()
- # `StreamWriter.wait_closed` was added in Python 3.7.
- if sys.version_info >= (3, 7):
- await self._writer.wait_closed()
+ is_change = self.is_alive()
- self._reader = None
- self._writer = None
- self._is_alive = False
- self._connection_time = time.time()
+ if self._writer:
+ self._writer.close()
+ # `StreamWriter.wait_closed` was added in Python 3.7.
+ if sys.version_info >= (3, 7):
+ await self._writer.wait_closed()
+
+ self._reader = None
+ self._writer = None
+ self._is_alive = False
+ self._connection_time = time.time()
- if is_change:
- await self._close()
+ if is_change:
+ await self._close()
async def _send(self, message: Union[bytes, str], handler: Callable[[asyncio.StreamWriter, Union[bytes, str]], Awaitable[None]]) -> None:
"""
Send message in a thread safe manner.
"""
- with self._send_lock:
+ async with self._get_send_lock():
try:
if not self.is_alive():
raise stem.SocketClosed()
@@ -237,7 +245,7 @@ class BaseSocket(object):
# everything down
if self.is_alive():
- await self.close()
+ await self._close_wo_send_lock()
raise
@@ -254,8 +262,8 @@ class BaseSocket(object):
Receives a message in a thread safe manner.
"""
- with self._recv_lock:
- try:
+ try:
+ async with self._get_recv_lock():
# makes a temporary reference to the _reader because connect()
# and close() may set or unset it
@@ -265,41 +273,32 @@ class BaseSocket(object):
raise stem.SocketClosed()
return await handler(my_reader)
- except stem.SocketClosed:
- # If recv_message raises a SocketClosed then we should properly shut
- # everything down. However, there's a couple cases where this will
- # cause deadlock...
- #
- # * This SocketClosed was *caused by* a close() call, which is joining
- # on our thread.
- #
- # * A send() call that's currently in flight is about to call close(),
- # also attempting to join on us.
- #
- # To resolve this we make a non-blocking call to acquire the send lock.
- # If we get it then great, we can close safely. If not then one of the
- # above are in progress and we leave the close to them.
-
- if self.is_alive():
- if self._send_lock.acquire(False):
- await self.close()
- self._send_lock.release()
+ except stem.SocketClosed:
+ if self.is_alive():
+ await self.close()
- raise
+ raise
- def _get_send_lock(self) -> threading.RLock:
+ def _get_send_lock(self) -> asyncio.Lock:
"""
The send lock is useful to classes that interact with us at a deep level
because it's used to lock :func:`stem.socket.ControlSocket.connect` /
:func:`stem.socket.BaseSocket.close`, and by extension our
:func:`stem.socket.BaseSocket.is_alive` state changes.
- :returns: **threading.RLock** that governs sending messages to our socket
+ :returns: **asyncio.Lock** that governs sending messages to our socket
and state changes
"""
+ if self._send_lock is None:
+ self._send_lock = asyncio.Lock()
return self._send_lock
+ def _get_recv_lock(self) -> asyncio.Lock:
+ if self._recv_lock is None:
+ self._recv_lock = asyncio.Lock()
+ return self._recv_lock
+
async def __aenter__(self) -> 'stem.socket.BaseSocket':
return self
1
0
commit b313a4211e08382afdf7dc51a99a181e1b54cbb1
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Tue May 26 22:12:04 2020 +0300
Get rid of `_recv_lock`
---
stem/socket.py | 53 ++++++++++++++++++++---------------------------------
1 file changed, 20 insertions(+), 33 deletions(-)
diff --git a/stem/socket.py b/stem/socket.py
index a0e5bf55..8de13ba0 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -114,18 +114,12 @@ class BaseSocket(object):
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
- # Tracks sending and receiving separately. This should be safe, and doing
- # so prevents deadlock where we block writes because we're waiting to read
- # a message that isn't coming.
-
# The class is often initialized in a thread with an event loop different
- # from one where it will be used. The asyncio locks are bound to the loop
- # running in a thread where they are initialized. Therefore, we are
- # creating them in _get_send_lock and _get_recv_lock when they are used the
- # first time.
+ # from one where it will be used. The asyncio lock is bound to the loop
+ # running in a thread where it is initialized. Therefore, we are creating
+ # it in _get_send_lock when it is used the first time.
self._send_lock = None # type: Optional[asyncio.Lock]
- self._recv_lock = None # type: Optional[asyncio.Lock]
def is_alive(self) -> bool:
"""
@@ -186,20 +180,19 @@ class BaseSocket(object):
if self.is_alive():
await self._close_wo_send_lock()
- async with self._get_recv_lock():
- self._reader, self._writer = await self._open_connection()
- self._is_alive = True
- self._connection_time = time.time()
+ self._reader, self._writer = await self._open_connection()
+ self._is_alive = True
+ self._connection_time = time.time()
- # It's possible for this to have a transient failure...
- # SocketError: [Errno 4] Interrupted system call
- #
- # It's safe to retry, so give it another try if it fails.
+ # It's possible for this to have a transient failure...
+ # SocketError: [Errno 4] Interrupted system call
+ #
+ # It's safe to retry, so give it another try if it fails.
- try:
- await self._connect()
- except stem.SocketError:
- await self._connect() # single retry
+ try:
+ await self._connect()
+ except stem.SocketError:
+ await self._connect() # single retry
async def close(self) -> None:
"""
@@ -263,16 +256,15 @@ class BaseSocket(object):
"""
try:
- async with self._get_recv_lock():
- # makes a temporary reference to the _reader because connect()
- # and close() may set or unset it
+ # makes a temporary reference to the _reader because connect()
+ # and close() may set or unset it
- my_reader = self._reader
+ my_reader = self._reader
- if not my_reader:
- raise stem.SocketClosed()
+ if not my_reader:
+ raise stem.SocketClosed()
- return await handler(my_reader)
+ return await handler(my_reader)
except stem.SocketClosed:
if self.is_alive():
await self.close()
@@ -294,11 +286,6 @@ class BaseSocket(object):
self._send_lock = asyncio.Lock()
return self._send_lock
- def _get_recv_lock(self) -> asyncio.Lock:
- if self._recv_lock is None:
- self._recv_lock = asyncio.Lock()
- return self._recv_lock
-
async def __aenter__(self) -> 'stem.socket.BaseSocket':
return self
1
0

[stem/master] Replace `CombinedReentrantAndAsyncioLock` with the plain `asyncio.Lock`
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 6be7d88f9e6bf82e5ae20813e6294c6862ea58c6
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun May 24 02:29:32 2020 +0300
Replace `CombinedReentrantAndAsyncioLock` with the plain `asyncio.Lock`
`CombinedReentrantAndAsyncioLock` cannot be used in multiple threads anyway.
---
stem/client/__init__.py | 3 ++-
stem/control.py | 4 ++--
stem/util/__init__.py | 29 -----------------------------
3 files changed, 4 insertions(+), 32 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 8ea7b3c1..8c8da923 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -25,6 +25,7 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
+- close - closes this circuit
"""
+import asyncio
import hashlib
import stem
@@ -70,7 +71,7 @@ class Relay(object):
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
- self._orport_lock = stem.util.CombinedReentrantAndAsyncioLock()
+ self._orport_lock = asyncio.Lock()
self._circuits = {} # type: Dict[int, stem.client.Circuit]
@staticmethod
diff --git a/stem/control.py b/stem/control.py
index 293d4bd3..084976ad 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -623,7 +623,7 @@ class BaseController(_BaseControllerSocketMixin):
self._asyncio_loop = asyncio.get_event_loop()
- self._msg_lock = stem.util.CombinedReentrantAndAsyncioLock()
+ self._msg_lock = asyncio.Lock()
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
@@ -1062,7 +1062,7 @@ class AsyncController(BaseController):
# mapping of event types to their listeners
self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]]]
- self._event_listeners_lock = stem.util.CombinedReentrantAndAsyncioLock()
+ self._event_listeners_lock = asyncio.Lock()
self._enabled_features = [] # type: List[str]
self._last_address_exc = None # type: Optional[BaseException]
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 7c53730c..a90aa7ac 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,7 +10,6 @@ import datetime
import threading
from concurrent.futures import Future
-from types import TracebackType
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
@@ -145,34 +144,6 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
return my_hash
-class CombinedReentrantAndAsyncioLock:
- """
- Lock that combines thread-safe reentrant and not thread-safe asyncio locks.
- """
-
- __slots__ = ('_r_lock', '_async_lock')
-
- def __init__(self) -> None:
- self._r_lock = threading.RLock()
- self._async_lock = asyncio.Lock()
-
- async def acquire(self) -> bool:
- await self._async_lock.acquire()
- self._r_lock.acquire()
- return True
-
- def release(self) -> None:
- self._r_lock.release()
- self._async_lock.release()
-
- async def __aenter__(self) -> 'CombinedReentrantAndAsyncioLock':
- await self.acquire()
- return self
-
- async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.release()
-
-
class ThreadForWrappedAsyncClass(threading.Thread):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs)
1
0

16 Jul '20
commit 5406561385cb2e5274ddda732b6e1d6fb014a2c5
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sun Jun 28 14:33:40 2020 -0700
Constructor method with an async context
Many asyncio classes can only be constructed within a running loop. We can't
presume that our __init__() has that, so adding an __ainit__() method that
will.
---
stem/util/__init__.py | 95 +++++++++++++++++++++++++++----------------
test/unit/util/synchronous.py | 31 ++++++++++++++
2 files changed, 92 insertions(+), 34 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e8ef361e..5a8f95e0 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -150,7 +150,7 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any) -> int:
class Synchronous(object):
"""
- Mixin that lets a class be called from both synchronous and asynchronous
+ Mixin that lets a class run within both synchronous and asynchronous
contexts.
::
@@ -172,39 +172,70 @@ class Synchronous(object):
sync_demo()
asyncio.run(async_demo())
+ Our async methods always run within a loop. For asyncio users this class has
+ no affect, but otherwise we transparently create an async context to run
+ within.
+
+ Class initialization and any non-async methods should assume they're running
+ within an synchronous context. If our class supplies an **__ainit__()**
+ method it is invoked within our loop during initialization...
+
+ ::
+
+ class Example(Synchronous):
+ def __init__(self):
+ super(Example, self).__init__()
+
+ # Synchronous part of our initialization. Avoid anything
+ # that must run within an asyncio loop.
+
+ def __ainit__(self):
+ # Asychronous part of our initialization. You can call
+ # asyncio.get_running_loop(), and construct objects that
+ # require it (like asyncio.Queue and asyncio.Lock).
+
Users are responsible for calling :func:`~stem.util.Synchronous.close` when
finished to clean up underlying resources.
"""
def __init__(self) -> None:
- self._loop = asyncio.new_event_loop()
- self._loop_lock = threading.RLock()
- self._loop_thread = threading.Thread(
- name = '%s asyncio' % type(self).__name__,
- target = self._loop.run_forever,
- daemon = True,
- )
+ ainit_func = getattr(self, '__ainit__', None)
+
+ if Synchronous.is_asyncio_context():
+ self._loop = asyncio.get_running_loop()
+ self._loop_thread = None
+
+ if ainit_func:
+ ainit_func()
+ else:
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % type(self).__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
- self._is_closed = False
+ self._loop_thread.start()
- # overwrite asynchronous class methods with instance methods that can be
- # called from either context
+ # call any coroutines through this loop
- def wrap(func: Callable, *args: Any, **kwargs: Any) -> Any:
- if Synchronous.is_asyncio_context():
- return func(*args, **kwargs)
- else:
- with self._loop_lock:
- if self._is_closed:
- raise RuntimeError('%s has been closed' % type(self).__name__)
- elif not self._loop_thread.is_alive():
- self._loop_thread.start()
+ def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
+ if Synchronous.is_asyncio_context():
+ return func(*args, **kwargs)
+ elif not self._loop_thread.is_alive():
+ raise RuntimeError('%s has been closed' % type(self).__name__)
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+ return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
- for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
- if inspect.iscoroutinefunction(func):
- setattr(self, method_name, functools.partial(wrap, func))
+ for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
+ if inspect.iscoroutinefunction(func):
+ setattr(self, method_name, functools.partial(call_async, func))
+
+ if ainit_func:
+ async def call_ainit():
+ ainit_func()
+
+ asyncio.run_coroutine_threadsafe(call_ainit(), self._loop).result()
def close(self) -> None:
"""
@@ -213,12 +244,9 @@ class Synchronous(object):
**RuntimeError**.
"""
- with self._loop_lock:
- if self._loop_thread.is_alive():
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
-
- self._is_closed = True
+ if self._loop_thread and self._loop_thread.is_alive():
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
@staticmethod
def is_asyncio_context() -> bool:
@@ -235,14 +263,13 @@ class Synchronous(object):
return False
def __iter__(self) -> Iterator:
- async def convert_async_generator(generator: AsyncIterator) -> Iterator:
+ async def convert_generator(generator: AsyncIterator) -> Iterator:
return iter([d async for d in generator])
- iter_func = getattr(self, '__aiter__')
+ iter_func = getattr(self, '__aiter__', None)
if iter_func:
- with self._loop_lock:
- return asyncio.run_coroutine_threadsafe(convert_async_generator(iter_func()), self._loop).result()
+ return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result()
else:
raise TypeError("'%s' object is not iterable" % type(self).__name__)
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 26dad98d..22271ffd 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -24,6 +24,10 @@ class Example(Synchronous):
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
def test_example(self, stdout_mock):
+ """
+ Run the example from our pydoc.
+ """
+
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
@@ -39,7 +43,34 @@ class TestSynchronous(unittest.TestCase):
self.assertEqual(EXAMPLE_OUTPUT, stdout_mock.getvalue())
+ def test_ainit(self):
+ """
+ Check that our constructor runs __ainit__ if present.
+ """
+
+ class AinitDemo(Synchronous):
+ def __init__(self):
+ super(AinitDemo, self).__init__()
+
+ def __ainit__(self):
+ self.ainit_loop = asyncio.get_running_loop()
+
+ def sync_demo():
+ instance = AinitDemo()
+ self.assertTrue(hasattr(instance, 'ainit_loop'))
+
+ async def async_demo():
+ instance = AinitDemo()
+ self.assertTrue(hasattr(instance, 'ainit_loop'))
+
+ sync_demo()
+ asyncio.run(async_demo())
+
def test_after_close(self):
+ """
+ Check that closed instances raise a RuntimeError to synchronous callers.
+ """
+
# close a used instance
instance = Example()
1
0
commit 448060eabed41b3bad22cc5b0a5b5494f2793816
Author: Damian Johnson <atagar(a)torproject.org>
Date: Mon Jun 15 16:23:40 2020 -0700
Rewrite descriptor downloading
Using run_in_executor() here has a couple issues...
1. Executor threads aren't cleaned up. Running our tests with the '--all'
argument concludes with...
Threads lingering after test run:
<_MainThread(MainThread, started 140249831520000)>
<Thread(ThreadPoolExecutor-0_0, started daemon 140249689769728)>
<Thread(ThreadPoolExecutor-0_1, started daemon 140249606911744)>
<Thread(ThreadPoolExecutor-0_2, started daemon 140249586980608)>
<Thread(ThreadPoolExecutor-0_3, started daemon 140249578587904)>
<Thread(ThreadPoolExecutor-0_4, started daemon 140249570195200)>
...
2. Asyncio has its own IO. Wrapping urllib within an executor is easy,
but loses asyncio benefits such as imposing timeouts through
asyncio.wait_for().
Urllib marshals and parses HTTP headers, but we already do that
for ORPort requests, so using a raw asyncio connection actually
lets us deduplicate some code.
Deduplication greatly simplifies testing in that we can mock _download_from()
rather than the raw connection. However, I couldn't adapt our timeout test.
Asyncio's wait_for() works in practice, but no dice when mocked.
---
stem/descriptor/remote.py | 229 ++++++++++++++++-------------------------
test/unit/descriptor/remote.py | 183 ++++++++------------------------
2 files changed, 133 insertions(+), 279 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index c23ab7a9..f1ce79db 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -84,14 +84,11 @@ content. For example...
"""
import asyncio
-import functools
import io
import random
-import socket
import sys
import threading
import time
-import urllib.request
import stem
import stem.client
@@ -313,7 +310,7 @@ class AsyncQuery(object):
:var bool is_done: flag that indicates if our request has finished
:var float start_time: unix timestamp when we first started running
- :var http.client.HTTPMessage reply_headers: headers provided in the response,
+ :var dict reply_headers: headers provided in the response,
**None** if we haven't yet made our request
:var float runtime: time our query took, this is **None** if it's not yet
finished
@@ -330,13 +327,9 @@ class AsyncQuery(object):
:var float timeout: duration before we'll time out our request
:var str download_url: last url used to download the descriptor, this is
unset until we've actually made a download attempt
-
- :param start: start making the request when constructed (default is **True**)
- :param block: only return after the request has been completed, this is
- the same as running **query.run(True)** (default is **False**)
"""
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
if not resource.startswith('/'):
raise ValueError("Resources should start with a '/': %s" % resource)
@@ -395,22 +388,15 @@ class AsyncQuery(object):
self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()
- self._asyncio_loop = asyncio.get_event_loop()
-
- if start:
- self.start()
-
- if block:
- self.run(True)
-
- def start(self) -> None:
+ async def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
with self._downloader_lock:
if self._downloader_task is None:
- self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
+ loop = asyncio.get_running_loop()
+ self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
@@ -434,7 +420,7 @@ class AsyncQuery(object):
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
- self.start()
+ await self.start()
await self._downloader_task
if self.error:
@@ -491,36 +477,71 @@ class AsyncQuery(object):
return random.choice(self.endpoints)
async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
- try:
- self.start_time = time.time()
+ self.start_time = time.time()
+
+ retries = self.retries
+ time_remaining = self.timeout
+
+ while True:
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort):
downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
- self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
elif isinstance(endpoint, stem.DirPort):
- self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
- downloaded_from = self.download_url
- self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
+ downloaded_from = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
+ self.download_url = downloaded_from
else:
raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
- self.runtime = time.time() - self.start_time
- log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
- except:
- exc = sys.exc_info()[1]
+ try:
+ response = await asyncio.wait_for(self._download_from(endpoint), time_remaining)
+ self.content, self.reply_headers = _http_body_and_headers(response)
+
+ self.is_done = True
+ self.runtime = time.time() - self.start_time
+
+ log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
+ return
+ except asyncio.TimeoutError as exc:
+ self.is_done = True
+ self.error = stem.DownloadTimeout(downloaded_from, exc, sys.exc_info()[2], self.timeout)
+ return
+ except:
+ exception = sys.exc_info()[1]
+ retries -= 1
+
+ if time_remaining is not None:
+ time_remaining -= time.time() - self.start_time
+
+ if retries > 0:
+ log.debug("Failed to download descriptors from '%s' (%i retries remaining): %s" % (downloaded_from, retries, exception))
+ else:
+ log.debug("Failed to download descriptors from '%s': %s" % (self.download_url, exception))
+
+ self.is_done = True
+ self.error = exception
+ return
- if timeout is not None:
- timeout -= time.time() - self.start_time
+ async def _download_from(self, endpoint: stem.Endpoint) -> bytes:
+ http_request = '\r\n'.join((
+ 'GET %s HTTP/1.0' % self.resource,
+ 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, self.compression)),
+ 'User-Agent: %s' % stem.USER_AGENT,
+ )) + '\r\n\r\n'
- if retries > 0 and (timeout is None or timeout > 0):
- log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
- return await self._download_descriptors(retries - 1, timeout)
- else:
- log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
- self.error = exc
- finally:
- self.is_done = True
+ if isinstance(endpoint, stem.ORPort):
+ link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
+
+ async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+ async with await relay.create_circuit() as circ:
+ return await circ.directory(http_request, stream_id = 1)
+ elif isinstance(endpoint, stem.DirPort):
+ reader, writer = await asyncio.open_connection(endpoint.address, endpoint.port)
+ writer.write(str_tools._to_bytes(http_request))
+
+ return await reader.read()
+ else:
+ raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
class Query(stem.util.AsyncClassWrapper):
@@ -663,8 +684,8 @@ class Query(stem.util.AsyncClassWrapper):
"""
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
- self._loop = asyncio.get_event_loop()
- self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio')
self._loop_thread.setDaemon(True)
self._loop_thread.start()
@@ -677,19 +698,23 @@ class Query(stem.util.AsyncClassWrapper):
retries,
fall_back_to_authority,
timeout,
- start,
- block,
validate,
document_handler,
**kwargs,
)
+ if start:
+ self.start()
+
+ if block:
+ self.run(True)
+
def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
- self._call_async_method_soon('start')
+ self._execute_async_method('start')
def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
"""
@@ -1146,10 +1171,9 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+def _http_body_and_headers(data: bytes) -> Tuple[bytes, Dict[str, str]]:
"""
- Downloads descriptors from the given orport. Payload is just like an http
- response (headers and all)...
+ Parse the headers and decompressed body from a HTTP response, such as...
::
@@ -1164,112 +1188,41 @@ async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[ste
identity-ed25519
... rest of the descriptor content...
- :param endpoint: endpoint to download from
- :param compression: compression methods for the request
- :param resource: descriptor resource to download
+ :param data: HTTP response
- :returns: two value tuple of the form (data, reply_headers)
+ :returns: **tuple** with the decompressed data and headers
:raises:
- * :class:`stem.ProtocolError` if not a valid descriptor response
- * :class:`stem.SocketError` if unable to establish a connection
- """
-
- link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
-
- async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
- async with await relay.create_circuit() as circ:
- request = '\r\n'.join((
- 'GET %s HTTP/1.0' % resource,
- 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent: %s' % stem.USER_AGENT,
- )) + '\r\n\r\n'
-
- response = await circ.directory(request, stream_id = 1)
- first_line, data = response.split(b'\r\n', 1)
- header_data, body_data = data.split(b'\r\n\r\n', 1)
-
- if not first_line.startswith(b'HTTP/1.0 2'):
- raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
-
- headers = {}
-
- for line in str_tools._to_unicode(header_data).splitlines():
- if ': ' not in line:
- raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
-
- key, value = line.split(': ', 1)
- headers[key] = value
-
- return _decompress(body_data, headers.get('Content-Encoding')), headers
-
-
-async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
- """
- Downloads descriptors from the given url.
-
- :param url: dirport url from which to download from
- :param compression: compression methods for the request
- :param timeout: duration before we'll time out our request
-
- :returns: two value tuple of the form (data, reply_headers)
-
- :raises:
- * :class:`~stem.DownloadTimeout` if our request timed out
- * :class:`~stem.DownloadFailed` if our request fails
- """
-
- # TODO: use an asyncronous solution for the HTTP request.
- request = urllib.request.Request(
- url,
- headers = {
- 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent': stem.USER_AGENT,
- }
- )
- get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
-
- loop = asyncio.get_event_loop()
- try:
- response = await loop.run_in_executor(None, get_response)
- except socket.timeout as exc:
- raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
- except:
- exception, stacktrace = sys.exc_info()[1:3]
- raise stem.DownloadFailed(url, exception, stacktrace)
-
- return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
-
-
-def _decompress(data: bytes, encoding: str) -> bytes:
+ * **stem.ProtocolError** if response was unsuccessful or malformed
+ * **ValueError** if encoding is unrecognized
+ * **ImportError** if missing the decompression module
"""
- Decompresses descriptor data.
- Tor doesn't include compression headers. As such when using gzip we
- need to include '32' for automatic header detection...
+ first_line, data = data.split(b'\r\n', 1)
+ header_data, body_data = data.split(b'\r\n\r\n', 1)
- https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decomp…
+ if not first_line.startswith(b'HTTP/1.0 2'):
+ raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
- ... and with zstd we need to use the streaming API.
+ headers = {}
- :param data: data we received
- :param encoding: 'Content-Encoding' header of the response
+ for line in str_tools._to_unicode(header_data).splitlines():
+ if ': ' not in line:
+ raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
- :returns: **bytes** with the decompressed data
+ key, value = line.split(': ', 1)
+ headers[key] = value
- :raises:
- * **ValueError** if encoding is unrecognized
- * **ImportError** if missing the decompression module
- """
+ encoding = headers.get('Content-Encoding')
if encoding == 'deflate':
- return stem.descriptor.Compression.GZIP.decompress(data)
+ return stem.descriptor.Compression.GZIP.decompress(body_data), headers
for compression in stem.descriptor.Compression:
if encoding == compression.encoding:
- return compression.decompress(data)
+ return compression.decompress(body_data), headers
- raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
+ raise ValueError("'%s' is an unrecognized encoding" % encoding)
def _guess_descriptor_type(resource: str) -> str:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 33ee57fb..797bc8a3 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -2,9 +2,6 @@
Unit tests for stem.descriptor.remote.
"""
-import http.client
-import socket
-import time
import unittest
import stem
@@ -67,47 +64,13 @@ HEADER = '\r\n'.join([
])
-def _orport_mock(data, encoding = 'identity', response_code_header = None):
+def mock_download(descriptor, encoding = 'identity', response_code_header = None):
if response_code_header is None:
response_code_header = b'HTTP/1.0 200 OK\r\n'
- data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data
- cells = []
+ data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
- for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]:
- cell = Mock()
- cell.data = hunk
- cells.append(cell)
-
- class AsyncMock(Mock):
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- return
-
- circ_mock = AsyncMock()
- circ_mock.directory.side_effect = coro_func_returning_value(data)
-
- relay_mock = AsyncMock()
- relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
-
- return coro_func_returning_value(relay_mock)
-
-
-def _dirport_mock(data, encoding = 'identity'):
- dirport_mock = Mock()
- dirport_mock().read.return_value = data
-
- headers = http.client.HTTPMessage()
-
- for line in HEADER.splitlines():
- key, value = line.split(': ', 1)
- headers.add_header(key, encoding if key == 'Content-Encoding' else value)
-
- dirport_mock().headers = headers
-
- return dirport_mock
+ return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
class TestDescriptorDownloader(unittest.TestCase):
@@ -115,10 +78,10 @@ class TestDescriptorDownloader(unittest.TestCase):
# prevent our mocks from impacting other tests
stem.descriptor.remote.SINGLETON_DOWNLOADER = None
- @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR))
- def test_using_orport(self):
+ @mock_download(TEST_DESCRIPTOR)
+ def test_download(self):
"""
- Download a descriptor through the ORPort.
+ Simply download and parse a descriptor.
"""
reply = stem.descriptor.remote.their_server_descriptor(
@@ -128,10 +91,16 @@ class TestDescriptorDownloader(unittest.TestCase):
)
self.assertEqual(1, len(list(reply)))
- self.assertEqual('moria1', list(reply)[0].nickname)
self.assertEqual(5, len(reply.reply_headers))
- def test_orport_response_code_headers(self):
+ desc = list(reply)[0]
+
+ self.assertEqual('moria1', desc.nickname)
+ self.assertEqual('128.31.0.34', desc.address)
+ self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
+ self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
+
+ def test_response_header_code(self):
"""
When successful Tor provides a '200 OK' status, but we should accept other 2xx
response codes, reason text, and recognize HTTP errors.
@@ -144,14 +113,14 @@ class TestDescriptorDownloader(unittest.TestCase):
)
for header in response_code_headers:
- with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = header)):
+ with mock_download(TEST_DESCRIPTOR, response_code_header = header):
stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.ORPort('12.34.56.78', 1100)],
validate = True,
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
).run()
- with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n')):
+ with mock_download(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n'):
request = stem.descriptor.remote.their_server_descriptor(
endpoints = [stem.ORPort('12.34.56.78', 1100)],
validate = True,
@@ -160,28 +129,32 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaisesRegexp(stem.ProtocolError, "^Response should begin with HTTP success, but was 'HTTP/1.0 500 Kaboom'", request.run)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_using_dirport(self):
- """
- Download a descriptor through the DirPort.
- """
+ @mock_download(TEST_DESCRIPTOR)
+ def test_reply_header_data(self):
+ query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
+ self.assertEqual(None, query.reply_headers) # initially we don't have a reply
+ query.run()
- reply = stem.descriptor.remote.their_server_descriptor(
- endpoints = [stem.DirPort('12.34.56.78', 1100)],
- validate = True,
- skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- )
+ self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
+ self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
+ self.assertEqual('97.103.17.56', query.reply_headers.get('X-Your-Address-Is'))
+ self.assertEqual('no-cache', query.reply_headers.get('Pragma'))
+ self.assertEqual('identity', query.reply_headers.get('Content-Encoding'))
- self.assertEqual(1, len(list(reply)))
- self.assertEqual('moria1', list(reply)[0].nickname)
- self.assertEqual(5, len(reply.reply_headers))
+ # request a header that isn't present
+ self.assertEqual(None, query.reply_headers.get('no-such-header'))
+ self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
+
+ descriptors = list(query)
+ self.assertEqual(1, len(descriptors))
+ self.assertEqual('moria1', descriptors[0].nickname)
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_identity'), encoding = 'identity'))
+ @mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
"""
Download a plaintext descriptor.
@@ -197,7 +170,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_gzip'), encoding = 'gzip'))
+ @mock_download(read_resource('compressed_gzip'), encoding = 'gzip')
def test_compression_gzip(self):
"""
Download a gip compressed descriptor.
@@ -213,7 +186,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_zstd'), encoding = 'x-zstd'))
+ @mock_download(read_resource('compressed_zstd'), encoding = 'x-zstd')
def test_compression_zstd(self):
"""
Download a zstd compressed descriptor.
@@ -231,7 +204,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_lzma'), encoding = 'x-tor-lzma'))
+ @mock_download(read_resource('compressed_lzma'), encoding = 'x-tor-lzma')
def test_compression_lzma(self):
"""
Download a lzma compressed descriptor.
@@ -249,8 +222,8 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen')
- def test_each_getter(self, dirport_mock):
+ @mock_download(TEST_DESCRIPTOR)
+ def test_each_getter(self):
"""
Surface level exercising of each getter method for downloading descriptors.
"""
@@ -266,57 +239,8 @@ class TestDescriptorDownloader(unittest.TestCase):
downloader.get_bandwidth_file()
downloader.get_detached_signatures()
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_reply_headers(self):
- query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
- self.assertEqual(None, query.reply_headers) # initially we don't have a reply
- query.run()
-
- self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('date'))
- self.assertEqual('application/octet-stream', query.reply_headers.get('content-type'))
- self.assertEqual('97.103.17.56', query.reply_headers.get('x-your-address-is'))
- self.assertEqual('no-cache', query.reply_headers.get('pragma'))
- self.assertEqual('identity', query.reply_headers.get('content-encoding'))
-
- # getting headers should be case insensitive
- self.assertEqual('identity', query.reply_headers.get('CoNtEnT-ENCODING'))
-
- # request a header that isn't present
- self.assertEqual(None, query.reply_headers.get('no-such-header'))
- self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
-
- descriptors = list(query)
- self.assertEqual(1, len(descriptors))
- self.assertEqual('moria1', descriptors[0].nickname)
-
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
- def test_query_download(self):
- """
- Check Query functionality when we successfully download a descriptor.
- """
-
- query = stem.descriptor.remote.Query(
- TEST_RESOURCE,
- 'server-descriptor 1.0',
- endpoints = [stem.DirPort('128.31.0.39', 9131)],
- compression = Compression.PLAINTEXT,
- validate = True,
- skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- )
-
- self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
-
- descriptors = list(query)
- self.assertEqual(1, len(descriptors))
- desc = descriptors[0]
-
- self.assertEqual('moria1', desc.nickname)
- self.assertEqual('128.31.0.34', desc.address)
- self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
- self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
-
- @patch('urllib.request.urlopen', _dirport_mock(b'some malformed stuff'))
- def test_query_with_malformed_content(self):
+ @mock_download(b'some malformed stuff')
+ def test_malformed_content(self):
"""
Query with malformed descriptor content.
"""
@@ -340,29 +264,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- @patch('urllib.request.urlopen')
- def test_query_with_timeout(self, dirport_mock):
- def urlopen_call(*args, **kwargs):
- time.sleep(0.06)
- raise socket.timeout('connection timed out')
-
- dirport_mock.side_effect = urlopen_call
-
- query = stem.descriptor.remote.Query(
- TEST_RESOURCE,
- 'server-descriptor 1.0',
- endpoints = [stem.DirPort('128.31.0.39', 9131)],
- fall_back_to_authority = False,
- timeout = 0.1,
- validate = True,
- )
-
- # After two requests we'll have reached our total permissable timeout.
- # It would be nice to check that we don't make a third, but this
- # assertion has proved unreliable so only checking for the exception.
-
- self.assertRaises(stem.DownloadTimeout, query.run)
-
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
'hello': "'h' is a str.",
@@ -375,7 +276,7 @@ class TestDescriptorDownloader(unittest.TestCase):
expected_error = 'Endpoints must be an stem.ORPort or stem.DirPort. ' + error_suffix
self.assertRaisesWith(ValueError, expected_error, stem.descriptor.remote.Query, TEST_RESOURCE, 'server-descriptor 1.0', endpoints = endpoints)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
+ @mock_download(TEST_DESCRIPTOR)
def test_can_iterate_multiple_times(self):
query = stem.descriptor.remote.Query(
TEST_RESOURCE,
1
0

[stem/master] Start awaiting finishing of the loop tasks while closing controllers
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit f9de9a9612d639337090715e0b84d44129a0288a
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun May 24 01:18:31 2020 +0300
Start awaiting finishing of the loop tasks while closing controllers
---
stem/control.py | 17 ++++++++++++++---
test/integ/control/base_controller.py | 11 ++++++++---
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index c26da351..6ca6c23b 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -648,6 +648,8 @@ class BaseController(_BaseControllerSocketMixin):
self._state_change_threads = [] # type: List[threading.Thread] # threads we've spawned to notify of state changes
+ self._reader_loop_task = None # type: Optional[asyncio.Task]
+ self._event_loop_task = None # type: Optional[asyncio.Task]
if self._socket.is_alive():
self._create_loop_tasks()
@@ -868,6 +870,15 @@ class BaseController(_BaseControllerSocketMixin):
self._event_notice.set()
self._is_authenticated = False
+ reader_loop_task = self._reader_loop_task
+ self._reader_loop_task = None
+ event_loop_task = self._event_loop_task
+ self._event_loop_task = None
+ if reader_loop_task and self.is_alive():
+ await reader_loop_task
+ if event_loop_task:
+ await event_loop_task
+
self._notify_status_listeners(State.CLOSED)
await self._socket_close()
@@ -923,12 +934,12 @@ class BaseController(_BaseControllerSocketMixin):
def _create_loop_tasks(self) -> None:
"""
- Initializes daemon threads. Threads can't be reused so we need to recreate
+ Initializes asyncio tasks. Tasks can't be reused so we need to recreate
them if we're restarted.
"""
- for coroutine in (self._reader_loop(), self._event_loop()):
- self._asyncio_loop.create_task(coroutine)
+ self._reader_loop_task = self._asyncio_loop.create_task(self._reader_loop())
+ self._event_loop_task = self._asyncio_loop.create_task(self._event_loop())
async def _reader_loop(self) -> None:
"""
diff --git a/test/integ/control/base_controller.py b/test/integ/control/base_controller.py
index ac5f1e56..8fc5f1a2 100644
--- a/test/integ/control/base_controller.py
+++ b/test/integ/control/base_controller.py
@@ -161,9 +161,14 @@ class TestBaseController(unittest.TestCase):
await controller.msg('SETEVENTS')
await controller.msg('RESETCONF NodeFamily')
- await controller.close()
- controller.receive_notice.set()
- await asyncio.sleep(0)
+ # We need to set the receive notice and shut down the controller
+ # concurrently because the controller will block on the event handling,
+ # which in turn is currently blocking on the reveive_notice.
+
+ async def set_receive_notice():
+ controller.receive_notice.set()
+
+ await asyncio.gather(controller.close(), set_receive_notice())
self.assertTrue(len(controller.received_events) >= 2)
1
0

[stem/master] Make `test.runner.Runner.get_tor_controller` synchronous
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit a7fbfadee6e11c270a37e93c4d67363bbbcd6629
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu May 21 23:25:38 2020 +0300
Make `test.runner.Runner.get_tor_controller` synchronous
---
test/integ/connection/authentication.py | 9 +-
test/integ/control/controller.py | 245 +++++++++++++-------------------
test/integ/manual.py | 6 +-
test/runner.py | 2 +-
4 files changed, 106 insertions(+), 156 deletions(-)
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 3eaae8d9..3709c275 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -121,7 +121,7 @@ class TestAuthenticate(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller(False) as controller:
+ with runner.get_tor_controller(False) as controller:
asyncio.run_coroutine_threadsafe(
stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
controller._thread_for_wrapped_class.loop,
@@ -276,8 +276,7 @@ class TestAuthenticate(unittest.TestCase):
await self._check_auth(auth_type, auth_value)
@test.require.controller
- @async_test
- async def test_wrong_password_with_controller(self):
+ def test_wrong_password_with_controller(self):
"""
We ran into a race condition where providing the wrong password to the
Controller caused inconsistent responses. Checking for that...
@@ -291,9 +290,9 @@ class TestAuthenticate(unittest.TestCase):
self.skipTest('(requires only password auth)')
for i in range(10):
- with await runner.get_tor_controller(False) as controller:
+ with runner.get_tor_controller(False) as controller:
with self.assertRaises(stem.connection.IncorrectPassword):
- await controller.authenticate('wrong_password')
+ controller.authenticate('wrong_password')
@test.require.controller
@async_test
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index b1772f34..7853d407 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -38,14 +38,13 @@ TEST_ROUTER_STATUS_ENTRY = None
class TestController(unittest.TestCase):
@test.require.only_run_once
@test.require.controller
- @async_test
- async def test_missing_capabilities(self):
+ def test_missing_capabilities(self):
"""
Check to see if tor supports any events, signals, or features that we
don't.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
for event in controller.get_info('events/names').split():
if event not in EventType:
test.register_new_capability('Event', event)
@@ -89,7 +88,7 @@ class TestController(unittest.TestCase):
Checks that a notificiation listener is... well, notified of SIGHUPs.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
received_events = []
def status_listener(my_controller, state, timestamp):
@@ -120,8 +119,7 @@ class TestController(unittest.TestCase):
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- @async_test
- async def test_event_handling(self):
+ def test_event_handling(self):
"""
Add a couple listeners for various events and make sure that they receive
them. Then remove the listeners.
@@ -140,7 +138,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
controller.add_event_listener(listener1, EventType.CONF_CHANGED)
controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
@@ -179,8 +177,7 @@ class TestController(unittest.TestCase):
controller.reset_conf('NodeFamily')
@test.require.controller
- @async_test
- async def test_reattaching_listeners(self):
+ def test_reattaching_listeners(self):
"""
Checks that event listeners are re-attached when a controller disconnects
then reconnects to tor.
@@ -195,7 +192,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
controller.add_event_listener(listener, EventType.CONF_CHANGED)
# trigger an event
@@ -221,15 +218,14 @@ class TestController(unittest.TestCase):
controller.reset_conf('NodeFamily')
@test.require.controller
- @async_test
- async def test_getinfo(self):
+ def test_getinfo(self):
"""
Exercises GETINFO with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# successful single query
torrc_path = runner.get_torrc_path()
@@ -260,13 +256,12 @@ class TestController(unittest.TestCase):
self.assertEqual({}, controller.get_info([], {}))
@test.require.controller
- @async_test
- async def test_getinfo_freshrelaydescs(self):
+ def test_getinfo_freshrelaydescs(self):
"""
Exercises 'GETINFO status/fresh-relay-descs'.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
response = controller.get_info('status/fresh-relay-descs')
div = response.find('\nextra-info ')
nickname = controller.get_conf('Nickname')
@@ -284,13 +279,12 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_getinfo_dir_status(self):
+ def test_getinfo_dir_status(self):
"""
Exercise 'GETINFO dir/status-vote/*'.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
consensus = controller.get_info('dir/status-vote/current/consensus')
self.assertTrue('moria1' in consensus, 'moria1 not found in the consensus')
@@ -299,26 +293,24 @@ class TestController(unittest.TestCase):
self.assertTrue('moria1' in microdescs, 'moria1 not found in the microdescriptor consensus')
@test.require.controller
- @async_test
- async def test_get_version(self):
+ def test_get_version(self):
"""
Test that the convenient method get_version() works.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
version = controller.get_version()
self.assertTrue(isinstance(version, stem.version.Version))
self.assertEqual(version, test.tor_version())
@test.require.controller
- @async_test
- async def test_get_exit_policy(self):
+ def test_get_exit_policy(self):
"""
Sanity test for get_exit_policy(). Our 'ExitRelay 0' torrc entry causes us
to have a simple reject-all policy.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(ExitPolicy('reject *:*'), controller.get_exit_policy())
@test.require.controller
@@ -330,20 +322,19 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller(False) as controller:
+ with runner.get_tor_controller(False) as controller:
controller.authenticate(test.runner.CONTROL_PASSWORD)
await test.runner.exercise_controller(self, controller)
@test.require.controller
- @async_test
- async def test_protocolinfo(self):
+ def test_protocolinfo(self):
"""
Test that the convenient method protocolinfo() works.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller(False) as controller:
+ with runner.get_tor_controller(False) as controller:
protocolinfo = controller.get_protocolinfo()
self.assertTrue(isinstance(protocolinfo, stem.response.protocolinfo.ProtocolInfoResponse))
@@ -364,15 +355,14 @@ class TestController(unittest.TestCase):
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
@test.require.controller
- @async_test
- async def test_getconf(self):
+ def test_getconf(self):
"""
Exercises GETCONF with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
control_socket = controller.get_socket()
if isinstance(control_socket, stem.socket.ControlPort):
@@ -428,15 +418,14 @@ class TestController(unittest.TestCase):
self.assertEqual({}, controller.get_conf_map([], 'la-di-dah'))
@test.require.controller
- @async_test
- async def test_is_set(self):
+ def test_is_set(self):
"""
Exercises our is_set() method.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
custom_options = controller._execute_async_method('_get_custom_options')
self.assertTrue('ControlPort' in custom_options or 'ControlSocket' in custom_options)
self.assertEqual('1', custom_options['DownloadExtraInfo'])
@@ -456,8 +445,7 @@ class TestController(unittest.TestCase):
self.assertFalse(controller.is_set('ConnLimit'))
@test.require.controller
- @async_test
- async def test_hidden_services_conf(self):
+ def test_hidden_services_conf(self):
"""
Exercises the hidden service family of methods (get_hidden_service_conf,
set_hidden_service_conf, create_hidden_service, and remove_hidden_service).
@@ -471,7 +459,7 @@ class TestController(unittest.TestCase):
service3_path = os.path.join(test_dir, 'test_hidden_service3')
service4_path = os.path.join(test_dir, 'test_hidden_service4')
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
try:
# initially we shouldn't be running any hidden services
@@ -565,35 +553,32 @@ class TestController(unittest.TestCase):
pass
@test.require.controller
- @async_test
- async def test_without_ephemeral_hidden_services(self):
+ def test_without_ephemeral_hidden_services(self):
"""
Exercises ephemeral hidden service methods when none are present.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual([], controller.list_ephemeral_hidden_services())
self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
self.assertEqual(False, controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
@test.require.controller
- @async_test
- async def test_with_invalid_ephemeral_hidden_service_port(self):
- with await test.runner.get_runner().get_tor_controller() as controller:
+ def test_with_invalid_ephemeral_hidden_service_port(self):
+ with test.runner.get_runner().get_tor_controller() as controller:
for ports in (4567890, [4567, 4567890], {4567: '-:4567'}):
exc_msg = "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"
self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, ports)
@test.require.controller
- @async_test
- async def test_ephemeral_hidden_services_v2(self):
+ def test_ephemeral_hidden_services_v2(self):
"""
Exercises creating v2 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -625,20 +610,19 @@ class TestController(unittest.TestCase):
# other controllers shouldn't be able to see these hidden services
- with await runner.get_tor_controller() as second_controller:
+ with runner.get_tor_controller() as second_controller:
self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- @async_test
- async def test_ephemeral_hidden_services_v3(self):
+ def test_ephemeral_hidden_services_v3(self):
"""
Exercises creating v3 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -670,20 +654,19 @@ class TestController(unittest.TestCase):
# other controllers shouldn't be able to see these hidden services
- with await runner.get_tor_controller() as second_controller:
+ with runner.get_tor_controller() as second_controller:
self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- @async_test
- async def test_with_ephemeral_hidden_services_basic_auth(self):
+ def test_with_ephemeral_hidden_services_basic_auth(self):
"""
Exercises creating ephemeral hidden services that uses basic authentication.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -695,8 +678,7 @@ class TestController(unittest.TestCase):
self.assertEqual([], controller.list_ephemeral_hidden_services())
@test.require.controller
- @async_test
- async def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
+ def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
"""
Exercises creating ephemeral hidden services when attempting to use basic
auth but not including any credentials.
@@ -704,13 +686,12 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
exc_msg = "ADD_ONION response didn't have an OK status: No auth clients specified"
self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, 4567, basic_auth = {})
@test.require.controller
- @async_test
- async def test_with_detached_ephemeral_hidden_services(self):
+ def test_with_detached_ephemeral_hidden_services(self):
"""
Exercises creating detached ephemeral hidden services and methods when
they're present.
@@ -718,7 +699,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, detached = True)
self.assertEqual([], controller.list_ephemeral_hidden_services())
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
@@ -732,7 +713,7 @@ class TestController(unittest.TestCase):
# other controllers should be able to see this service, and drop it
- with await runner.get_tor_controller() as second_controller:
+ with runner.get_tor_controller() as second_controller:
self.assertEqual([response.service_id], second_controller.list_ephemeral_hidden_services(detached = True))
self.assertEqual(True, second_controller.remove_ephemeral_hidden_service(response.service_id))
self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
@@ -745,8 +726,7 @@ class TestController(unittest.TestCase):
controller.remove_ephemeral_hidden_service(response.service_id)
@test.require.controller
- @async_test
- async def test_rejecting_unanonymous_hidden_services_creation(self):
+ def test_rejecting_unanonymous_hidden_services_creation(self):
"""
Attempt to create a non-anonymous hidden service despite not setting
HiddenServiceSingleHopMode and HiddenServiceNonAnonymousMode.
@@ -754,12 +734,11 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
self.assertEqual('Tor is in anonymous hidden service mode', str(controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
@test.require.controller
- @async_test
- async def test_set_conf(self):
+ def test_set_conf(self):
"""
Exercises set_conf(), reset_conf(), and set_options() methods with valid
and invalid requests.
@@ -769,7 +748,7 @@ class TestController(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
try:
# successfully set a single option
connlimit = int(controller.get_conf('ConnLimit'))
@@ -832,14 +811,13 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- @async_test
- async def test_set_conf_for_usebridges(self):
+ def test_set_conf_for_usebridges(self):
"""
Ensure we can set UseBridges=1 and also set a Bridge. This is a tor
regression check (:trac:`31945`).
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
orport = controller.get_conf('ORPort')
try:
@@ -856,26 +834,24 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- @async_test
- async def test_set_conf_when_immutable(self):
+ def test_set_conf_when_immutable(self):
"""
Issue a SETCONF for tor options that cannot be changed while running.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running", controller.set_conf, 'DisableAllSwap', '1')
self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running", controller.set_options, {'User': 'atagar', 'DisableAllSwap': '1'})
@test.require.controller
- @async_test
- async def test_loadconf(self):
+ def test_loadconf(self):
"""
Exercises Controller.load_conf with valid and invalid requests.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -905,13 +881,12 @@ class TestController(unittest.TestCase):
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- @async_test
- async def test_saveconf(self):
+ def test_saveconf(self):
runner = test.runner.get_runner()
# only testing for success, since we need to run out of disk space to test
# for failure
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -926,15 +901,14 @@ class TestController(unittest.TestCase):
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- @async_test
- async def test_get_ports(self):
+ def test_get_ports(self):
"""
Test Controller.get_ports against a running tor instance.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
self.assertEqual([test.runner.ORPORT], controller.get_ports(Listener.OR))
self.assertEqual([], controller.get_ports(Listener.DIR))
self.assertEqual([test.runner.SOCKS_PORT], controller.get_ports(Listener.SOCKS))
@@ -948,15 +922,14 @@ class TestController(unittest.TestCase):
self.assertEqual([], controller.get_ports(Listener.CONTROL))
@test.require.controller
- @async_test
- async def test_get_listeners(self):
+ def test_get_listeners(self):
"""
Test Controller.get_listeners against a running tor instance.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
self.assertEqual([('0.0.0.0', test.runner.ORPORT)], controller.get_listeners(Listener.OR))
self.assertEqual([], controller.get_listeners(Listener.DIR))
self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], controller.get_listeners(Listener.SOCKS))
@@ -972,15 +945,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
@test.require.version(stem.version.Version('0.1.2.2-alpha'))
- @async_test
- async def test_enable_feature(self):
+ def test_enable_feature(self):
"""
Test Controller.enable_feature with valid and invalid inputs.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
self.assertTrue(controller.is_feature_enabled('VERBOSE_NAMES'))
self.assertRaises(stem.InvalidArguments, controller.enable_feature, ['NOT', 'A', 'FEATURE'])
@@ -992,13 +964,12 @@ class TestController(unittest.TestCase):
self.fail()
@test.require.controller
- @async_test
- async def test_signal(self):
+ def test_signal(self):
"""
Test controller.signal with valid and invalid signals.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
# valid signal
controller.signal('CLEARDNSCACHE')
@@ -1006,13 +977,12 @@ class TestController(unittest.TestCase):
self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR')
@test.require.controller
- @async_test
- async def test_newnym_availability(self):
+ def test_newnym_availability(self):
"""
Test the is_newnym_available and get_newnym_wait methods.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(True, controller.is_newnym_available())
self.assertEqual(0.0, controller.get_newnym_wait())
@@ -1023,9 +993,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_extendcircuit(self):
- with await test.runner.get_runner().get_tor_controller() as controller:
+ def test_extendcircuit(self):
+ with test.runner.get_runner().get_tor_controller() as controller:
circuit_id = controller.extend_circuit('0')
# check if our circuit was created
@@ -1039,15 +1008,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_repurpose_circuit(self):
+ def test_repurpose_circuit(self):
"""
Tests Controller.repurpose_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
circ_id = controller.new_circuit()
controller.repurpose_circuit(circ_id, 'CONTROLLER')
circuit = controller.get_circuit(circ_id)
@@ -1062,15 +1030,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_close_circuit(self):
+ def test_close_circuit(self):
"""
Tests Controller.close_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id)
circuit_output = controller.get_info('circuit-status')
@@ -1089,8 +1056,7 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_streams(self):
+ def test_get_streams(self):
"""
Tests Controller.get_streams().
"""
@@ -1099,7 +1065,7 @@ class TestController(unittest.TestCase):
port = 443
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# we only need one proxy port, so take the first
socks_listener = controller.get_listeners(Listener.SOCKS)[0]
@@ -1115,15 +1081,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_close_stream(self):
+ def test_close_stream(self):
"""
Tests Controller.close_stream with valid and invalid input.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# use the first socks listener
socks_listener = controller.get_listeners(Listener.SOCKS)[0]
@@ -1155,12 +1120,11 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_mapaddress(self):
+ def test_mapaddress(self):
self.skipTest('(https://trac.torproject.org/projects/tor/ticket/25611)')
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
controller.map_address({'1.2.1.2': 'ifconfig.me'})
s = None
@@ -1194,11 +1158,10 @@ class TestController(unittest.TestCase):
self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)), "'%s' isn't an address" % ip_addr)
@test.require.controller
- @async_test
- async def test_mapaddress_offline(self):
+ def test_mapaddress_offline(self):
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# try mapping one element, ensuring results are as expected
map1 = {'1.2.1.2': 'ifconfig.me'}
@@ -1274,13 +1237,12 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_microdescriptor(self):
+ def test_get_microdescriptor(self):
"""
Basic checks for get_microdescriptor().
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_microdescriptor, '')
self.assertRaises(ValueError, controller.get_microdescriptor, 5)
@@ -1299,8 +1261,7 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_microdescriptors(self):
+ def test_get_microdescriptors(self):
"""
Fetches a few descriptors via the get_microdescriptors() method.
"""
@@ -1310,7 +1271,7 @@ class TestController(unittest.TestCase):
if not os.path.exists(runner.get_test_dir('cached-microdescs')):
self.skipTest('(no cached microdescriptors)')
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_microdescriptors():
@@ -1322,15 +1283,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_server_descriptor(self):
+ def test_get_server_descriptor(self):
"""
Basic checks for get_server_descriptor().
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_server_descriptor, '')
self.assertRaises(ValueError, controller.get_server_descriptor, 5)
@@ -1349,15 +1309,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_server_descriptors(self):
+ def test_get_server_descriptors(self):
"""
Fetches a few descriptors via the get_server_descriptors() method.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_server_descriptors():
@@ -1375,13 +1334,12 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_network_status(self):
+ def test_get_network_status(self):
"""
Basic checks for get_network_status().
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_network_status, '')
self.assertRaises(ValueError, controller.get_network_status, 5)
@@ -1400,15 +1358,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_network_statuses(self):
+ def test_get_network_statuses(self):
"""
Fetches a few descriptors via the get_network_statuses() method.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_network_statuses():
@@ -1424,15 +1381,14 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_hidden_service_descriptor(self):
+ def test_get_hidden_service_descriptor(self):
"""
Fetches a few descriptors via the get_hidden_service_descriptor() method.
"""
runner = test.runner.get_runner()
- with await runner.get_tor_controller() as controller:
+ with runner.get_tor_controller() as controller:
# fetch the descriptor for DuckDuckGo
desc = controller.get_hidden_service_descriptor('3g2upl4pq6kufc4m.onion')
@@ -1450,8 +1406,7 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_attachstream(self):
+ def test_attachstream(self):
host = socket.gethostbyname('www.torproject.org')
port = 80
@@ -1461,7 +1416,7 @@ class TestController(unittest.TestCase):
if stream.status == 'NEW' and circuit_id:
controller.attach_stream(stream.id, circuit_id)
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
# try 10 times to build a circuit we can connect through
for i in range(10):
controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
@@ -1491,26 +1446,24 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- @async_test
- async def test_get_circuits(self):
+ def test_get_circuits(self):
"""
Fetches circuits via the get_circuits() method.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
new_circ = controller.new_circuit()
circuits = controller.get_circuits()
self.assertTrue(new_circ in [circ.id for circ in circuits])
@test.require.controller
- @async_test
- async def test_transition_to_relay(self):
+ def test_transition_to_relay(self):
"""
Transitions Tor to turn into a relay, then back to a client. This helps to
catch transition issues such as the one cited in :trac:`14901`.
"""
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
try:
controller.reset_conf('OrPort', 'DisableNetwork')
self.assertEqual(None, controller.get_conf('OrPort'))
diff --git a/test/integ/manual.py b/test/integ/manual.py
index d285c758..3e721ac6 100644
--- a/test/integ/manual.py
+++ b/test/integ/manual.py
@@ -14,7 +14,6 @@ import test
import test.runner
from stem.manual import Category
-from stem.util.test_tools import async_test
EXPECTED_CATEGORIES = set([
'NAME',
@@ -217,15 +216,14 @@ class TestManual(unittest.TestCase):
self.assertEqual(['tor - The second-generation onion router'], categories['NAME'])
self.assertEqual(['tor [OPTION value]...'], categories['SYNOPSIS'])
- @async_test
- async def test_has_all_tor_config_options(self):
+ def test_has_all_tor_config_options(self):
"""
Check that all the configuration options tor supports are in the man page.
"""
self.requires_downloaded_manual()
- with await test.runner.get_runner().get_tor_controller() as controller:
+ with test.runner.get_runner().get_tor_controller() as controller:
config_options_in_tor = set([line.split()[0] for line in controller.get_info('config/names').splitlines() if line.split()[1] != 'Virtual'])
# options starting with an underscore are hidden by convention
diff --git a/test/runner.py b/test/runner.py
index 189a2d7b..4f237552 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -477,7 +477,7 @@ class Runner(object):
def _authenticate_controller(self, controller):
controller.authenticate(password=CONTROL_PASSWORD, chroot_path=self.get_chroot())
- async def get_tor_controller(self, authenticate = True):
+ def get_tor_controller(self, authenticate = True):
"""
Provides a controller connected to our tor test instance.
1
0
commit ef06783b96ee38230b41cc08c14032d7127f867b
Author: Damian Johnson <atagar(a)torproject.org>
Date: Tue Jul 7 16:03:11 2020 -0700
Investigate an async ainit method
Our ainit method should be asynchronous, but guess that's not to be.
Documenting the issues we encountered.
---
stem/util/__init__.py | 52 +++++++++++++++++++++++++++++++++++++++++++--------
1 file changed, 44 insertions(+), 8 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 5a8f95e0..cddce755 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -199,14 +199,11 @@ class Synchronous(object):
"""
def __init__(self) -> None:
- ainit_func = getattr(self, '__ainit__', None)
-
if Synchronous.is_asyncio_context():
self._loop = asyncio.get_running_loop()
self._loop_thread = None
- if ainit_func:
- ainit_func()
+ self.__ainit__()
else:
self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
@@ -231,11 +228,50 @@ class Synchronous(object):
if inspect.iscoroutinefunction(func):
setattr(self, method_name, functools.partial(call_async, func))
- if ainit_func:
- async def call_ainit():
- ainit_func()
+ asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
+
+ def __ainit__(self):
+ """
+ Implicitly called during construction. This method is assured to have an
+ asyncio loop during its execution.
+ """
- asyncio.run_coroutine_threadsafe(call_ainit(), self._loop).result()
+ # This method should be async (so 'await' works), but apparently that
+ # is not possible.
+ #
+ # When our object is constructed our __init__() can be called from a
+ # synchronous or asynchronous context. If synchronous, it's trivial to
+ # run an asynchronous variant of this method because we fully control
+ # the execution of our loop...
+ #
+ # asyncio.run_coroutine_threadsafe(self.__ainit__(), self._loop).result()
+ #
+ # However, when constructed from an asynchronous context the above will
+ # likely hang because our loop is already processing a task (namely,
+ # whatever is constructing us). While we can schedule tasks, we cannot
+ # invoke it during our construction.
+ #
+ # Finally, when this method is simple we could directly invoke it...
+ #
+ # class Synchronous(object):
+ # def __init__(self):
+ # if Synchronous.is_asyncio_context():
+ # try:
+ # self.__ainit__().send(None)
+ # except StopIteration:
+ # pass
+ # else:
+ # asyncio.run_coroutine_threadsafe(self.__ainit__(), self._loop).result()
+ #
+ # async def __ainit__(self):
+ # # asynchronous construction
+ #
+ # However, this breaks if any 'await' suspends our execution. For more
+ # information see...
+ #
+ # https://stackoverflow.com/questions/52783605/how-to-run-a-coroutine-outside…
+
+ pass
def close(self) -> None:
"""
1
0
commit 69be99b4aaa0ebdb038266103a7c9e748e38ef3b
Author: Damian Johnson <atagar(a)torproject.org>
Date: Wed Jun 24 17:47:23 2020 -0700
Use Synchronous for Query
Time to use our mixin in practice. Good news is that it works and *greatly*
deduplicates our code, but it's not all sunshine and ponies...
* Query users now must call close(). This is a significant hassle in terms of
usability, and must be fixed prior to release. However, it'll require some
API adustments.
* Mypy's type checks assume that Synchronous users are calling Coroutines,
causing false positives when objects are used in a synchronous fashion.
---
stem/descriptor/remote.py | 351 ++++++++---------------------------------
stem/util/__init__.py | 28 +++-
test/unit/descriptor/remote.py | 61 ++++---
3 files changed, 134 insertions(+), 306 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index f1ce79db..ad6a02ab 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -92,7 +92,6 @@ import time
import stem
import stem.client
-import stem.control
import stem.descriptor
import stem.descriptor.networkstatus
import stem.directory
@@ -100,8 +99,8 @@ import stem.util.enum
import stem.util.tor_tools
from stem.descriptor import Compression
-from stem.util import log, str_tools
-from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from stem.util import Synchronous, log, str_tools
+from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
return get_instance().get_detached_signatures(**query_args)
-class AsyncQuery(object):
+class Query(Synchronous):
"""
Asynchronous request for descriptor content from a directory authority or
mirror. These can either be made through the
@@ -235,18 +234,18 @@ class AsyncQuery(object):
advanced usage.
To block on the response and get results either call
- :func:`~stem.descriptor.remote.AsyncQuery.run` or iterate over the Query. The
- :func:`~stem.descriptor.remote.AsyncQuery.run` method pass along any errors
- that arise...
+ :func:`~stem.descriptor.remote.Query.run` or iterate over the Query. The
+ :func:`~stem.descriptor.remote.Query.run` method pass along any errors that
+ arise...
::
- from stem.descriptor.remote import AsyncQuery
+ from stem.descriptor.remote import Query
print('Current relays:')
try:
- for desc in await AsyncQuery('/tor/server/all', 'server-descriptor 1.0').run():
+ for desc in await Query('/tor/server/all', 'server-descriptor 1.0').run():
print(desc.fingerprint)
except Exception as exc:
print('Unable to retrieve the server descriptors: %s' % exc)
@@ -257,7 +256,7 @@ class AsyncQuery(object):
print('Current relays:')
- async for desc in AsyncQuery('/tor/server/all', 'server-descriptor 1.0'):
+ async for desc in Query('/tor/server/all', 'server-descriptor 1.0'):
print(desc.fingerprint)
In either case exceptions are available via our 'error' attribute.
@@ -290,6 +289,39 @@ class AsyncQuery(object):
For legacy reasons if our resource has a '.z' suffix then our **compression**
argument is overwritten with Compression.GZIP.
+ .. versionchanged:: 1.7.0
+ Added support for downloading from ORPorts.
+
+ .. versionchanged:: 1.7.0
+ Added the compression argument.
+
+ .. versionchanged:: 1.7.0
+ Added the reply_headers attribute.
+
+ The class this provides changed between Python versions. In python2
+ this was called httplib.HTTPMessage, whereas in python3 the class was
+ renamed to http.client.HTTPMessage.
+
+ .. versionchanged:: 1.7.0
+ Avoid downloading from tor26. This directory authority throttles its
+ DirPort to such an extent that requests either time out or take on the
+ order of minutes.
+
+ .. versionchanged:: 1.7.0
+ Avoid downloading from Bifroest. This is the bridge authority so it
+ doesn't vote in the consensus, and apparently times out frequently.
+
+ .. versionchanged:: 1.8.0
+ Serge has replaced Bifroest as our bridge authority. Avoiding descriptor
+ downloads from it instead.
+
+ .. versionchanged:: 1.8.0
+ Defaulting to gzip compression rather than plaintext downloads.
+
+ .. versionchanged:: 1.8.0
+ Using :class:`~stem.descriptor.__init__.Compression` for our compression
+ argument.
+
:var str resource: resource being fetched, such as '/tor/server/all'
:var str descriptor_type: type of descriptors being fetched (for options see
:func:`~stem.descriptor.__init__.parse_file`), this is guessed from the
@@ -327,9 +359,15 @@ class AsyncQuery(object):
:var float timeout: duration before we'll time out our request
:var str download_url: last url used to download the descriptor, this is
unset until we've actually made a download attempt
+
+ :param start: start making the request when constructed (default is **True**)
+ :param block: only return after the request has been completed, this is
+ the same as running **query.run(True)** (default is **False**)
"""
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+ super(Query, self).__init__()
+
if not resource.startswith('/'):
raise ValueError("Resources should start with a '/': %s" % resource)
@@ -388,6 +426,12 @@ class AsyncQuery(object):
self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()
+ if start:
+ self.start()
+
+ if block:
+ self.run(True)
+
async def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
@@ -398,12 +442,14 @@ class AsyncQuery(object):
loop = asyncio.get_running_loop()
self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
- async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+ async def run(self, suppress: bool = False, close: bool = True) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
:param suppress: avoids raising exceptions if **True**
+ :param close: terminates the resources backing this query if **True**,
+ further method calls will raise a RuntimeError
:returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
@@ -416,7 +462,15 @@ class AsyncQuery(object):
* :class:`~stem.DownloadFailed` if our request fails
"""
- return [desc async for desc in self._run(suppress)]
+ # TODO: We should replace our 'close' argument with a new API design prior
+ # to release. Self-destructing this object by default for synchronous users
+ # is quite a step backward, but is acceptable as we iterate on this.
+
+ try:
+ return [desc async for desc in self._run(suppress)]
+ finally:
+ if close:
+ self._loop.call_soon_threadsafe(self._loop.stop)
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
@@ -544,271 +598,6 @@ class AsyncQuery(object):
raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
-class Query(stem.util.AsyncClassWrapper):
- """
- Asynchronous request for descriptor content from a directory authority or
- mirror. These can either be made through the
- :class:`~stem.descriptor.remote.DescriptorDownloader` or directly for more
- advanced usage.
-
- To block on the response and get results either call
- :func:`~stem.descriptor.remote.Query.run` or iterate over the Query. The
- :func:`~stem.descriptor.remote.Query.run` method pass along any errors that
- arise...
-
- ::
-
- from stem.descriptor.remote import Query
-
- print('Current relays:')
-
- try:
- for desc in Query('/tor/server/all', 'server-descriptor 1.0').run():
- print(desc.fingerprint)
- except Exception as exc:
- print('Unable to retrieve the server descriptors: %s' % exc)
-
- ... while iterating fails silently...
-
- ::
-
- print('Current relays:')
-
- for desc in Query('/tor/server/all', 'server-descriptor 1.0'):
- print(desc.fingerprint)
-
- In either case exceptions are available via our 'error' attribute.
-
- Tor provides quite a few different descriptor resources via its directory
- protocol (see section 4.2 and later of the `dir-spec
- <https://gitweb.torproject.org/torspec.git/tree/dir-spec.txt>`_).
- Commonly useful ones include...
-
- =============================================== ===========
- Resource Description
- =============================================== ===========
- /tor/server/all all present server descriptors
- /tor/server/fp/<fp1>+<fp2>+<fp3> server descriptors with the given fingerprints
- /tor/extra/all all present extrainfo descriptors
- /tor/extra/fp/<fp1>+<fp2>+<fp3> extrainfo descriptors with the given fingerprints
- /tor/micro/d/<hash1>-<hash2> microdescriptors with the given hashes
- /tor/status-vote/current/consensus present consensus
- /tor/status-vote/current/consensus-microdesc present microdescriptor consensus
- /tor/status-vote/next/bandwidth bandwidth authority heuristics for the next consenus
- /tor/status-vote/next/consensus-signatures detached signature, used for making the next consenus
- /tor/keys/all key certificates for the authorities
- /tor/keys/fp/<v3ident1>+<v3ident2> key certificates for specific authorities
- =============================================== ===========
-
- **ZSTD** compression requires `zstandard
- <https://pypi.org/project/zstandard/>`_, and **LZMA** requires the `lzma
- module <https://docs.python.org/3/library/lzma.html>`_.
-
- For legacy reasons if our resource has a '.z' suffix then our **compression**
- argument is overwritten with Compression.GZIP.
-
- .. versionchanged:: 1.7.0
- Added support for downloading from ORPorts.
-
- .. versionchanged:: 1.7.0
- Added the compression argument.
-
- .. versionchanged:: 1.7.0
- Added the reply_headers attribute.
-
- The class this provides changed between Python versions. In python2
- this was called httplib.HTTPMessage, whereas in python3 the class was
- renamed to http.client.HTTPMessage.
-
- .. versionchanged:: 1.7.0
- Avoid downloading from tor26. This directory authority throttles its
- DirPort to such an extent that requests either time out or take on the
- order of minutes.
-
- .. versionchanged:: 1.7.0
- Avoid downloading from Bifroest. This is the bridge authority so it
- doesn't vote in the consensus, and apparently times out frequently.
-
- .. versionchanged:: 1.8.0
- Serge has replaced Bifroest as our bridge authority. Avoiding descriptor
- downloads from it instead.
-
- .. versionchanged:: 1.8.0
- Defaulting to gzip compression rather than plaintext downloads.
-
- .. versionchanged:: 1.8.0
- Using :class:`~stem.descriptor.__init__.Compression` for our compression
- argument.
-
- :var str resource: resource being fetched, such as '/tor/server/all'
- :var str descriptor_type: type of descriptors being fetched (for options see
- :func:`~stem.descriptor.__init__.parse_file`), this is guessed from the
- resource if **None**
-
- :var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the
- authority or mirror we're querying, this uses authorities if undefined
- :var list compression: list of :data:`stem.descriptor.Compression`
- we're willing to accept, when none are mutually supported downloads fall
- back to Compression.PLAINTEXT
- :var int retries: number of times to attempt the request if downloading it
- fails
- :var bool fall_back_to_authority: when retrying request issues the last
- request to a directory authority if **True**
-
- :var str content: downloaded descriptor content
- :var Exception error: exception if a problem occured
- :var bool is_done: flag that indicates if our request has finished
-
- :var float start_time: unix timestamp when we first started running
- :var http.client.HTTPMessage reply_headers: headers provided in the response,
- **None** if we haven't yet made our request
- :var float runtime: time our query took, this is **None** if it's not yet
- finished
-
- :var bool validate: checks the validity of the descriptor's content if
- **True**, skips these checks otherwise
- :var stem.descriptor.__init__.DocumentHandler document_handler: method in
- which to parse a :class:`~stem.descriptor.networkstatus.NetworkStatusDocument`
- :var dict kwargs: additional arguments for the descriptor constructor
-
- Following are only applicable when downloading from a
- :class:`~stem.DirPort`...
-
- :var float timeout: duration before we'll time out our request
- :var str download_url: last url used to download the descriptor, this is
- unset until we've actually made a download attempt
-
- :param start: start making the request when constructed (default is **True**)
- :param block: only return after the request has been completed, this is
- the same as running **query.run(True)** (default is **False**)
- """
-
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
- self._loop = asyncio.new_event_loop()
- self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio')
- self._loop_thread.setDaemon(True)
- self._loop_thread.start()
-
- self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore
- AsyncQuery,
- resource,
- descriptor_type,
- endpoints,
- compression,
- retries,
- fall_back_to_authority,
- timeout,
- validate,
- document_handler,
- **kwargs,
- )
-
- if start:
- self.start()
-
- if block:
- self.run(True)
-
- def start(self) -> None:
- """
- Starts downloading the scriptors if we haven't started already.
- """
-
- self._execute_async_method('start')
-
- def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
- """
- Blocks until our request is complete then provides the descriptors. If we
- haven't yet started our request then this does so.
-
- :param suppress: avoids raising exceptions if **True**
-
- :returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
-
- :raises:
- Using the iterator can fail with the following if **suppress** is
- **False**...
-
- * **ValueError** if the descriptor contents is malformed
- * :class:`~stem.DownloadTimeout` if our request timed out
- * :class:`~stem.DownloadFailed` if our request fails
- """
-
- return self._execute_async_method('run', suppress)
-
- def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
- for desc in self._execute_async_generator_method('__aiter__'):
- yield desc
-
- @property
- def descriptor_type(self) -> str:
- return self._wrapped_instance.descriptor_type
-
- @property
- def endpoints(self) -> List[Union[stem.ORPort, stem.DirPort]]:
- return self._wrapped_instance.endpoints
-
- @property
- def resource(self) -> str:
- return self._wrapped_instance.resource
-
- @property
- def compression(self) -> List[stem.descriptor._Compression]:
- return self._wrapped_instance.compression
-
- @property
- def retries(self) -> int:
- return self._wrapped_instance.retries
-
- @property
- def fall_back_to_authority(self) -> bool:
- return self._wrapped_instance.fall_back_to_authority
-
- @property
- def content(self) -> Optional[bytes]:
- return self._wrapped_instance.content
-
- @property
- def error(self) -> Optional[BaseException]:
- return self._wrapped_instance.error
-
- @property
- def is_done(self) -> bool:
- return self._wrapped_instance.is_done
-
- @property
- def download_url(self) -> Optional[str]:
- return self._wrapped_instance.download_url
-
- @property
- def start_time(self) -> Optional[float]:
- return self._wrapped_instance.start_time
-
- @property
- def timeout(self) -> Optional[float]:
- return self._wrapped_instance.timeout
-
- @property
- def runtime(self) -> Optional[float]:
- return self._wrapped_instance.runtime
-
- @property
- def validate(self) -> bool:
- return self._wrapped_instance.validate
-
- @property
- def document_handler(self) -> stem.descriptor.DocumentHandler:
- return self._wrapped_instance.document_handler
-
- @property
- def reply_headers(self) -> Optional[Dict[str, str]]:
- return self._wrapped_instance.reply_headers
-
- @property
- def kwargs(self) -> Dict[str, Any]:
- return self._wrapped_instance.kwargs
-
-
class DescriptorDownloader(object):
"""
Configurable class that issues :class:`~stem.descriptor.remote.Query`
@@ -848,7 +637,7 @@ class DescriptorDownloader(object):
directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST]
new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])
- consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]
+ consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0] # type: ignore
for desc in consensus.routers.values():
if stem.Flag.V2DIR in desc.flags and desc.dir_port:
@@ -858,7 +647,7 @@ class DescriptorDownloader(object):
self._endpoints = list(new_endpoints)
- return consensus # type: ignore
+ return consensus
def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
@@ -1013,7 +802,7 @@ class DescriptorDownloader(object):
# authority key certificates
if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT:
- consensus = list(consensus_query.run())[0]
+ consensus = list(consensus_query.run())[0] # type: ignore
key_certs = self.get_key_certificates(**query_args).run()
try:
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 72239273..e8ef361e 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -12,7 +12,7 @@ import inspect
import threading
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Iterator, Type, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Type, Union
__all__ = [
'conf',
@@ -116,7 +116,7 @@ def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.
raise ValueError('Key must be a string or cryptographic public/private key (was %s)' % type(key).__name__)
-def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
+def _hash_attr(obj: Any, *attributes: str, **kwargs: Any) -> int:
"""
Provide a hash value for the given set of attributes.
@@ -124,6 +124,8 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
:param attributes: attribute names to take into account
:param cache: persists hash in a '_cached_hash' object attribute
:param parent: include parent's hash value
+
+ :returns: **int** object hash
"""
is_cached = kwargs.get('cache', False)
@@ -174,11 +176,11 @@ class Synchronous(object):
finished to clean up underlying resources.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._loop = asyncio.new_event_loop()
self._loop_lock = threading.RLock()
self._loop_thread = threading.Thread(
- name = '%s asyncio' % self.__class__.__name__,
+ name = '%s asyncio' % type(self).__name__,
target = self._loop.run_forever,
daemon = True,
)
@@ -188,7 +190,7 @@ class Synchronous(object):
# overwrite asynchronous class methods with instance methods that can be
# called from either context
- def wrap(func, *args, **kwargs):
+ def wrap(func: Callable, *args: Any, **kwargs: Any) -> Any:
if Synchronous.is_asyncio_context():
return func(*args, **kwargs)
else:
@@ -204,7 +206,7 @@ class Synchronous(object):
if inspect.iscoroutinefunction(func):
setattr(self, method_name, functools.partial(wrap, func))
- def close(self):
+ def close(self) -> None:
"""
Terminate resources that permits this from being callable from synchronous
contexts. Once called any further synchronous invocations will fail with a
@@ -219,7 +221,7 @@ class Synchronous(object):
self._is_closed = True
@staticmethod
- def is_asyncio_context():
+ def is_asyncio_context() -> bool:
"""
Check if running within a synchronous or asynchronous context.
@@ -232,6 +234,18 @@ class Synchronous(object):
except RuntimeError:
return False
+ def __iter__(self) -> Iterator:
+ async def convert_async_generator(generator: AsyncIterator) -> Iterator:
+ return iter([d async for d in generator])
+
+ iter_func = getattr(self, '__aiter__')
+
+ if iter_func:
+ with self._loop_lock:
+ return asyncio.run_coroutine_threadsafe(convert_async_generator(iter_func()), self._loop).result()
+ else:
+ raise TypeError("'%s' object is not iterable" % type(self).__name__)
+
class AsyncClassWrapper:
_loop: asyncio.AbstractEventLoop
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 797bc8a3..1fd2aaf9 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -70,7 +70,7 @@ def mock_download(descriptor, encoding = 'identity', response_code_header = None
data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
- return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
+ return patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = coro_func_returning_value(data)))
class TestDescriptorDownloader(unittest.TestCase):
@@ -100,6 +100,8 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
+ reply.close()
+
def test_response_header_code(self):
"""
When successful Tor provides a '200 OK' status, but we should accept other 2xx
@@ -133,7 +135,7 @@ class TestDescriptorDownloader(unittest.TestCase):
def test_reply_header_data(self):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
self.assertEqual(None, query.reply_headers) # initially we don't have a reply
- query.run()
+ query.run(close = False)
self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
@@ -148,11 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase):
descriptors = list(query)
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
+ query.close()
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
+ query.close()
@mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
@@ -160,12 +164,15 @@ class TestDescriptorDownloader(unittest.TestCase):
Download a plaintext descriptor.
"""
- descriptors = list(stem.descriptor.remote.get_server_descriptors(
+ query = stem.descriptor.remote.get_server_descriptors(
'9695DFC35FFEB861329B9F1AB04C46397020CE31',
compression = Compression.PLAINTEXT,
validate = True,
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- ))
+ )
+
+ descriptors = list(query)
+ query.close()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -176,12 +183,15 @@ class TestDescriptorDownloader(unittest.TestCase):
Download a gip compressed descriptor.
"""
- descriptors = list(stem.descriptor.remote.get_server_descriptors(
+ query = stem.descriptor.remote.get_server_descriptors(
'9695DFC35FFEB861329B9F1AB04C46397020CE31',
compression = Compression.GZIP,
validate = True,
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
- ))
+ )
+
+ descriptors = list(query)
+ query.close()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -195,11 +205,14 @@ class TestDescriptorDownloader(unittest.TestCase):
if not Compression.ZSTD.available:
self.skipTest('(requires zstd module)')
- descriptors = list(stem.descriptor.remote.get_server_descriptors(
+ query = stem.descriptor.remote.get_server_descriptors(
'9695DFC35FFEB861329B9F1AB04C46397020CE31',
compression = Compression.ZSTD,
validate = True,
- ))
+ )
+
+ descriptors = list(query)
+ query.close()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -213,11 +226,14 @@ class TestDescriptorDownloader(unittest.TestCase):
if not Compression.LZMA.available:
self.skipTest('(requires lzma module)')
- descriptors = list(stem.descriptor.remote.get_server_descriptors(
+ query = stem.descriptor.remote.get_server_descriptors(
'9695DFC35FFEB861329B9F1AB04C46397020CE31',
compression = Compression.LZMA,
validate = True,
- ))
+ )
+
+ descriptors = list(query)
+ query.close()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -228,16 +244,21 @@ class TestDescriptorDownloader(unittest.TestCase):
Surface level exercising of each getter method for downloading descriptors.
"""
+ queries = []
+
downloader = stem.descriptor.remote.get_instance()
- downloader.get_server_descriptors()
- downloader.get_extrainfo_descriptors()
- downloader.get_microdescriptors('test-hash')
- downloader.get_consensus()
- downloader.get_vote(stem.directory.Authority.from_cache()['moria1'])
- downloader.get_key_certificates()
- downloader.get_bandwidth_file()
- downloader.get_detached_signatures()
+ queries.append(downloader.get_server_descriptors())
+ queries.append(downloader.get_extrainfo_descriptors())
+ queries.append(downloader.get_microdescriptors('test-hash'))
+ queries.append(downloader.get_consensus())
+ queries.append(downloader.get_vote(stem.directory.Authority.from_cache()['moria1']))
+ queries.append(downloader.get_key_certificates())
+ queries.append(downloader.get_bandwidth_file())
+ queries.append(downloader.get_detached_signatures())
+
+ for query in queries:
+ query.close()
@mock_download(b'some malformed stuff')
def test_malformed_content(self):
@@ -264,6 +285,8 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
+ query.close()
+
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
'hello': "'h' is a str.",
@@ -292,3 +315,5 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
+
+ query.close()
1
0