commit ef1e41ebce0aa1bb5fde9064410402bee9887451 Author: Damian Johnson atagar@torproject.org Date: Wed Jul 8 17:12:39 2020 -0700
Synchronous class mockability
The meta-programming behind our Synchronous class doesn't play well with test mocks. Handling this, and testing the permutations I can think of. --- stem/util/__init__.py | 49 +++++++++++++++++++++++++++----------- test/unit/util/synchronous.py | 55 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 16 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py index 54f90376..15840f56 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -10,9 +10,11 @@ import datetime import functools import inspect import threading +import unittest.mock + from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union +from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [ 'conf', @@ -213,19 +215,11 @@ class Synchronous(object):
# call any coroutines through this loop
- def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any: - if Synchronous.is_asyncio_context(): - return func(*args, **kwargs) - - with self._loop_thread_lock: - if not self._loop_thread.is_alive(): - raise RuntimeError('%s has been stopped' % type(self).__name__) - - return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result() - - for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod): - if inspect.iscoroutinefunction(func): - setattr(self, method_name, functools.partial(call_async, func)) + for name, func in inspect.getmembers(self): + if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect): + setattr(self, name, functools.partial(self._call_async_method, name)) + elif inspect.ismethod(func) and inspect.iscoroutinefunction(func): + setattr(self, name, functools.partial(self._call_async_method, name))
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
@@ -312,6 +306,33 @@ class Synchronous(object): except RuntimeError: return False
+ def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + """ + Run this async method from either a synchronous or asynchronous context. + + :param method_name: name of the method to invoke + :param args: positional arguments + :param kwargs: keyword arguments + + :returns: method's return value + + :raises: **AttributeError** if this method doesn't exist + """ + + # Retrieving methods by name (rather than keeping a reference) so runtime + # replacements like test mocks work. + + func = getattr(type(self), method_name) + + if Synchronous.is_asyncio_context(): + return func(self, *args, **kwargs) + + with self._loop_thread_lock: + if self._loop_thread and not self._loop_thread.is_alive(): + raise RuntimeError('%s has been closed' % type(self).__name__) + + return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result() + def __iter__(self) -> Iterator: async def convert_generator(generator: AsyncIterator) -> Iterator: return iter([d async for d in generator]) diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py index dd27c3c6..5b38a7b5 100644 --- a/test/unit/util/synchronous.py +++ b/test/unit/util/synchronous.py @@ -6,9 +6,10 @@ import asyncio import io import unittest
-from unittest.mock import patch +from unittest.mock import patch, Mock
from stem.util import Synchronous +from stem.util.test_tools import coro_func_returning_value
EXAMPLE_OUTPUT = """\ hello from a synchronous context @@ -20,6 +21,8 @@ class Example(Synchronous): async def hello(self): return 'hello'
+ def sync_hello(self): + return 'hello'
class TestSynchronous(unittest.TestCase): @patch('sys.stdout', new_callable = io.StringIO) @@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self): """ - Check that our constructor runs __ainit__ if present. + Check that our constructor runs __ainit__ when present. """
class AinitDemo(Synchronous): @@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase): instance.start() self.assertEqual('hello', instance.hello()) instance.stop() + + def test_asynchronous_mockability(self): + """ + Check that method mocks are respected. + """ + + # mock prior to construction + + with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))): + instance = Example() + self.assertEqual('mocked hello', instance.hello()) + + self.assertEqual('hello', instance.hello()) # mock should now be reverted + instance.stop() + + # mock after construction + + instance = Example() + + with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))): + self.assertEqual('mocked hello', instance.hello()) + + self.assertEqual('hello', instance.hello()) + instance.stop() + + def test_synchronous_mockability(self): + """ + Ensure we do not disrupt non-asynchronous method mocks. + """ + + # mock prior to construction + + with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')): + instance = Example() + self.assertEqual('mocked hello', instance.sync_hello()) + + self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted + instance.stop() + + # mock after construction + + instance = Example() + + with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')): + self.assertEqual('mocked hello', instance.sync_hello()) + + self.assertEqual('hello', instance.sync_hello()) + instance.stop()
tor-commits@lists.torproject.org