commit 5406561385cb2e5274ddda732b6e1d6fb014a2c5 Author: Damian Johnson atagar@torproject.org Date: Sun Jun 28 14:33:40 2020 -0700
Constructor method with an async context
Many asyncio classes can only be constructed within a running loop. We can't presume that our __init__() has that, so adding an __ainit__() method that will. --- stem/util/__init__.py | 95 +++++++++++++++++++++++++++---------------- test/unit/util/synchronous.py | 31 ++++++++++++++ 2 files changed, 92 insertions(+), 34 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py index e8ef361e..5a8f95e0 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -150,7 +150,7 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any) -> int:
class Synchronous(object): """ - Mixin that lets a class be called from both synchronous and asynchronous + Mixin that lets a class run within both synchronous and asynchronous contexts.
:: @@ -172,39 +172,70 @@ class Synchronous(object): sync_demo() asyncio.run(async_demo())
+ Our async methods always run within a loop. For asyncio users this class has + no affect, but otherwise we transparently create an async context to run + within. + + Class initialization and any non-async methods should assume they're running + within an synchronous context. If our class supplies an **__ainit__()** + method it is invoked within our loop during initialization... + + :: + + class Example(Synchronous): + def __init__(self): + super(Example, self).__init__() + + # Synchronous part of our initialization. Avoid anything + # that must run within an asyncio loop. + + def __ainit__(self): + # Asychronous part of our initialization. You can call + # asyncio.get_running_loop(), and construct objects that + # require it (like asyncio.Queue and asyncio.Lock). + Users are responsible for calling :func:`~stem.util.Synchronous.close` when finished to clean up underlying resources. """
def __init__(self) -> None: - self._loop = asyncio.new_event_loop() - self._loop_lock = threading.RLock() - self._loop_thread = threading.Thread( - name = '%s asyncio' % type(self).__name__, - target = self._loop.run_forever, - daemon = True, - ) + ainit_func = getattr(self, '__ainit__', None) + + if Synchronous.is_asyncio_context(): + self._loop = asyncio.get_running_loop() + self._loop_thread = None + + if ainit_func: + ainit_func() + else: + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread( + name = '%s asyncio' % type(self).__name__, + target = self._loop.run_forever, + daemon = True, + )
- self._is_closed = False + self._loop_thread.start()
- # overwrite asynchronous class methods with instance methods that can be - # called from either context + # call any coroutines through this loop
- def wrap(func: Callable, *args: Any, **kwargs: Any) -> Any: - if Synchronous.is_asyncio_context(): - return func(*args, **kwargs) - else: - with self._loop_lock: - if self._is_closed: - raise RuntimeError('%s has been closed' % type(self).__name__) - elif not self._loop_thread.is_alive(): - self._loop_thread.start() + def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any: + if Synchronous.is_asyncio_context(): + return func(*args, **kwargs) + elif not self._loop_thread.is_alive(): + raise RuntimeError('%s has been closed' % type(self).__name__)
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result() + return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
- for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod): - if inspect.iscoroutinefunction(func): - setattr(self, method_name, functools.partial(wrap, func)) + for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod): + if inspect.iscoroutinefunction(func): + setattr(self, method_name, functools.partial(call_async, func)) + + if ainit_func: + async def call_ainit(): + ainit_func() + + asyncio.run_coroutine_threadsafe(call_ainit(), self._loop).result()
def close(self) -> None: """ @@ -213,12 +244,9 @@ class Synchronous(object): **RuntimeError**. """
- with self._loop_lock: - if self._loop_thread.is_alive(): - self._loop.call_soon_threadsafe(self._loop.stop) - self._loop_thread.join() - - self._is_closed = True + if self._loop_thread and self._loop_thread.is_alive(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join()
@staticmethod def is_asyncio_context() -> bool: @@ -235,14 +263,13 @@ class Synchronous(object): return False
def __iter__(self) -> Iterator: - async def convert_async_generator(generator: AsyncIterator) -> Iterator: + async def convert_generator(generator: AsyncIterator) -> Iterator: return iter([d async for d in generator])
- iter_func = getattr(self, '__aiter__') + iter_func = getattr(self, '__aiter__', None)
if iter_func: - with self._loop_lock: - return asyncio.run_coroutine_threadsafe(convert_async_generator(iter_func()), self._loop).result() + return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result() else: raise TypeError("'%s' object is not iterable" % type(self).__name__)
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py index 26dad98d..22271ffd 100644 --- a/test/unit/util/synchronous.py +++ b/test/unit/util/synchronous.py @@ -24,6 +24,10 @@ class Example(Synchronous): class TestSynchronous(unittest.TestCase): @patch('sys.stdout', new_callable = io.StringIO) def test_example(self, stdout_mock): + """ + Run the example from our pydoc. + """ + def sync_demo(): instance = Example() print('%s from a synchronous context' % instance.hello()) @@ -39,7 +43,34 @@ class TestSynchronous(unittest.TestCase):
self.assertEqual(EXAMPLE_OUTPUT, stdout_mock.getvalue())
+ def test_ainit(self): + """ + Check that our constructor runs __ainit__ if present. + """ + + class AinitDemo(Synchronous): + def __init__(self): + super(AinitDemo, self).__init__() + + def __ainit__(self): + self.ainit_loop = asyncio.get_running_loop() + + def sync_demo(): + instance = AinitDemo() + self.assertTrue(hasattr(instance, 'ainit_loop')) + + async def async_demo(): + instance = AinitDemo() + self.assertTrue(hasattr(instance, 'ainit_loop')) + + sync_demo() + asyncio.run(async_demo()) + def test_after_close(self): + """ + Check that closed instances raise a RuntimeError to synchronous callers. + """ + # close a used instance
instance = Example()
tor-commits@lists.torproject.org