commit 675f49fcbe6dc1a52d10215c07adff56001faa70
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sat May 30 17:21:57 2020 -0700
Drop ThreadForWrappedAsyncClass
To call an asynchronous function we require a loop and thread for it to run
within...
def call(my_async_function):
loop = asyncio.get_event_loop()
loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
loop_thread.setDaemon(True)
loop_thread.start()
result = asyncio.run_coroutine_threadsafe(my_async_function, loop)
loop.call_soon_threadsafe(loop.stop)
loop_thread.join()
return result
ThreadForWrappedAsyncClass bundled these together, but I found it more
confusing than helpful. These threads failed to clean themselves up,
causing 'lingering thread' notifications when we run our tests.
---
stem/connection.py | 30 ++++++++++++++++++------------
stem/control.py | 21 ++++++++++-----------
stem/descriptor/remote.py | 7 +++++--
stem/util/__init__.py | 33 +++++++++++----------------------
test/integ/connection/authentication.py | 2 +-
test/runner.py | 18 +++++++++---------
test/unit/control/controller.py | 2 +-
7 files changed, 55 insertions(+), 58 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index 213ba010..86d32d7f 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -159,7 +159,7 @@ import stem.util.str_tools
import stem.util.system
import stem.version
-from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -271,18 +271,24 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
if controller is None or not issubclass(controller, stem.control.Controller):
raise ValueError('Controller should be a stem.control.Controller subclass.')
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
- async_controller_thread.start()
+ loop = asyncio.new_event_loop()
+ loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
+ loop_thread.setDaemon(True)
+ loop_thread.start()
- connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
try:
- connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result()
- if connection is None and async_controller_thread.is_alive():
- async_controller_thread.join()
+ connection = asyncio.run_coroutine_threadsafe(_connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller), loop).result()
+
+ if connection is None and loop_thread.is_alive():
+ loop.call_soon_threadsafe(loop.stop)
+ loop_thread.join()
+
return connection
except:
- if async_controller_thread.is_alive():
- async_controller_thread.join()
+ if loop_thread.is_alive():
+ loop.call_soon_threadsafe(loop.stop)
+ loop_thread.join()
+
raise
@@ -399,10 +405,10 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None:
return control_socket
- elif issubclass(controller, stem.control.BaseController):
+ elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller):
+ # TODO: Controller no longer extends BaseController (we'll probably change that)
+
return controller(control_socket, is_authenticated = True)
- elif issubclass(controller, stem.control.Controller):
- return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread()))
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
diff --git a/stem/control.py b/stem/control.py
index 084976ad..47ddaa35 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3941,13 +3941,17 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
self,
control_socket: stem.socket.ControlSocket,
is_authenticated: bool = False,
- started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None,
) -> None:
- if started_async_controller_thread:
- self._thread_for_wrapped_class = started_async_controller_thread
- else:
- self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
- self._thread_for_wrapped_class.start()
+ # if within an asyncio context use its loop, otherwise spawn our own
+
+ try:
+ self._loop = asyncio.get_running_loop()
+ self._loop_thread = threading.current_thread()
+ except RuntimeError:
+ self._loop = asyncio.new_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop_thread.setDaemon(True)
+ self._loop_thread.start()
self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore
self._socket = self._wrapped_instance._socket
@@ -4212,11 +4216,6 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
def drop_guards(self) -> None:
self._execute_async_method('drop_guards')
- def __del__(self) -> None:
- loop = self._thread_for_wrapped_class.loop
- if loop.is_running():
- loop.call_soon_threadsafe(loop.stop)
-
def __enter__(self) -> 'stem.control.Controller':
return self
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e7ccaa24..c23ab7a9 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -663,8 +663,11 @@ class Query(stem.util.AsyncClassWrapper):
"""
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
- self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
- self._thread_for_wrapped_class.start()
+ self._loop = asyncio.get_event_loop()
+ self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+ self._loop_thread.setDaemon(True)
+ self._loop_thread.start()
+
self._wrapped_instance: AsyncQuery = self._init_async_class( # type: ignore
AsyncQuery,
resource,
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index a90aa7ac..25282b99 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,7 +10,7 @@ import datetime
import threading
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Type, Union
__all__ = [
'conf',
@@ -144,41 +144,26 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
return my_hash
-class ThreadForWrappedAsyncClass(threading.Thread):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, *kwargs)
- self.loop = asyncio.new_event_loop()
- self.setDaemon(True)
-
- def run(self) -> None:
- self.loop.run_forever()
-
- def join(self, timeout: Optional[float] = None) -> None:
- self.loop.call_soon_threadsafe(self.loop.stop)
- super().join(timeout)
- self.loop.close()
-
-
class AsyncClassWrapper:
- _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+ _loop: asyncio.AbstractEventLoop
+ _loop_thread: threading.Thread
_wrapped_instance: type
def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any:
- thread = self._thread_for_wrapped_class
# The asynchronous class should be initialized in the thread where
# its methods will be executed.
- if thread != threading.current_thread():
+ if self._loop_thread != threading.current_thread():
async def init():
return async_class(*args, **kwargs)
- return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+ return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
return async_class(*args, **kwargs)
def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future:
return asyncio.run_coroutine_threadsafe(
getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- self._thread_for_wrapped_class.loop,
+ self._loop,
)
def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
@@ -192,5 +177,9 @@ class AsyncClassWrapper:
convert_async_generator(
getattr(self._wrapped_instance, method_name)(*args, **kwargs),
),
- self._thread_for_wrapped_class.loop,
+ self._loop,
).result()
+
+ def __del__(self) -> None:
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 3709c275..042d3939 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -124,7 +124,7 @@ class TestAuthenticate(unittest.TestCase):
with runner.get_tor_controller(False) as controller:
asyncio.run_coroutine_threadsafe(
stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
- controller._thread_for_wrapped_class.loop,
+ controller._loop,
).result()
await test.runner.exercise_controller(self, controller)
diff --git a/test/runner.py b/test/runner.py
index 4f237552..b132b8f5 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,16 +488,16 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
- async_controller_thread.start()
+ loop = asyncio.new_event_loop()
+ loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller')
+ loop_thread.setDaemon(True)
+ loop_thread.start()
- try:
- control_socket = asyncio.run_coroutine_threadsafe(self.get_tor_socket(False), async_controller_thread.loop).result()
- controller = stem.control.Controller(control_socket, started_async_controller_thread = async_controller_thread)
- except Exception:
- if async_controller_thread.is_alive():
- async_controller_thread.join()
- raise
+ async def wrapped_get_controller():
+ control_socket = await self.get_tor_socket(False)
+ return stem.control.Controller(control_socket)
+
+ controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
if authenticate:
self._authenticate_controller(controller)
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 850918f7..84fcdfed 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -751,7 +751,7 @@ class TestControl(unittest.TestCase):
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
- loop = self.controller._thread_for_wrapped_class.loop
+ loop = self.controller._loop
asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try: