[tor-commits] [stem/master] Constructor method with an async context

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 5406561385cb2e5274ddda732b6e1d6fb014a2c5
Author: Damian Johnson <atagar at torproject.org>
Date:   Sun Jun 28 14:33:40 2020 -0700

    Constructor method with an async context
    
    Many asyncio classes can only be constructed within a running loop. We can't
    presume that our __init__() has that, so adding an __ainit__() method that
    will.
---
 stem/util/__init__.py         | 95 +++++++++++++++++++++++++++----------------
 test/unit/util/synchronous.py | 31 ++++++++++++++
 2 files changed, 92 insertions(+), 34 deletions(-)

diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e8ef361e..5a8f95e0 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -150,7 +150,7 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any) -> int:
 
 class Synchronous(object):
   """
-  Mixin that lets a class be called from both synchronous and asynchronous
+  Mixin that lets a class run within both synchronous and asynchronous
   contexts.
 
   ::
@@ -172,39 +172,70 @@ class Synchronous(object):
     sync_demo()
     asyncio.run(async_demo())
 
+  Our async methods always run within a loop. For asyncio users this class has
+  no affect, but otherwise we transparently create an async context to run
+  within.
+
+  Class initialization and any non-async methods should assume they're running
+  within an synchronous context. If our class supplies an **__ainit__()**
+  method it is invoked within our loop during initialization...
+
+  ::
+
+    class Example(Synchronous):
+      def __init__(self):
+        super(Example, self).__init__()
+
+        # Synchronous part of our initialization. Avoid anything
+        # that must run within an asyncio loop.
+
+      def __ainit__(self):
+        # Asychronous part of our initialization. You can call
+        # asyncio.get_running_loop(), and construct objects that
+        # require it (like asyncio.Queue and asyncio.Lock).
+
   Users are responsible for calling :func:`~stem.util.Synchronous.close` when
   finished to clean up underlying resources.
   """
 
   def __init__(self) -> None:
-    self._loop = asyncio.new_event_loop()
-    self._loop_lock = threading.RLock()
-    self._loop_thread = threading.Thread(
-      name = '%s asyncio' % type(self).__name__,
-      target = self._loop.run_forever,
-      daemon = True,
-    )
+    ainit_func = getattr(self, '__ainit__', None)
+
+    if Synchronous.is_asyncio_context():
+      self._loop = asyncio.get_running_loop()
+      self._loop_thread = None
+
+      if ainit_func:
+        ainit_func()
+    else:
+      self._loop = asyncio.new_event_loop()
+      self._loop_thread = threading.Thread(
+        name = '%s asyncio' % type(self).__name__,
+        target = self._loop.run_forever,
+        daemon = True,
+      )
 
-    self._is_closed = False
+      self._loop_thread.start()
 
-    # overwrite asynchronous class methods with instance methods that can be
-    # called from either context
+      # call any coroutines through this loop
 
-    def wrap(func: Callable, *args: Any, **kwargs: Any) -> Any:
-      if Synchronous.is_asyncio_context():
-        return func(*args, **kwargs)
-      else:
-        with self._loop_lock:
-          if self._is_closed:
-            raise RuntimeError('%s has been closed' % type(self).__name__)
-          elif not self._loop_thread.is_alive():
-            self._loop_thread.start()
+      def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
+        if Synchronous.is_asyncio_context():
+          return func(*args, **kwargs)
+        elif not self._loop_thread.is_alive():
+          raise RuntimeError('%s has been closed' % type(self).__name__)
 
-          return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+        return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
 
-    for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
-      if inspect.iscoroutinefunction(func):
-        setattr(self, method_name, functools.partial(wrap, func))
+      for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
+        if inspect.iscoroutinefunction(func):
+          setattr(self, method_name, functools.partial(call_async, func))
+
+      if ainit_func:
+        async def call_ainit():
+          ainit_func()
+
+        asyncio.run_coroutine_threadsafe(call_ainit(), self._loop).result()
 
   def close(self) -> None:
     """
@@ -213,12 +244,9 @@ class Synchronous(object):
     **RuntimeError**.
     """
 
-    with self._loop_lock:
-      if self._loop_thread.is_alive():
-        self._loop.call_soon_threadsafe(self._loop.stop)
-        self._loop_thread.join()
-
-      self._is_closed = True
+    if self._loop_thread and self._loop_thread.is_alive():
+      self._loop.call_soon_threadsafe(self._loop.stop)
+      self._loop_thread.join()
 
   @staticmethod
   def is_asyncio_context() -> bool:
@@ -235,14 +263,13 @@ class Synchronous(object):
       return False
 
   def __iter__(self) -> Iterator:
-    async def convert_async_generator(generator: AsyncIterator) -> Iterator:
+    async def convert_generator(generator: AsyncIterator) -> Iterator:
       return iter([d async for d in generator])
 
-    iter_func = getattr(self, '__aiter__')
+    iter_func = getattr(self, '__aiter__', None)
 
     if iter_func:
-      with self._loop_lock:
-        return asyncio.run_coroutine_threadsafe(convert_async_generator(iter_func()), self._loop).result()
+      return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result()
     else:
       raise TypeError("'%s' object is not iterable" % type(self).__name__)
 
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 26dad98d..22271ffd 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -24,6 +24,10 @@ class Example(Synchronous):
 class TestSynchronous(unittest.TestCase):
   @patch('sys.stdout', new_callable = io.StringIO)
   def test_example(self, stdout_mock):
+    """
+    Run the example from our pydoc.
+    """
+
     def sync_demo():
       instance = Example()
       print('%s from a synchronous context' % instance.hello())
@@ -39,7 +43,34 @@ class TestSynchronous(unittest.TestCase):
 
     self.assertEqual(EXAMPLE_OUTPUT, stdout_mock.getvalue())
 
+  def test_ainit(self):
+    """
+    Check that our constructor runs __ainit__ if present.
+    """
+
+    class AinitDemo(Synchronous):
+      def __init__(self):
+        super(AinitDemo, self).__init__()
+
+      def __ainit__(self):
+        self.ainit_loop = asyncio.get_running_loop()
+
+    def sync_demo():
+      instance = AinitDemo()
+      self.assertTrue(hasattr(instance, 'ainit_loop'))
+
+    async def async_demo():
+      instance = AinitDemo()
+      self.assertTrue(hasattr(instance, 'ainit_loop'))
+
+    sync_demo()
+    asyncio.run(async_demo())
+
   def test_after_close(self):
+    """
+    Check that closed instances raise a RuntimeError to synchronous callers.
+    """
+
     # close a used instance
 
     instance = Example()





More information about the tor-commits mailing list