[tor-commits] [stem/master] Prepare to creating and wrapping one more asynchronous class

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020


commit 79e8c1b47c63dfd49b62ec47c1b7902f51b06a83
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Thu May 14 00:06:49 2020 +0300

    Prepare to creating and wrapping one more asynchronous class
---
 stem/connection.py               |   2 +-
 stem/control.py                  | 108 ++++++++-------------------------------
 stem/interpreter/__init__.py     |   2 +-
 stem/interpreter/commands.py     |   5 +-
 stem/util/__init__.py            |  81 +++++++++++++++++++++++++++++
 test/integ/control/controller.py |   2 +-
 test/runner.py                   |   2 +-
 test/unit/control/controller.py  |   4 +-
 8 files changed, 110 insertions(+), 96 deletions(-)

diff --git a/stem/connection.py b/stem/connection.py
index c44fddb1..8f57f3b3 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -257,7 +257,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
   if controller is None or not issubclass(controller, stem.control.Controller):
     raise ValueError('Controller should be a stem.control.BaseController subclass.')
 
-  async_controller_thread = stem.control._AsyncControllerThread()
+  async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
   async_controller_thread.start()
 
   connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
diff --git a/stem/control.py b/stem/control.py
index 6de671b6..1488621a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -553,29 +553,6 @@ def event_description(event: str) -> str:
   return EVENT_DESCRIPTIONS.get(event.lower())
 
 
-class _MsgLock:
-  __slots__ = ('_r_lock', '_async_lock')
-
-  def __init__(self):
-    self._r_lock = threading.RLock()
-    self._async_lock = asyncio.Lock()
-
-  async def acquire(self):
-    await self._async_lock.acquire()
-    self._r_lock.acquire()
-
-  def release(self):
-    self._r_lock.release()
-    self._async_lock.release()
-
-  async def __aenter__(self):
-    await self.acquire()
-    return self
-
-  async def __aexit__(self, exc_type, exc_val, exc_tb):
-    self.release()
-
-
 class _BaseControllerSocketMixin:
   def is_alive(self) -> bool:
     """
@@ -644,7 +621,7 @@ class BaseController(_BaseControllerSocketMixin):
 
     self._asyncio_loop = asyncio.get_event_loop()
 
-    self._msg_lock = _MsgLock()
+    self._msg_lock = stem.util.CombinedReentrantAndAsyncioLock()
 
     self._status_listeners = []  # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
     self._status_listeners_lock = threading.RLock()
@@ -3901,22 +3878,7 @@ class AsyncController(_ControllerClassMethodMixin, BaseController):
     return (set_events, failed_events)
 
 
-class _AsyncControllerThread(threading.Thread):
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, *kwargs)
-    self.loop = asyncio.new_event_loop()
-    self.setDaemon(True)
-
-  def run(self):
-    self.loop.run_forever()
-
-  def join(self, timeout = None):
-    self.loop.call_soon_threadsafe(self.loop.stop)
-    super().join(timeout)
-    self.loop.close()
-
-
-class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
+class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
   @classmethod
   def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
     instance = super().from_port(address, port)
@@ -3932,48 +3894,19 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
   def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
   def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
     if started_async_controller_thread:
-      self._async_controller_thread = started_async_controller_thread
+      self._thread_for_wrapped_class = started_async_controller_thread
     else:
-      self._async_controller_thread = _AsyncControllerThread()
-      self._async_controller_thread.start()
-    self._asyncio_loop = self._async_controller_thread.loop
-
-    self._async_controller = self._init_async_controller(control_socket, is_authenticated)
-    self._socket = self._async_controller._socket
-
-  def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController':
-    # The asynchronous controller should be initialized in the thread where its
-    # methods will be executed.
-    if self._async_controller_thread != threading.current_thread():
-      async def init_async_controller() -> 'stem.control.AsyncController':
-        return AsyncController(control_socket, is_authenticated)
-
-      return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
-
-    return AsyncController(control_socket, is_authenticated)
-
-  def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
-    return asyncio.run_coroutine_threadsafe(
-      getattr(self._async_controller, method_name)(*args, **kwargs),
-      self._asyncio_loop,
-    ).result()
-
-  def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
-    async def convert_async_generator(generator):
-      return iter([d async for d in generator])
+      self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+      self._thread_for_wrapped_class.start()
 
-    return asyncio.run_coroutine_threadsafe(
-      convert_async_generator(
-        getattr(self._async_controller, method_name)(*args, **kwargs),
-      ),
-      self._asyncio_loop,
-    ).result()
+    self._wrapped_instance = self._init_async_class(AsyncController, control_socket, is_authenticated)
+    self._socket = self._wrapped_instance._socket
 
   def msg(self, message: str) -> stem.response.ControlMessage:
     return self._execute_async_method('msg', message)
 
   def is_authenticated(self) -> bool:
-    return self._async_controller.is_authenticated()
+    return self._wrapped_instance.is_authenticated()
 
   def connect(self) -> None:
     self._execute_async_method('connect')
@@ -3985,13 +3918,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     self._execute_async_method('close')
 
   def get_latest_heartbeat(self) -> float:
-    return self._async_controller.get_latest_heartbeat()
+    return self._wrapped_instance.get_latest_heartbeat()
 
   def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
-    self._async_controller.add_status_listener(callback, spawn)
+    self._wrapped_instance.add_status_listener(callback, spawn)
 
   def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
-    self._async_controller.remove_status_listener(callback)
+    self._wrapped_instance.remove_status_listener(callback)
 
   def authenticate(self, *args: Any, **kwargs: Any) -> None:
     self._execute_async_method('authenticate', *args, **kwargs)
@@ -4099,13 +4032,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     self._execute_async_method('remove_event_listener', listener)
 
   def is_caching_enabled(self) -> bool:
-    return self._async_controller.is_caching_enabled()
+    return self._wrapped_instance.is_caching_enabled()
 
   def set_caching(self, enabled: bool) -> None:
-    self._async_controller.set_caching(enabled)
+    self._wrapped_instance.set_caching(enabled)
 
   def clear_cache(self) -> None:
-    self._async_controller.clear_cache()
+    self._wrapped_instance.clear_cache()
 
   def load_conf(self, configtext: str) -> None:
     self._execute_async_method('load_conf', configtext)
@@ -4114,10 +4047,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     return self._execute_async_method('save_conf', force)
 
   def is_feature_enabled(self, feature: str) -> bool:
-    return self._async_controller.is_feature_enabled(feature)
+    return self._wrapped_instance.is_feature_enabled(feature)
 
   def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
-    self._async_controller.enable_feature(features)
+    self._wrapped_instance.enable_feature(features)
 
   def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
     return self._execute_async_method('get_circuit', circuit_id, default)
@@ -4150,10 +4083,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     self._execute_async_method('signal', signal)
 
   def is_newnym_available(self) -> bool:
-    return self._async_controller.is_newnym_available()
+    return self._wrapped_instance.is_newnym_available()
 
   def get_newnym_wait(self) -> float:
-    return self._async_controller.get_newnym_wait()
+    return self._wrapped_instance.get_newnym_wait()
 
   def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
     return self._execute_async_method('get_effective_rate', default, burst)
@@ -4165,8 +4098,9 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     self._execute_async_method('drop_guards')
 
   def __del__(self) -> None:
-    if self._asyncio_loop.is_running():
-      self._asyncio_loop.call_soon_threadsafe(self._asyncio_loop.stop)
+    loop = self._thread_for_wrapped_class.loop
+    if loop.is_running():
+      loop.call_soon_threadsafe(loop.stop)
 
   def __enter__(self) -> 'stem.control.Controller':
     return self
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 07353d44..ae064a0a 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -127,7 +127,7 @@ def main() -> None:
         async def handle_event(event_message):
           print(format(str(event_message), *STANDARD_OUTPUT))
 
-        controller._async_controller._handle_event = handle_event
+        controller._wrapped_instance._handle_event = handle_event
 
         if sys.stdout.isatty():
           events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 0d262ab5..edbcca70 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole):
     # Intercept events our controller hears about at a pretty low level since
     # the user will likely be requesting them by direct 'SETEVENTS' calls.
 
-    handle_event_real = self._controller._async_controller._handle_event
+    handle_event_real = self._controller._wrapped_instance._handle_event
 
     async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
       await handle_event_real(event_message)
@@ -139,8 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
 
     # type check disabled due to https://github.com/python/mypy/issues/708
 
-    self._controller._async_controller._handle_event = handle_event_wrapper
-    self._controller._handle_event = handle_event_wrapper  # type: ignore
+    self._controller._wrapped_instance._handle_event = handle_event_wrapper
 
   def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
     events = list(self._received_events)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e4fa3ca8..a230cfbd 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -5,7 +5,9 @@
 Utility functions used by the stem library.
 """
 
+import asyncio
 import datetime
+import threading
 
 from typing import Any, Union
 
@@ -139,3 +141,82 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
     setattr(obj, '_cached_hash', my_hash)
 
   return my_hash
+
+
+class CombinedReentrantAndAsyncioLock:
+  """
+  Lock that combines thread-safe reentrant and not thread-safe asyncio locks.
+  """
+
+  __slots__ = ('_r_lock', '_async_lock')
+
+  def __init__(self):
+    self._r_lock = threading.RLock()
+    self._async_lock = asyncio.Lock()
+
+  async def acquire(self):
+    await self._async_lock.acquire()
+    self._r_lock.acquire()
+
+  def release(self):
+    self._r_lock.release()
+    self._async_lock.release()
+
+  async def __aenter__(self):
+    await self.acquire()
+    return self
+
+  async def __aexit__(self, exc_type, exc_val, exc_tb):
+    self.release()
+
+
+class ThreadForWrappedAsyncClass(threading.Thread):
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, *kwargs)
+    self.loop = asyncio.new_event_loop()
+    self.setDaemon(True)
+
+  def run(self):
+    self.loop.run_forever()
+
+  def join(self, timeout=None):
+    self.loop.call_soon_threadsafe(self.loop.stop)
+    super().join(timeout)
+    self.loop.close()
+
+
+class AsyncClassWrapper:
+  _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+  _wrapped_instance: type
+
+  def _init_async_class(self, async_class, *args, **kwargs):
+    thread = self._thread_for_wrapped_class
+    # The asynchronous class should be initialized in the thread where
+    # its methods will be executed.
+    if thread != threading.current_thread():
+      async def init():
+        return async_class(*args, **kwargs)
+
+      return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+
+    return async_class(*args, **kwargs)
+
+  def _call_async_method_soon(self, method_name, *args, **kwargs):
+    return asyncio.run_coroutine_threadsafe(
+      getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+      self._thread_for_wrapped_class.loop,
+    )
+
+  def _execute_async_method(self, method_name, *args, **kwargs):
+    return self._call_async_method_soon(method_name, *args, **kwargs).result()
+
+  def _execute_async_generator_method(self, method_name, *args, **kwargs):
+    async def convert_async_generator(generator):
+      return iter([d async for d in generator])
+
+    return asyncio.run_coroutine_threadsafe(
+      convert_async_generator(
+        getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+      ),
+      self._thread_for_wrapped_class.loop,
+    ).result()
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 73e71fa4..b1772f34 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -113,7 +113,7 @@ class TestController(unittest.TestCase):
 
       state_controller, state_type, state_timestamp = received_events[0]
 
-      self.assertEqual(controller._async_controller, state_controller)
+      self.assertEqual(controller._wrapped_instance, state_controller)
       self.assertEqual(State.RESET, state_type)
       self.assertTrue(state_timestamp > before and state_timestamp < after)
 
diff --git a/test/runner.py b/test/runner.py
index 4a38e824..189a2d7b 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,7 +488,7 @@ class Runner(object):
     :raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
     """
 
-    async_controller_thread = stem.control._AsyncControllerThread()
+    async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
     async_controller_thread.start()
 
     try:
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index e8ef4787..a11aba45 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -44,7 +44,7 @@ class TestControl(unittest.TestCase):
 
     with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
       self.controller = Controller(socket)
-      self.async_controller = self.controller._async_controller
+      self.async_controller = self.controller._wrapped_instance
 
       self.circ_listener = Mock()
       self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -748,7 +748,7 @@ class TestControl(unittest.TestCase):
     with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
       with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
         is_alive_mock.return_value = True
-        loop = self.controller._asyncio_loop
+        loop = self.controller._thread_for_wrapped_class.loop
         asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
 
         try:





More information about the tor-commits mailing list