commit 8b539f2facdd86d07e76b5bf5daa379bf0d3d2ba Author: Damian Johnson atagar@torproject.org Date: Thu Jul 9 17:29:58 2020 -0700
Synchronous context management
Make our class handle 'with' statements, and tidy up both its implementation and tests. --- stem/util/__init__.py | 79 +++++++++------ test/unit/util/synchronous.py | 223 +++++++++++++++++++++++++++++------------- 2 files changed, 205 insertions(+), 97 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py index 15840f56..c147a5a4 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -10,10 +10,11 @@ import datetime import functools import inspect import threading +import typing import unittest.mock
from concurrent.futures import Future - +from types import TracebackType from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [ @@ -201,26 +202,31 @@ class Synchronous(object): """
def __init__(self) -> None: + self._loop = None # type: Optional[asyncio.AbstractEventLoop] self._loop_thread = None # type: Optional[threading.Thread] self._loop_thread_lock = threading.RLock()
- if Synchronous.is_asyncio_context(): - self._loop = asyncio.get_running_loop() + # this class is a no-op when created from an asyncio context
- self.__ainit__() - else: - self._loop = asyncio.new_event_loop() + self._no_op = Synchronous.is_asyncio_context()
+ if not self._no_op: + self._loop = asyncio.new_event_loop() Synchronous.start(self)
- # call any coroutines through this loop + # call any coroutines through our loop
for name, func in inspect.getmembers(self): - if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect): - setattr(self, name, functools.partial(self._call_async_method, name)) + if name in ('__aiter__', '__aenter__', '__aexit__'): + pass # async object methods with synchronous counterparts + elif isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect): + setattr(self, name, functools.partial(self._run_async_method, name)) elif inspect.ismethod(func) and inspect.iscoroutinefunction(func): - setattr(self, name, functools.partial(self._call_async_method, name)) + setattr(self, name, functools.partial(self._run_async_method, name))
+ if self._no_op: + self.__ainit__() # this is already an asyncio context + else: asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
def __ainit__(self): @@ -272,13 +278,14 @@ class Synchronous(object): """
with self._loop_thread_lock: - self._loop_thread = threading.Thread( - name = '%s asyncio' % type(self).__name__, - target = self._loop.run_forever, - daemon = True, - ) + if not self._no_op and self._loop_thread is None: + self._loop_thread = threading.Thread( + name = '%s asyncio' % type(self).__name__, + target = self._loop.run_forever, + daemon = True, + )
- self._loop_thread.start() + self._loop_thread.start()
def stop(self) -> None: """ @@ -288,9 +295,13 @@ class Synchronous(object): """
with self._loop_thread_lock: - if self._loop_thread and self._loop_thread.is_alive(): + if not self._no_op and self._loop_thread is not None: self._loop.call_soon_threadsafe(self._loop.stop) - self._loop_thread.join() + + if threading.current_thread() != self._loop_thread: + self._loop_thread.join() + + self._loop_thread = None
@staticmethod def is_asyncio_context() -> bool: @@ -306,7 +317,7 @@ class Synchronous(object): except RuntimeError: return False
- def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + def _run_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: """ Run this async method from either a synchronous or asynchronous context.
@@ -324,25 +335,33 @@ class Synchronous(object):
func = getattr(type(self), method_name)
- if Synchronous.is_asyncio_context(): + if self._no_op or Synchronous.is_asyncio_context(): return func(self, *args, **kwargs)
with self._loop_thread_lock: - if self._loop_thread and not self._loop_thread.is_alive(): - raise RuntimeError('%s has been closed' % type(self).__name__) + if self._loop_thread is None: + raise RuntimeError('%s has been stopped' % type(self).__name__) + elif not func: + raise TypeError("'%s' does not have a %s method" % (type(self).__name__, method_name)) + + # convert iterator if indicated by this method's name or type hint + + if method_name == '__aiter__' or (inspect.ismethod(func) and typing.get_type_hints(func).get('return') == AsyncIterator): + async def convert_generator(generator: AsyncIterator) -> Iterator: + return iter([d async for d in generator])
- return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result() + return asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop).result() + else: + return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
def __iter__(self) -> Iterator: - async def convert_generator(generator: AsyncIterator) -> Iterator: - return iter([d async for d in generator]) + return self._run_async_method('__aiter__')
- iter_func = getattr(self, '__aiter__', None) + def __enter__(self): + return self._run_async_method('__aenter__')
- if iter_func: - return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result() - else: - raise TypeError("'%s' object is not iterable" % type(self).__name__) + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): + return self._run_async_method('__aexit__', exit_type, value, traceback)
class AsyncClassWrapper: diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py index 5b38a7b5..d4429f3e 100644 --- a/test/unit/util/synchronous.py +++ b/test/unit/util/synchronous.py @@ -17,12 +17,34 @@ hello from an asynchronous context """
-class Example(Synchronous): - async def hello(self): - return 'hello' +class Demo(Synchronous): + def __init__(self): + super(Demo, self).__init__() + + self.called_enter = False + self.called_exit = False + + def __ainit__(self): + self.ainit_loop = asyncio.get_running_loop() + + async def async_method(self): + return 'async call' + + def sync_method(self): + return 'sync call' + + async def __aiter__(self): + for i in range(3): + yield i + + async def __aenter__(self): + self.called_enter = True + return self + + async def __aexit__(self, exit_type, value, traceback): + self.called_exit = True + return
- def sync_hello(self): - return 'hello'
class TestSynchronous(unittest.TestCase): @patch('sys.stdout', new_callable = io.StringIO) @@ -31,6 +53,10 @@ class TestSynchronous(unittest.TestCase): Run the example from our pydoc. """
+ class Example(Synchronous): + async def hello(self): + return 'hello' + def sync_demo(): instance = Example() print('%s from a synchronous context' % instance.hello()) @@ -48,102 +74,165 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self): """ - Check that our constructor runs __ainit__ when present. + Check that construction runs __ainit__ with a loop when present. """
- class AinitDemo(Synchronous): - def __init__(self): - super(AinitDemo, self).__init__() - - def __ainit__(self): - self.ainit_loop = asyncio.get_running_loop() - - def sync_demo(): - instance = AinitDemo() - self.assertTrue(hasattr(instance, 'ainit_loop')) + def sync_test(): + instance = Demo() + self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop)) + instance.stop()
- async def async_demo(): - instance = AinitDemo() - self.assertTrue(hasattr(instance, 'ainit_loop')) + async def async_test(): + instance = Demo() + self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop)) + instance.stop()
- sync_demo() - asyncio.run(async_demo()) + sync_test() + asyncio.run(async_test())
- def test_after_stop(self): + def test_stop(self): """ - Check that stopped instances raise a RuntimeError to synchronous callers. + Synchronous callers should receive a RuntimeError when stopped. """
- # stop a used instance + def sync_test(): + instance = Demo() + self.assertEqual('async call', instance.async_method()) + instance.stop() + + self.assertRaises(RuntimeError, instance.async_method) + + # synchronous methods still work
- instance = Example() - self.assertEqual('hello', instance.hello()) - instance.stop() - self.assertRaises(RuntimeError, instance.hello) + self.assertEqual('sync call', instance.sync_method())
- # stop an unused instance + async def async_test(): + instance = Demo() + self.assertEqual('async call', await instance.async_method()) + instance.stop() + + # stop has no affect on async users + + self.assertEqual('async call', await instance.async_method())
- instance = Example() - instance.stop() - self.assertRaises(RuntimeError, instance.hello) + sync_test() + asyncio.run(async_test())
def test_resuming(self): """ Resume a previously stopped instance. """
- instance = Example() - self.assertEqual('hello', instance.hello()) - instance.stop() - self.assertRaises(RuntimeError, instance.hello) - instance.start() - self.assertEqual('hello', instance.hello()) - instance.stop() + def sync_test(): + instance = Demo() + self.assertEqual('async call', instance.async_method()) + instance.stop() + + self.assertRaises(RuntimeError, instance.async_method)
- def test_asynchronous_mockability(self): + instance.start() + self.assertEqual('async call', instance.async_method()) + instance.stop() + + async def async_test(): + instance = Demo() + self.assertEqual('async call', await instance.async_method()) + instance.stop() + + # start has no affect on async users + + instance.start() + self.assertEqual('async call', await instance.async_method()) + instance.stop() + + sync_test() + asyncio.run(async_test()) + + def test_iteration(self): """ - Check that method mocks are respected. + Check that we can iterate in both contexts. """
- # mock prior to construction + def sync_test(): + instance = Demo() + result = []
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))): - instance = Example() - self.assertEqual('mocked hello', instance.hello()) + for val in instance: + result.append(val)
- self.assertEqual('hello', instance.hello()) # mock should now be reverted - instance.stop() + self.assertEqual([0, 1, 2], result) + instance.stop()
- # mock after construction + async def async_test(): + instance = Demo() + result = []
- instance = Example() + async for val in instance: + result.append(val)
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))): - self.assertEqual('mocked hello', instance.hello()) + self.assertEqual([0, 1, 2], result) + instance.stop()
- self.assertEqual('hello', instance.hello()) - instance.stop() + sync_test() + asyncio.run(async_test())
- def test_synchronous_mockability(self): + def test_context_management(self): """ - Ensure we do not disrupt non-asynchronous method mocks. + Exercise context management via 'with' statements. """
- # mock prior to construction + def sync_test(): + instance = Demo()
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')): - instance = Example() - self.assertEqual('mocked hello', instance.sync_hello()) + self.assertFalse(instance.called_enter) + self.assertFalse(instance.called_exit) + + with instance: + self.assertTrue(instance.called_enter) + self.assertFalse(instance.called_exit) + + self.assertTrue(instance.called_enter) + self.assertTrue(instance.called_exit) + + async def async_test(): + instance = Demo() + + self.assertFalse(instance.called_enter) + self.assertFalse(instance.called_exit) + + async with instance: + self.assertTrue(instance.called_enter) + self.assertFalse(instance.called_exit) + + self.assertTrue(instance.called_enter) + self.assertTrue(instance.called_exit) + + sync_test() + asyncio.run(async_test()) + + def test_mockability(self): + """ + Check that method mocks are respected for both previously constructed + instances and those made after the mock. + """ + + pre_constructed = Demo() + + with patch('test.unit.util.synchronous.Demo.async_method', Mock(side_effect = coro_func_returning_value('mocked call'))): + post_constructed = Demo() + + self.assertEqual('mocked call', pre_constructed.async_method()) + self.assertEqual('mocked call', post_constructed.async_method())
- self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted - instance.stop() + self.assertEqual('async call', pre_constructed.async_method()) + self.assertEqual('async call', post_constructed.async_method())
- # mock after construction + # synchronous methods are unaffected
- instance = Example() + with patch('test.unit.util.synchronous.Demo.sync_method', Mock(return_value = 'mocked call')): + self.assertEqual('mocked call', pre_constructed.sync_method())
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')): - self.assertEqual('mocked hello', instance.sync_hello()) + self.assertEqual('sync call', pre_constructed.sync_method())
- self.assertEqual('hello', instance.sync_hello()) - instance.stop() + pre_constructed.stop() + post_constructed.stop()
tor-commits@lists.torproject.org