commit 8b539f2facdd86d07e76b5bf5daa379bf0d3d2ba
Author: Damian Johnson <atagar(a)torproject.org>
Date: Thu Jul 9 17:29:58 2020 -0700
Synchronous context management
Make our class handle 'with' statements, and tidy up both its implementation
and tests.
---
stem/util/__init__.py | 79 +++++++++------
test/unit/util/synchronous.py | 223 +++++++++++++++++++++++++++++-------------
2 files changed, 205 insertions(+), 97 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 15840f56..c147a5a4 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,10 +10,11 @@ import datetime
import functools
import inspect
import threading
+import typing
import unittest.mock
from concurrent.futures import Future
-
+from types import TracebackType
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
@@ -201,26 +202,31 @@ class Synchronous(object):
"""
def __init__(self) -> None:
+ self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._loop_thread = None # type: Optional[threading.Thread]
self._loop_thread_lock = threading.RLock()
- if Synchronous.is_asyncio_context():
- self._loop = asyncio.get_running_loop()
+ # this class is a no-op when created from an asyncio context
- self.__ainit__()
- else:
- self._loop = asyncio.new_event_loop()
+ self._no_op = Synchronous.is_asyncio_context()
+ if not self._no_op:
+ self._loop = asyncio.new_event_loop()
Synchronous.start(self)
- # call any coroutines through this loop
+ # call any coroutines through our loop
for name, func in inspect.getmembers(self):
- if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
- setattr(self, name, functools.partial(self._call_async_method, name))
+ if name in ('__aiter__', '__aenter__', '__aexit__'):
+ pass # async object methods with synchronous counterparts
+ elif isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+ setattr(self, name, functools.partial(self._run_async_method, name))
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
- setattr(self, name, functools.partial(self._call_async_method, name))
+ setattr(self, name, functools.partial(self._run_async_method, name))
+ if self._no_op:
+ self.__ainit__() # this is already an asyncio context
+ else:
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
def __ainit__(self):
@@ -272,13 +278,14 @@ class Synchronous(object):
"""
with self._loop_thread_lock:
- self._loop_thread = threading.Thread(
- name = '%s asyncio' % type(self).__name__,
- target = self._loop.run_forever,
- daemon = True,
- )
+ if not self._no_op and self._loop_thread is None:
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % type(self).__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
- self._loop_thread.start()
+ self._loop_thread.start()
def stop(self) -> None:
"""
@@ -288,9 +295,13 @@ class Synchronous(object):
"""
with self._loop_thread_lock:
- if self._loop_thread and self._loop_thread.is_alive():
+ if not self._no_op and self._loop_thread is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
+
+ if threading.current_thread() != self._loop_thread:
+ self._loop_thread.join()
+
+ self._loop_thread = None
@staticmethod
def is_asyncio_context() -> bool:
@@ -306,7 +317,7 @@ class Synchronous(object):
except RuntimeError:
return False
- def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+ def _run_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""
Run this async method from either a synchronous or asynchronous context.
@@ -324,25 +335,33 @@ class Synchronous(object):
func = getattr(type(self), method_name)
- if Synchronous.is_asyncio_context():
+ if self._no_op or Synchronous.is_asyncio_context():
return func(self, *args, **kwargs)
with self._loop_thread_lock:
- if self._loop_thread and not self._loop_thread.is_alive():
- raise RuntimeError('%s has been closed' % type(self).__name__)
+ if self._loop_thread is None:
+ raise RuntimeError('%s has been stopped' % type(self).__name__)
+ elif not func:
+ raise TypeError("'%s' does not have a %s method" % (type(self).__name__, method_name))
+
+ # convert iterator if indicated by this method's name or type hint
+
+ if method_name == '__aiter__' or (inspect.ismethod(func) and typing.get_type_hints(func).get('return') == AsyncIterator):
+ async def convert_generator(generator: AsyncIterator) -> Iterator:
+ return iter([d async for d in generator])
- return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+ return asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop).result()
+ else:
+ return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
def __iter__(self) -> Iterator:
- async def convert_generator(generator: AsyncIterator) -> Iterator:
- return iter([d async for d in generator])
+ return self._run_async_method('__aiter__')
- iter_func = getattr(self, '__aiter__', None)
+ def __enter__(self):
+ return self._run_async_method('__aenter__')
- if iter_func:
- return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result()
- else:
- raise TypeError("'%s' object is not iterable" % type(self).__name__)
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
+ return self._run_async_method('__aexit__', exit_type, value, traceback)
class AsyncClassWrapper:
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 5b38a7b5..d4429f3e 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -17,12 +17,34 @@ hello from an asynchronous context
"""
-class Example(Synchronous):
- async def hello(self):
- return 'hello'
+class Demo(Synchronous):
+ def __init__(self):
+ super(Demo, self).__init__()
+
+ self.called_enter = False
+ self.called_exit = False
+
+ def __ainit__(self):
+ self.ainit_loop = asyncio.get_running_loop()
+
+ async def async_method(self):
+ return 'async call'
+
+ def sync_method(self):
+ return 'sync call'
+
+ async def __aiter__(self):
+ for i in range(3):
+ yield i
+
+ async def __aenter__(self):
+ self.called_enter = True
+ return self
+
+ async def __aexit__(self, exit_type, value, traceback):
+ self.called_exit = True
+ return
- def sync_hello(self):
- return 'hello'
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@@ -31,6 +53,10 @@ class TestSynchronous(unittest.TestCase):
Run the example from our pydoc.
"""
+ class Example(Synchronous):
+ async def hello(self):
+ return 'hello'
+
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
@@ -48,102 +74,165 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self):
"""
- Check that our constructor runs __ainit__ when present.
+ Check that construction runs __ainit__ with a loop when present.
"""
- class AinitDemo(Synchronous):
- def __init__(self):
- super(AinitDemo, self).__init__()
-
- def __ainit__(self):
- self.ainit_loop = asyncio.get_running_loop()
-
- def sync_demo():
- instance = AinitDemo()
- self.assertTrue(hasattr(instance, 'ainit_loop'))
+ def sync_test():
+ instance = Demo()
+ self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+ instance.stop()
- async def async_demo():
- instance = AinitDemo()
- self.assertTrue(hasattr(instance, 'ainit_loop'))
+ async def async_test():
+ instance = Demo()
+ self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+ instance.stop()
- sync_demo()
- asyncio.run(async_demo())
+ sync_test()
+ asyncio.run(async_test())
- def test_after_stop(self):
+ def test_stop(self):
"""
- Check that stopped instances raise a RuntimeError to synchronous callers.
+ Synchronous callers should receive a RuntimeError when stopped.
"""
- # stop a used instance
+ def sync_test():
+ instance = Demo()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ self.assertRaises(RuntimeError, instance.async_method)
+
+ # synchronous methods still work
- instance = Example()
- self.assertEqual('hello', instance.hello())
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
+ self.assertEqual('sync call', instance.sync_method())
- # stop an unused instance
+ async def async_test():
+ instance = Demo()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ # stop has no affect on async users
+
+ self.assertEqual('async call', await instance.async_method())
- instance = Example()
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
+ sync_test()
+ asyncio.run(async_test())
def test_resuming(self):
"""
Resume a previously stopped instance.
"""
- instance = Example()
- self.assertEqual('hello', instance.hello())
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
- instance.start()
- self.assertEqual('hello', instance.hello())
- instance.stop()
+ def sync_test():
+ instance = Demo()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ self.assertRaises(RuntimeError, instance.async_method)
- def test_asynchronous_mockability(self):
+ instance.start()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ async def async_test():
+ instance = Demo()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ # start has no affect on async users
+
+ instance.start()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ sync_test()
+ asyncio.run(async_test())
+
+ def test_iteration(self):
"""
- Check that method mocks are respected.
+ Check that we can iterate in both contexts.
"""
- # mock prior to construction
+ def sync_test():
+ instance = Demo()
+ result = []
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
- instance = Example()
- self.assertEqual('mocked hello', instance.hello())
+ for val in instance:
+ result.append(val)
- self.assertEqual('hello', instance.hello()) # mock should now be reverted
- instance.stop()
+ self.assertEqual([0, 1, 2], result)
+ instance.stop()
- # mock after construction
+ async def async_test():
+ instance = Demo()
+ result = []
- instance = Example()
+ async for val in instance:
+ result.append(val)
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
- self.assertEqual('mocked hello', instance.hello())
+ self.assertEqual([0, 1, 2], result)
+ instance.stop()
- self.assertEqual('hello', instance.hello())
- instance.stop()
+ sync_test()
+ asyncio.run(async_test())
- def test_synchronous_mockability(self):
+ def test_context_management(self):
"""
- Ensure we do not disrupt non-asynchronous method mocks.
+ Exercise context management via 'with' statements.
"""
- # mock prior to construction
+ def sync_test():
+ instance = Demo()
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
- instance = Example()
- self.assertEqual('mocked hello', instance.sync_hello())
+ self.assertFalse(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ with instance:
+ self.assertTrue(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ self.assertTrue(instance.called_enter)
+ self.assertTrue(instance.called_exit)
+
+ async def async_test():
+ instance = Demo()
+
+ self.assertFalse(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ async with instance:
+ self.assertTrue(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ self.assertTrue(instance.called_enter)
+ self.assertTrue(instance.called_exit)
+
+ sync_test()
+ asyncio.run(async_test())
+
+ def test_mockability(self):
+ """
+ Check that method mocks are respected for both previously constructed
+ instances and those made after the mock.
+ """
+
+ pre_constructed = Demo()
+
+ with patch('test.unit.util.synchronous.Demo.async_method', Mock(side_effect = coro_func_returning_value('mocked call'))):
+ post_constructed = Demo()
+
+ self.assertEqual('mocked call', pre_constructed.async_method())
+ self.assertEqual('mocked call', post_constructed.async_method())
- self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
- instance.stop()
+ self.assertEqual('async call', pre_constructed.async_method())
+ self.assertEqual('async call', post_constructed.async_method())
- # mock after construction
+ # synchronous methods are unaffected
- instance = Example()
+ with patch('test.unit.util.synchronous.Demo.sync_method', Mock(return_value = 'mocked call')):
+ self.assertEqual('mocked call', pre_constructed.sync_method())
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
- self.assertEqual('mocked hello', instance.sync_hello())
+ self.assertEqual('sync call', pre_constructed.sync_method())
- self.assertEqual('hello', instance.sync_hello())
- instance.stop()
+ pre_constructed.stop()
+ post_constructed.stop()