commit 675f49fcbe6dc1a52d10215c07adff56001faa70 Author: Damian Johnson atagar@torproject.org Date: Sat May 30 17:21:57 2020 -0700
Drop ThreadForWrappedAsyncClass
To call an asynchronous function we require a loop and thread for it to run within...
def call(my_async_function): loop = asyncio.get_event_loop() loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio') loop_thread.setDaemon(True) loop_thread.start()
result = asyncio.run_coroutine_threadsafe(my_async_function, loop)
loop.call_soon_threadsafe(loop.stop) loop_thread.join()
return result
ThreadForWrappedAsyncClass bundled these together, but I found it more confusing than helpful. These threads failed to clean themselves up, causing 'lingering thread' notifications when we run our tests. --- stem/connection.py | 30 ++++++++++++++++++------------ stem/control.py | 21 ++++++++++----------- stem/descriptor/remote.py | 7 +++++-- stem/util/__init__.py | 33 +++++++++++---------------------- test/integ/connection/authentication.py | 2 +- test/runner.py | 18 +++++++++--------- test/unit/control/controller.py | 2 +- 7 files changed, 55 insertions(+), 58 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py index 213ba010..86d32d7f 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -159,7 +159,7 @@ import stem.util.str_tools import stem.util.system import stem.version
-from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN') @@ -271,18 +271,24 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') if controller is None or not issubclass(controller, stem.control.Controller): raise ValueError('Controller should be a stem.control.Controller subclass.')
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass() - async_controller_thread.start() + loop = asyncio.new_event_loop() + loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio') + loop_thread.setDaemon(True) + loop_thread.start()
- connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller) try: - connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result() - if connection is None and async_controller_thread.is_alive(): - async_controller_thread.join() + connection = asyncio.run_coroutine_threadsafe(_connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller), loop).result() + + if connection is None and loop_thread.is_alive(): + loop.call_soon_threadsafe(loop.stop) + loop_thread.join() + return connection except: - if async_controller_thread.is_alive(): - async_controller_thread.join() + if loop_thread.is_alive(): + loop.call_soon_threadsafe(loop.stop) + loop_thread.join() + raise
@@ -399,10 +405,10 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None: return control_socket - elif issubclass(controller, stem.control.BaseController): + elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller): + # TODO: Controller no longer extends BaseController (we'll probably change that) + return controller(control_socket, is_authenticated = True) - elif issubclass(controller, stem.control.Controller): - return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread())) except IncorrectSocketType: if isinstance(control_socket, stem.socket.ControlPort): print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port)) diff --git a/stem/control.py b/stem/control.py index 084976ad..47ddaa35 100644 --- a/stem/control.py +++ b/stem/control.py @@ -3941,13 +3941,17 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False, - started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None, ) -> None: - if started_async_controller_thread: - self._thread_for_wrapped_class = started_async_controller_thread - else: - self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() - self._thread_for_wrapped_class.start() + # if within an asyncio context use its loop, otherwise spawn our own + + try: + self._loop = asyncio.get_running_loop() + self._loop_thread = threading.current_thread() + except RuntimeError: + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio') + self._loop_thread.setDaemon(True) + self._loop_thread.start()
self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore self._socket = self._wrapped_instance._socket @@ -4212,11 +4216,6 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper): def drop_guards(self) -> None: self._execute_async_method('drop_guards')
- def __del__(self) -> None: - loop = self._thread_for_wrapped_class.loop - if loop.is_running(): - loop.call_soon_threadsafe(loop.stop) - def __enter__(self) -> 'stem.control.Controller': return self
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index e7ccaa24..c23ab7a9 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -663,8 +663,11 @@ class Query(stem.util.AsyncClassWrapper): """
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: - self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() - self._thread_for_wrapped_class.start() + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio') + self._loop_thread.setDaemon(True) + self._loop_thread.start() + self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore AsyncQuery, resource, diff --git a/stem/util/__init__.py b/stem/util/__init__.py index a90aa7ac..25282b99 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -10,7 +10,7 @@ import datetime import threading from concurrent.futures import Future
-from typing import Any, AsyncIterator, Iterator, Optional, Type, Union +from typing import Any, AsyncIterator, Iterator, Type, Union
__all__ = [ 'conf', @@ -144,41 +144,26 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any): return my_hash
-class ThreadForWrappedAsyncClass(threading.Thread): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, *kwargs) - self.loop = asyncio.new_event_loop() - self.setDaemon(True) - - def run(self) -> None: - self.loop.run_forever() - - def join(self, timeout: Optional[float] = None) -> None: - self.loop.call_soon_threadsafe(self.loop.stop) - super().join(timeout) - self.loop.close() - - class AsyncClassWrapper: - _thread_for_wrapped_class: ThreadForWrappedAsyncClass + _loop: asyncio.AbstractEventLoop + _loop_thread: threading.Thread _wrapped_instance: type
def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any: - thread = self._thread_for_wrapped_class # The asynchronous class should be initialized in the thread where # its methods will be executed. - if thread != threading.current_thread(): + if self._loop_thread != threading.current_thread(): async def init(): return async_class(*args, **kwargs)
- return asyncio.run_coroutine_threadsafe(init(), thread.loop).result() + return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
return async_class(*args, **kwargs)
def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future: return asyncio.run_coroutine_threadsafe( getattr(self._wrapped_instance, method_name)(*args, **kwargs), - self._thread_for_wrapped_class.loop, + self._loop, )
def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: @@ -192,5 +177,9 @@ class AsyncClassWrapper: convert_async_generator( getattr(self._wrapped_instance, method_name)(*args, **kwargs), ), - self._thread_for_wrapped_class.loop, + self._loop, ).result() + + def __del__(self) -> None: + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join() diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py index 3709c275..042d3939 100644 --- a/test/integ/connection/authentication.py +++ b/test/integ/connection/authentication.py @@ -124,7 +124,7 @@ class TestAuthenticate(unittest.TestCase): with runner.get_tor_controller(False) as controller: asyncio.run_coroutine_threadsafe( stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()), - controller._thread_for_wrapped_class.loop, + controller._loop, ).result() await test.runner.exercise_controller(self, controller)
diff --git a/test/runner.py b/test/runner.py index 4f237552..b132b8f5 100644 --- a/test/runner.py +++ b/test/runner.py @@ -488,16 +488,16 @@ class Runner(object): :raises: :class: `test.runner.TorInaccessable` if tor can't be connected to """
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass() - async_controller_thread.start() + loop = asyncio.new_event_loop() + loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller') + loop_thread.setDaemon(True) + loop_thread.start()
- try: - control_socket = asyncio.run_coroutine_threadsafe(self.get_tor_socket(False), async_controller_thread.loop).result() - controller = stem.control.Controller(control_socket, started_async_controller_thread = async_controller_thread) - except Exception: - if async_controller_thread.is_alive(): - async_controller_thread.join() - raise + async def wrapped_get_controller(): + control_socket = await self.get_tor_socket(False) + return stem.control.Controller(control_socket) + + controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
if authenticate: self._authenticate_controller(controller) diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index 850918f7..84fcdfed 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -751,7 +751,7 @@ class TestControl(unittest.TestCase): with patch('time.time', Mock(return_value = TEST_TIMESTAMP)): with patch('stem.control.AsyncController.is_alive') as is_alive_mock: is_alive_mock.return_value = True - loop = self.controller._thread_for_wrapped_class.loop + loop = self.controller._loop asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
tor-commits@lists.torproject.org