[tor-commits] [stem/master] Drop ThreadForWrappedAsyncClass

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 675f49fcbe6dc1a52d10215c07adff56001faa70
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat May 30 17:21:57 2020 -0700

    Drop ThreadForWrappedAsyncClass
    
    To call an asynchronous function we require a loop and thread for it to run
    within...
    
      def call(my_async_function):
        loop = asyncio.get_event_loop()
        loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
        loop_thread.setDaemon(True)
        loop_thread.start()
    
        result = asyncio.run_coroutine_threadsafe(my_async_function, loop)
    
        loop.call_soon_threadsafe(loop.stop)
        loop_thread.join()
    
        return result
    
    ThreadForWrappedAsyncClass bundled these together, but I found it more
    confusing than helpful. These threads failed to clean themselves up,
    causing 'lingering thread' notifications when we run our tests.
---
 stem/connection.py                      | 30 ++++++++++++++++++------------
 stem/control.py                         | 21 ++++++++++-----------
 stem/descriptor/remote.py               |  7 +++++--
 stem/util/__init__.py                   | 33 +++++++++++----------------------
 test/integ/connection/authentication.py |  2 +-
 test/runner.py                          | 18 +++++++++---------
 test/unit/control/controller.py         |  2 +-
 7 files changed, 55 insertions(+), 58 deletions(-)

diff --git a/stem/connection.py b/stem/connection.py
index 213ba010..86d32d7f 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -159,7 +159,7 @@ import stem.util.str_tools
 import stem.util.system
 import stem.version
 
-from typing import Any, cast, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, List, Optional, Sequence, Tuple, Type, Union
 from stem.util import log
 
 AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN')
@@ -271,18 +271,24 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
   if controller is None or not issubclass(controller, stem.control.Controller):
     raise ValueError('Controller should be a stem.control.Controller subclass.')
 
-  async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
-  async_controller_thread.start()
+  loop = asyncio.new_event_loop()
+  loop_thread = threading.Thread(target = loop.run_forever, name = 'asyncio')
+  loop_thread.setDaemon(True)
+  loop_thread.start()
 
-  connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
   try:
-    connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result()
-    if connection is None and async_controller_thread.is_alive():
-      async_controller_thread.join()
+    connection = asyncio.run_coroutine_threadsafe(_connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller), loop).result()
+
+    if connection is None and loop_thread.is_alive():
+      loop.call_soon_threadsafe(loop.stop)
+      loop_thread.join()
+
     return connection
   except:
-    if async_controller_thread.is_alive():
-      async_controller_thread.join()
+    if loop_thread.is_alive():
+      loop.call_soon_threadsafe(loop.stop)
+      loop_thread.join()
+
     raise
 
 
@@ -399,10 +405,10 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
 
     if controller is None:
       return control_socket
-    elif issubclass(controller, stem.control.BaseController):
+    elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller):
+      # TODO: Controller no longer extends BaseController (we'll probably change that)
+
       return controller(control_socket, is_authenticated = True)
-    elif issubclass(controller, stem.control.Controller):
-      return controller(control_socket, is_authenticated = True, started_async_controller_thread = cast(stem.util.ThreadForWrappedAsyncClass, threading.current_thread()))
   except IncorrectSocketType:
     if isinstance(control_socket, stem.socket.ControlPort):
       print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
diff --git a/stem/control.py b/stem/control.py
index 084976ad..47ddaa35 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3941,13 +3941,17 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
       self,
       control_socket: stem.socket.ControlSocket,
       is_authenticated: bool = False,
-      started_async_controller_thread: stem.util.ThreadForWrappedAsyncClass = None,
   ) -> None:
-    if started_async_controller_thread:
-      self._thread_for_wrapped_class = started_async_controller_thread
-    else:
-      self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
-      self._thread_for_wrapped_class.start()
+    # if within an asyncio context use its loop, otherwise spawn our own
+
+    try:
+      self._loop = asyncio.get_running_loop()
+      self._loop_thread = threading.current_thread()
+    except RuntimeError:
+      self._loop = asyncio.new_event_loop()
+      self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+      self._loop_thread.setDaemon(True)
+      self._loop_thread.start()
 
     self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated)  # type: ignore
     self._socket = self._wrapped_instance._socket
@@ -4212,11 +4216,6 @@ class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
   def drop_guards(self) -> None:
     self._execute_async_method('drop_guards')
 
-  def __del__(self) -> None:
-    loop = self._thread_for_wrapped_class.loop
-    if loop.is_running():
-      loop.call_soon_threadsafe(loop.stop)
-
   def __enter__(self) -> 'stem.control.Controller':
     return self
 
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e7ccaa24..c23ab7a9 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -663,8 +663,11 @@ class Query(stem.util.AsyncClassWrapper):
   """
 
   def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
-    self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
-    self._thread_for_wrapped_class.start()
+    self._loop = asyncio.get_event_loop()
+    self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+    self._loop_thread.setDaemon(True)
+    self._loop_thread.start()
+
     self._wrapped_instance: AsyncQuery = self._init_async_class(  # type: ignore
       AsyncQuery,
       resource,
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index a90aa7ac..25282b99 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,7 +10,7 @@ import datetime
 import threading
 from concurrent.futures import Future
 
-from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Type, Union
 
 __all__ = [
   'conf',
@@ -144,41 +144,26 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
   return my_hash
 
 
-class ThreadForWrappedAsyncClass(threading.Thread):
-  def __init__(self, *args: Any, **kwargs: Any) -> None:
-    super().__init__(*args, *kwargs)
-    self.loop = asyncio.new_event_loop()
-    self.setDaemon(True)
-
-  def run(self) -> None:
-    self.loop.run_forever()
-
-  def join(self, timeout: Optional[float] = None) -> None:
-    self.loop.call_soon_threadsafe(self.loop.stop)
-    super().join(timeout)
-    self.loop.close()
-
-
 class AsyncClassWrapper:
-  _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+  _loop: asyncio.AbstractEventLoop
+  _loop_thread: threading.Thread
   _wrapped_instance: type
 
   def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any:
-    thread = self._thread_for_wrapped_class
     # The asynchronous class should be initialized in the thread where
     # its methods will be executed.
-    if thread != threading.current_thread():
+    if self._loop_thread != threading.current_thread():
       async def init():
         return async_class(*args, **kwargs)
 
-      return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+      return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
 
     return async_class(*args, **kwargs)
 
   def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future:
     return asyncio.run_coroutine_threadsafe(
       getattr(self._wrapped_instance, method_name)(*args, **kwargs),
-      self._thread_for_wrapped_class.loop,
+      self._loop,
     )
 
   def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
@@ -192,5 +177,9 @@ class AsyncClassWrapper:
       convert_async_generator(
         getattr(self._wrapped_instance, method_name)(*args, **kwargs),
       ),
-      self._thread_for_wrapped_class.loop,
+      self._loop,
     ).result()
+
+  def __del__(self) -> None:
+    self._loop.call_soon_threadsafe(self._loop.stop)
+    self._loop_thread.join()
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 3709c275..042d3939 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -124,7 +124,7 @@ class TestAuthenticate(unittest.TestCase):
     with runner.get_tor_controller(False) as controller:
       asyncio.run_coroutine_threadsafe(
         stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
-        controller._thread_for_wrapped_class.loop,
+        controller._loop,
       ).result()
       await test.runner.exercise_controller(self, controller)
 
diff --git a/test/runner.py b/test/runner.py
index 4f237552..b132b8f5 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,16 +488,16 @@ class Runner(object):
     :raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
     """
 
-    async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
-    async_controller_thread.start()
+    loop = asyncio.new_event_loop()
+    loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller')
+    loop_thread.setDaemon(True)
+    loop_thread.start()
 
-    try:
-      control_socket = asyncio.run_coroutine_threadsafe(self.get_tor_socket(False), async_controller_thread.loop).result()
-      controller = stem.control.Controller(control_socket, started_async_controller_thread = async_controller_thread)
-    except Exception:
-      if async_controller_thread.is_alive():
-        async_controller_thread.join()
-      raise
+    async def wrapped_get_controller():
+      control_socket = await self.get_tor_socket(False)
+      return stem.control.Controller(control_socket)
+
+    controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
 
     if authenticate:
       self._authenticate_controller(controller)
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 850918f7..84fcdfed 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -751,7 +751,7 @@ class TestControl(unittest.TestCase):
     with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
       with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
         is_alive_mock.return_value = True
-        loop = self.controller._thread_for_wrapped_class.loop
+        loop = self.controller._loop
         asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
 
         try:





More information about the tor-commits mailing list