commit 9f71ce9b21d8c710025a440e69920006e37eeb88 Author: Damian Johnson atagar@torproject.org Date: Tue Jul 7 18:38:06 2020 -0700
Make Synchronous class resumable
Our Controller needs to start and stop with its connect/close methods, so we need for this to be resumable. --- stem/util/__init__.py | 50 +++++++++++++++++++++++++++--------------- test/settings.cfg | 2 ++ test/unit/descriptor/remote.py | 20 ++++++++--------- test/unit/util/synchronous.py | 29 +++++++++++++++++------- 4 files changed, 65 insertions(+), 36 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py index cddce755..54f90376 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -12,7 +12,7 @@ import inspect import threading from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Type, Union +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
__all__ = [ 'conf', @@ -162,12 +162,12 @@ class Synchronous(object): def sync_demo(): instance = Example() print('%s from a synchronous context' % instance.hello()) - instance.close() + instance.stop()
async def async_demo(): instance = Example() print('%s from an asynchronous context' % await instance.hello()) - instance.close() + instance.stop()
sync_demo() asyncio.run(async_demo()) @@ -194,35 +194,34 @@ class Synchronous(object): # asyncio.get_running_loop(), and construct objects that # require it (like asyncio.Queue and asyncio.Lock).
- Users are responsible for calling :func:`~stem.util.Synchronous.close` when + Users are responsible for calling :func:`~stem.util.Synchronous.stop` when finished to clean up underlying resources. """
def __init__(self) -> None: + self._loop_thread = None # type: Optional[threading.Thread] + self._loop_thread_lock = threading.RLock() + if Synchronous.is_asyncio_context(): self._loop = asyncio.get_running_loop() - self._loop_thread = None
self.__ainit__() else: self._loop = asyncio.new_event_loop() - self._loop_thread = threading.Thread( - name = '%s asyncio' % type(self).__name__, - target = self._loop.run_forever, - daemon = True, - )
- self._loop_thread.start() + Synchronous.start(self)
# call any coroutines through this loop
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any: if Synchronous.is_asyncio_context(): return func(*args, **kwargs) - elif not self._loop_thread.is_alive(): - raise RuntimeError('%s has been closed' % type(self).__name__)
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result() + with self._loop_thread_lock: + if not self._loop_thread.is_alive(): + raise RuntimeError('%s has been stopped' % type(self).__name__) + + return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod): if inspect.iscoroutinefunction(func): @@ -273,16 +272,31 @@ class Synchronous(object):
pass
- def close(self) -> None: + def start(self) -> None: + """ + Initiate resources to make this object callable from synchronous contexts. + """ + + with self._loop_thread_lock: + self._loop_thread = threading.Thread( + name = '%s asyncio' % type(self).__name__, + target = self._loop.run_forever, + daemon = True, + ) + + self._loop_thread.start() + + def stop(self) -> None: """ Terminate resources that permits this from being callable from synchronous contexts. Once called any further synchronous invocations will fail with a **RuntimeError**. """
- if self._loop_thread and self._loop_thread.is_alive(): - self._loop.call_soon_threadsafe(self._loop.stop) - self._loop_thread.join() + with self._loop_thread_lock: + if self._loop_thread and self._loop_thread.is_alive(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join()
@staticmethod def is_asyncio_context() -> bool: diff --git a/test/settings.cfg b/test/settings.cfg index fcef5ec1..70bdd069 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -229,6 +229,8 @@ mypy.ignore * => "_IntegerEnum" has no attribute * mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html* mypy.ignore * => *is not valid as a type*
+mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]" of "start" * + # Metaprogramming prevents mypy from determining descriptor attributes.
mypy.ignore * => "Descriptor" has no attribute "* diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index 1fd2aaf9..bb6f554c 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -100,7 +100,7 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint) self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
- reply.close() + reply.stop()
def test_response_header_code(self): """ @@ -150,13 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase): descriptors = list(query) self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) - query.close() + query.stop()
def test_gzip_url_override(self): query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False) self.assertEqual([stem.descriptor.Compression.GZIP], query.compression) self.assertEqual(TEST_RESOURCE, query.resource) - query.close() + query.stop()
@mock_download(read_resource('compressed_identity'), encoding = 'identity') def test_compression_plaintext(self): @@ -172,7 +172,7 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.close() + query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -191,7 +191,7 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.close() + query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -212,7 +212,7 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.close() + query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -233,7 +233,7 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.close() + query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -258,7 +258,7 @@ class TestDescriptorDownloader(unittest.TestCase): queries.append(downloader.get_detached_signatures())
for query in queries: - query.close() + query.stop()
@mock_download(b'some malformed stuff') def test_malformed_content(self): @@ -285,7 +285,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- query.close() + query.stop()
def test_query_with_invalid_endpoints(self): invalid_endpoints = { @@ -316,4 +316,4 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(list(query))) self.assertEqual(1, len(list(query)))
- query.close() + query.stop() diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py index 22271ffd..dd27c3c6 100644 --- a/test/unit/util/synchronous.py +++ b/test/unit/util/synchronous.py @@ -31,12 +31,12 @@ class TestSynchronous(unittest.TestCase): def sync_demo(): instance = Example() print('%s from a synchronous context' % instance.hello()) - instance.close() + instance.stop()
async def async_demo(): instance = Example() print('%s from an asynchronous context' % await instance.hello()) - instance.close() + instance.stop()
sync_demo() asyncio.run(async_demo()) @@ -66,20 +66,33 @@ class TestSynchronous(unittest.TestCase): sync_demo() asyncio.run(async_demo())
- def test_after_close(self): + def test_after_stop(self): """ - Check that closed instances raise a RuntimeError to synchronous callers. + Check that stopped instances raise a RuntimeError to synchronous callers. """
- # close a used instance + # stop a used instance
instance = Example() self.assertEqual('hello', instance.hello()) - instance.close() + instance.stop() self.assertRaises(RuntimeError, instance.hello)
- # close an unused instance + # stop an unused instance
instance = Example() - instance.close() + instance.stop() self.assertRaises(RuntimeError, instance.hello) + + def test_resuming(self): + """ + Resume a previously stopped instance. + """ + + instance = Example() + self.assertEqual('hello', instance.hello()) + instance.stop() + self.assertRaises(RuntimeError, instance.hello) + instance.start() + self.assertEqual('hello', instance.hello()) + instance.stop()
tor-commits@lists.torproject.org