commit 79e8c1b47c63dfd49b62ec47c1b7902f51b06a83 Author: Illia Volochii illia.volochii@gmail.com Date: Thu May 14 00:06:49 2020 +0300
Prepare to creating and wrapping one more asynchronous class --- stem/connection.py | 2 +- stem/control.py | 108 ++++++++------------------------------- stem/interpreter/__init__.py | 2 +- stem/interpreter/commands.py | 5 +- stem/util/__init__.py | 81 +++++++++++++++++++++++++++++ test/integ/control/controller.py | 2 +- test/runner.py | 2 +- test/unit/control/controller.py | 4 +- 8 files changed, 110 insertions(+), 96 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py index c44fddb1..8f57f3b3 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -257,7 +257,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') if controller is None or not issubclass(controller, stem.control.Controller): raise ValueError('Controller should be a stem.control.BaseController subclass.')
- async_controller_thread = stem.control._AsyncControllerThread() + async_controller_thread = stem.util.ThreadForWrappedAsyncClass() async_controller_thread.start()
connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller) diff --git a/stem/control.py b/stem/control.py index 6de671b6..1488621a 100644 --- a/stem/control.py +++ b/stem/control.py @@ -553,29 +553,6 @@ def event_description(event: str) -> str: return EVENT_DESCRIPTIONS.get(event.lower())
-class _MsgLock: - __slots__ = ('_r_lock', '_async_lock') - - def __init__(self): - self._r_lock = threading.RLock() - self._async_lock = asyncio.Lock() - - async def acquire(self): - await self._async_lock.acquire() - self._r_lock.acquire() - - def release(self): - self._r_lock.release() - self._async_lock.release() - - async def __aenter__(self): - await self.acquire() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - self.release() - - class _BaseControllerSocketMixin: def is_alive(self) -> bool: """ @@ -644,7 +621,7 @@ class BaseController(_BaseControllerSocketMixin):
self._asyncio_loop = asyncio.get_event_loop()
- self._msg_lock = _MsgLock() + self._msg_lock = stem.util.CombinedReentrantAndAsyncioLock()
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread) self._status_listeners_lock = threading.RLock() @@ -3901,22 +3878,7 @@ class AsyncController(_ControllerClassMethodMixin, BaseController): return (set_events, failed_events)
-class _AsyncControllerThread(threading.Thread): - def __init__(self, *args, **kwargs): - super().__init__(*args, *kwargs) - self.loop = asyncio.new_event_loop() - self.setDaemon(True) - - def run(self): - self.loop.run_forever() - - def join(self, timeout = None): - self.loop.call_soon_threadsafe(self.loop.stop) - super().join(timeout) - self.loop.close() - - -class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): +class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.util.AsyncClassWrapper): @classmethod def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller': instance = super().from_port(address, port) @@ -3932,48 +3894,19 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None: def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None): if started_async_controller_thread: - self._async_controller_thread = started_async_controller_thread + self._thread_for_wrapped_class = started_async_controller_thread else: - self._async_controller_thread = _AsyncControllerThread() - self._async_controller_thread.start() - self._asyncio_loop = self._async_controller_thread.loop - - self._async_controller = self._init_async_controller(control_socket, is_authenticated) - self._socket = self._async_controller._socket - - def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController': - # The asynchronous controller should be initialized in the thread where its - # methods will be executed. - if self._async_controller_thread != threading.current_thread(): - async def init_async_controller() -> 'stem.control.AsyncController': - return AsyncController(control_socket, is_authenticated) - - return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result() - - return AsyncController(control_socket, is_authenticated) - - def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: - return asyncio.run_coroutine_threadsafe( - getattr(self._async_controller, method_name)(*args, **kwargs), - self._asyncio_loop, - ).result() - - def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: - async def convert_async_generator(generator): - return iter([d async for d in generator]) + self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass() + self._thread_for_wrapped_class.start()
- return asyncio.run_coroutine_threadsafe( - convert_async_generator( - getattr(self._async_controller, method_name)(*args, **kwargs), - ), - self._asyncio_loop, - ).result() + self._wrapped_instance = self._init_async_class(AsyncController, control_socket, is_authenticated) + self._socket = self._wrapped_instance._socket
def msg(self, message: str) -> stem.response.ControlMessage: return self._execute_async_method('msg', message)
def is_authenticated(self) -> bool: - return self._async_controller.is_authenticated() + return self._wrapped_instance.is_authenticated()
def connect(self) -> None: self._execute_async_method('connect') @@ -3985,13 +3918,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): self._execute_async_method('close')
def get_latest_heartbeat(self) -> float: - return self._async_controller.get_latest_heartbeat() + return self._wrapped_instance.get_latest_heartbeat()
def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None: - self._async_controller.add_status_listener(callback, spawn) + self._wrapped_instance.add_status_listener(callback, spawn)
def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool: - self._async_controller.remove_status_listener(callback) + self._wrapped_instance.remove_status_listener(callback)
def authenticate(self, *args: Any, **kwargs: Any) -> None: self._execute_async_method('authenticate', *args, **kwargs) @@ -4099,13 +4032,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): self._execute_async_method('remove_event_listener', listener)
def is_caching_enabled(self) -> bool: - return self._async_controller.is_caching_enabled() + return self._wrapped_instance.is_caching_enabled()
def set_caching(self, enabled: bool) -> None: - self._async_controller.set_caching(enabled) + self._wrapped_instance.set_caching(enabled)
def clear_cache(self) -> None: - self._async_controller.clear_cache() + self._wrapped_instance.clear_cache()
def load_conf(self, configtext: str) -> None: self._execute_async_method('load_conf', configtext) @@ -4114,10 +4047,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): return self._execute_async_method('save_conf', force)
def is_feature_enabled(self, feature: str) -> bool: - return self._async_controller.is_feature_enabled(feature) + return self._wrapped_instance.is_feature_enabled(feature)
def enable_feature(self, features: Union[str, Sequence[str]]) -> None: - self._async_controller.enable_feature(features) + self._wrapped_instance.enable_feature(features)
def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent: return self._execute_async_method('get_circuit', circuit_id, default) @@ -4150,10 +4083,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): self._execute_async_method('signal', signal)
def is_newnym_available(self) -> bool: - return self._async_controller.is_newnym_available() + return self._wrapped_instance.is_newnym_available()
def get_newnym_wait(self) -> float: - return self._async_controller.get_newnym_wait() + return self._wrapped_instance.get_newnym_wait()
def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int: return self._execute_async_method('get_effective_rate', default, burst) @@ -4165,8 +4098,9 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): self._execute_async_method('drop_guards')
def __del__(self) -> None: - if self._asyncio_loop.is_running(): - self._asyncio_loop.call_soon_threadsafe(self._asyncio_loop.stop) + loop = self._thread_for_wrapped_class.loop + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop)
def __enter__(self) -> 'stem.control.Controller': return self diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index 07353d44..ae064a0a 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -127,7 +127,7 @@ def main() -> None: async def handle_event(event_message): print(format(str(event_message), *STANDARD_OUTPUT))
- controller._async_controller._handle_event = handle_event + controller._wrapped_instance._handle_event = handle_event
if sys.stdout.isatty(): events = args.run_cmd.upper().split(' ', 1)[1] diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index 0d262ab5..edbcca70 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole): # Intercept events our controller hears about at a pretty low level since # the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._async_controller._handle_event + handle_event_real = self._controller._wrapped_instance._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None: await handle_event_real(event_message) @@ -139,8 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._async_controller._handle_event = handle_event_wrapper - self._controller._handle_event = handle_event_wrapper # type: ignore + self._controller._wrapped_instance._handle_event = handle_event_wrapper
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]: events = list(self._received_events) diff --git a/stem/util/__init__.py b/stem/util/__init__.py index e4fa3ca8..a230cfbd 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -5,7 +5,9 @@ Utility functions used by the stem library. """
+import asyncio import datetime +import threading
from typing import Any, Union
@@ -139,3 +141,82 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any): setattr(obj, '_cached_hash', my_hash)
return my_hash + + +class CombinedReentrantAndAsyncioLock: + """ + Lock that combines thread-safe reentrant and not thread-safe asyncio locks. + """ + + __slots__ = ('_r_lock', '_async_lock') + + def __init__(self): + self._r_lock = threading.RLock() + self._async_lock = asyncio.Lock() + + async def acquire(self): + await self._async_lock.acquire() + self._r_lock.acquire() + + def release(self): + self._r_lock.release() + self._async_lock.release() + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.release() + + +class ThreadForWrappedAsyncClass(threading.Thread): + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + self.loop = asyncio.new_event_loop() + self.setDaemon(True) + + def run(self): + self.loop.run_forever() + + def join(self, timeout=None): + self.loop.call_soon_threadsafe(self.loop.stop) + super().join(timeout) + self.loop.close() + + +class AsyncClassWrapper: + _thread_for_wrapped_class: ThreadForWrappedAsyncClass + _wrapped_instance: type + + def _init_async_class(self, async_class, *args, **kwargs): + thread = self._thread_for_wrapped_class + # The asynchronous class should be initialized in the thread where + # its methods will be executed. + if thread != threading.current_thread(): + async def init(): + return async_class(*args, **kwargs) + + return asyncio.run_coroutine_threadsafe(init(), thread.loop).result() + + return async_class(*args, **kwargs) + + def _call_async_method_soon(self, method_name, *args, **kwargs): + return asyncio.run_coroutine_threadsafe( + getattr(self._wrapped_instance, method_name)(*args, **kwargs), + self._thread_for_wrapped_class.loop, + ) + + def _execute_async_method(self, method_name, *args, **kwargs): + return self._call_async_method_soon(method_name, *args, **kwargs).result() + + def _execute_async_generator_method(self, method_name, *args, **kwargs): + async def convert_async_generator(generator): + return iter([d async for d in generator]) + + return asyncio.run_coroutine_threadsafe( + convert_async_generator( + getattr(self._wrapped_instance, method_name)(*args, **kwargs), + ), + self._thread_for_wrapped_class.loop, + ).result() diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py index 73e71fa4..b1772f34 100644 --- a/test/integ/control/controller.py +++ b/test/integ/control/controller.py @@ -113,7 +113,7 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._async_controller, state_controller) + self.assertEqual(controller._wrapped_instance, state_controller) self.assertEqual(State.RESET, state_type) self.assertTrue(state_timestamp > before and state_timestamp < after)
diff --git a/test/runner.py b/test/runner.py index 4a38e824..189a2d7b 100644 --- a/test/runner.py +++ b/test/runner.py @@ -488,7 +488,7 @@ class Runner(object): :raises: :class: `test.runner.TorInaccessable` if tor can't be connected to """
- async_controller_thread = stem.control._AsyncControllerThread() + async_controller_thread = stem.util.ThreadForWrappedAsyncClass() async_controller_thread.start()
try: diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index e8ef4787..a11aba45 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -44,7 +44,7 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))): self.controller = Controller(socket) - self.async_controller = self.controller._async_controller + self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock() self.controller.add_event_listener(self.circ_listener, EventType.CIRC) @@ -748,7 +748,7 @@ class TestControl(unittest.TestCase): with patch('time.time', Mock(return_value = TEST_TIMESTAMP)): with patch('stem.control.AsyncController.is_alive') as is_alive_mock: is_alive_mock.return_value = True - loop = self.controller._asyncio_loop + loop = self.controller._thread_for_wrapped_class.loop asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
tor-commits@lists.torproject.org