[tor-commits] [stem/master] Make Synchronous class resumable

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 9f71ce9b21d8c710025a440e69920006e37eeb88
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Jul 7 18:38:06 2020 -0700

    Make Synchronous class resumable
    
    Our Controller needs to start and stop with its connect/close methods, so we
    need for this to be resumable.
---
 stem/util/__init__.py          | 50 +++++++++++++++++++++++++++---------------
 test/settings.cfg              |  2 ++
 test/unit/descriptor/remote.py | 20 ++++++++---------
 test/unit/util/synchronous.py  | 29 +++++++++++++++++-------
 4 files changed, 65 insertions(+), 36 deletions(-)

diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index cddce755..54f90376 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -12,7 +12,7 @@ import inspect
 import threading
 from concurrent.futures import Future
 
-from typing import Any, AsyncIterator, Callable, Iterator, Type, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
 
 __all__ = [
   'conf',
@@ -162,12 +162,12 @@ class Synchronous(object):
     def sync_demo():
       instance = Example()
       print('%s from a synchronous context' % instance.hello())
-      instance.close()
+      instance.stop()
 
     async def async_demo():
       instance = Example()
       print('%s from an asynchronous context' % await instance.hello())
-      instance.close()
+      instance.stop()
 
     sync_demo()
     asyncio.run(async_demo())
@@ -194,35 +194,34 @@ class Synchronous(object):
         # asyncio.get_running_loop(), and construct objects that
         # require it (like asyncio.Queue and asyncio.Lock).
 
-  Users are responsible for calling :func:`~stem.util.Synchronous.close` when
+  Users are responsible for calling :func:`~stem.util.Synchronous.stop` when
   finished to clean up underlying resources.
   """
 
   def __init__(self) -> None:
+    self._loop_thread = None  # type: Optional[threading.Thread]
+    self._loop_thread_lock = threading.RLock()
+
     if Synchronous.is_asyncio_context():
       self._loop = asyncio.get_running_loop()
-      self._loop_thread = None
 
       self.__ainit__()
     else:
       self._loop = asyncio.new_event_loop()
-      self._loop_thread = threading.Thread(
-        name = '%s asyncio' % type(self).__name__,
-        target = self._loop.run_forever,
-        daemon = True,
-      )
 
-      self._loop_thread.start()
+      Synchronous.start(self)
 
       # call any coroutines through this loop
 
       def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
         if Synchronous.is_asyncio_context():
           return func(*args, **kwargs)
-        elif not self._loop_thread.is_alive():
-          raise RuntimeError('%s has been closed' % type(self).__name__)
 
-        return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+        with self._loop_thread_lock:
+          if not self._loop_thread.is_alive():
+            raise RuntimeError('%s has been stopped' % type(self).__name__)
+
+          return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
 
       for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
         if inspect.iscoroutinefunction(func):
@@ -273,16 +272,31 @@ class Synchronous(object):
 
     pass
 
-  def close(self) -> None:
+  def start(self) -> None:
+    """
+    Initiate resources to make this object callable from synchronous contexts.
+    """
+
+    with self._loop_thread_lock:
+      self._loop_thread = threading.Thread(
+        name = '%s asyncio' % type(self).__name__,
+        target = self._loop.run_forever,
+        daemon = True,
+      )
+
+      self._loop_thread.start()
+
+  def stop(self) -> None:
     """
     Terminate resources that permits this from being callable from synchronous
     contexts. Once called any further synchronous invocations will fail with a
     **RuntimeError**.
     """
 
-    if self._loop_thread and self._loop_thread.is_alive():
-      self._loop.call_soon_threadsafe(self._loop.stop)
-      self._loop_thread.join()
+    with self._loop_thread_lock:
+      if self._loop_thread and self._loop_thread.is_alive():
+        self._loop.call_soon_threadsafe(self._loop.stop)
+        self._loop_thread.join()
 
   @staticmethod
   def is_asyncio_context() -> bool:
diff --git a/test/settings.cfg b/test/settings.cfg
index fcef5ec1..70bdd069 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -229,6 +229,8 @@ mypy.ignore * => "_IntegerEnum" has no attribute *
 mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html*
 mypy.ignore * => *is not valid as a type*
 
+mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]" of "start" *
+
 # Metaprogramming prevents mypy from determining descriptor attributes.
 
 mypy.ignore * => "Descriptor" has no attribute "*
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 1fd2aaf9..bb6f554c 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -100,7 +100,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
     self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
 
-    reply.close()
+    reply.stop()
 
   def test_response_header_code(self):
     """
@@ -150,13 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase):
     descriptors = list(query)
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
-    query.close()
+    query.stop()
 
   def test_gzip_url_override(self):
     query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
     self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
     self.assertEqual(TEST_RESOURCE, query.resource)
-    query.close()
+    query.stop()
 
   @mock_download(read_resource('compressed_identity'), encoding = 'identity')
   def test_compression_plaintext(self):
@@ -172,7 +172,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.close()
+    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -191,7 +191,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.close()
+    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -212,7 +212,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.close()
+    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -233,7 +233,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.close()
+    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -258,7 +258,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     queries.append(downloader.get_detached_signatures())
 
     for query in queries:
-      query.close()
+      query.stop()
 
   @mock_download(b'some malformed stuff')
   def test_malformed_content(self):
@@ -285,7 +285,7 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     self.assertRaises(ValueError, query.run)
 
-    query.close()
+    query.stop()
 
   def test_query_with_invalid_endpoints(self):
     invalid_endpoints = {
@@ -316,4 +316,4 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(list(query)))
     self.assertEqual(1, len(list(query)))
 
-    query.close()
+    query.stop()
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 22271ffd..dd27c3c6 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -31,12 +31,12 @@ class TestSynchronous(unittest.TestCase):
     def sync_demo():
       instance = Example()
       print('%s from a synchronous context' % instance.hello())
-      instance.close()
+      instance.stop()
 
     async def async_demo():
       instance = Example()
       print('%s from an asynchronous context' % await instance.hello())
-      instance.close()
+      instance.stop()
 
     sync_demo()
     asyncio.run(async_demo())
@@ -66,20 +66,33 @@ class TestSynchronous(unittest.TestCase):
     sync_demo()
     asyncio.run(async_demo())
 
-  def test_after_close(self):
+  def test_after_stop(self):
     """
-    Check that closed instances raise a RuntimeError to synchronous callers.
+    Check that stopped instances raise a RuntimeError to synchronous callers.
     """
 
-    # close a used instance
+    # stop a used instance
 
     instance = Example()
     self.assertEqual('hello', instance.hello())
-    instance.close()
+    instance.stop()
     self.assertRaises(RuntimeError, instance.hello)
 
-    # close an unused instance
+    # stop an unused instance
 
     instance = Example()
-    instance.close()
+    instance.stop()
     self.assertRaises(RuntimeError, instance.hello)
+
+  def test_resuming(self):
+    """
+    Resume a previously stopped instance.
+    """
+
+    instance = Example()
+    self.assertEqual('hello', instance.hello())
+    instance.stop()
+    self.assertRaises(RuntimeError, instance.hello)
+    instance.start()
+    self.assertEqual('hello', instance.hello())
+    instance.stop()





More information about the tor-commits mailing list