[tor-commits] [stem/master] Synchronous class mockability

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit ef1e41ebce0aa1bb5fde9064410402bee9887451
Author: Damian Johnson <atagar at torproject.org>
Date:   Wed Jul 8 17:12:39 2020 -0700

    Synchronous class mockability
    
    The meta-programming behind our Synchronous class doesn't play well with test
    mocks. Handling this, and testing the permutations I can think of.
---
 stem/util/__init__.py         | 49 +++++++++++++++++++++++++++-----------
 test/unit/util/synchronous.py | 55 +++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 88 insertions(+), 16 deletions(-)

diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 54f90376..15840f56 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,9 +10,11 @@ import datetime
 import functools
 import inspect
 import threading
+import unittest.mock
+
 from concurrent.futures import Future
 
-from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
 
 __all__ = [
   'conf',
@@ -213,19 +215,11 @@ class Synchronous(object):
 
       # call any coroutines through this loop
 
-      def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
-        if Synchronous.is_asyncio_context():
-          return func(*args, **kwargs)
-
-        with self._loop_thread_lock:
-          if not self._loop_thread.is_alive():
-            raise RuntimeError('%s has been stopped' % type(self).__name__)
-
-          return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
-
-      for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
-        if inspect.iscoroutinefunction(func):
-          setattr(self, method_name, functools.partial(call_async, func))
+      for name, func in inspect.getmembers(self):
+        if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+          setattr(self, name, functools.partial(self._call_async_method, name))
+        elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
+          setattr(self, name, functools.partial(self._call_async_method, name))
 
       asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
 
@@ -312,6 +306,33 @@ class Synchronous(object):
     except RuntimeError:
       return False
 
+  def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+    """
+    Run this async method from either a synchronous or asynchronous context.
+
+    :param method_name: name of the method to invoke
+    :param args: positional arguments
+    :param kwargs: keyword arguments
+
+    :returns: method's return value
+
+    :raises: **AttributeError** if this method doesn't exist
+    """
+
+    # Retrieving methods by name (rather than keeping a reference) so runtime
+    # replacements like test mocks work.
+
+    func = getattr(type(self), method_name)
+
+    if Synchronous.is_asyncio_context():
+      return func(self, *args, **kwargs)
+
+    with self._loop_thread_lock:
+      if self._loop_thread and not self._loop_thread.is_alive():
+        raise RuntimeError('%s has been closed' % type(self).__name__)
+
+      return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+
   def __iter__(self) -> Iterator:
     async def convert_generator(generator: AsyncIterator) -> Iterator:
       return iter([d async for d in generator])
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index dd27c3c6..5b38a7b5 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -6,9 +6,10 @@ import asyncio
 import io
 import unittest
 
-from unittest.mock import patch
+from unittest.mock import patch, Mock
 
 from stem.util import Synchronous
+from stem.util.test_tools import coro_func_returning_value
 
 EXAMPLE_OUTPUT = """\
 hello from a synchronous context
@@ -20,6 +21,8 @@ class Example(Synchronous):
   async def hello(self):
     return 'hello'
 
+  def sync_hello(self):
+    return 'hello'
 
 class TestSynchronous(unittest.TestCase):
   @patch('sys.stdout', new_callable = io.StringIO)
@@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
 
   def test_ainit(self):
     """
-    Check that our constructor runs __ainit__ if present.
+    Check that our constructor runs __ainit__ when present.
     """
 
     class AinitDemo(Synchronous):
@@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
     instance.start()
     self.assertEqual('hello', instance.hello())
     instance.stop()
+
+  def test_asynchronous_mockability(self):
+    """
+    Check that method mocks are respected.
+    """
+
+    # mock prior to construction
+
+    with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+      instance = Example()
+      self.assertEqual('mocked hello', instance.hello())
+
+    self.assertEqual('hello', instance.hello())  # mock should now be reverted
+    instance.stop()
+
+    # mock after construction
+
+    instance = Example()
+
+    with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+      self.assertEqual('mocked hello', instance.hello())
+
+    self.assertEqual('hello', instance.hello())
+    instance.stop()
+
+  def test_synchronous_mockability(self):
+    """
+    Ensure we do not disrupt non-asynchronous method mocks.
+    """
+
+    # mock prior to construction
+
+    with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+      instance = Example()
+      self.assertEqual('mocked hello', instance.sync_hello())
+
+    self.assertEqual('hello', instance.sync_hello())  # mock should now be reverted
+    instance.stop()
+
+    # mock after construction
+
+    instance = Example()
+
+    with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+      self.assertEqual('mocked hello', instance.sync_hello())
+
+    self.assertEqual('hello', instance.sync_hello())
+    instance.stop()





More information about the tor-commits mailing list