commit 7ce8a5e090fc95bfb874299d61c824638d5242f4 Author: Damian Johnson atagar@torproject.org Date: Sat Nov 7 17:18:53 2020 -0800
Remove Query's Synchronous usage
First step to remove our asyncio metaprogramming...
https://github.com/torproject/stem/issues/77
Our Query class now provides a run method for synchronous users, and run_async for asyncio. This also adds a stop method that can cancel our download. --- stem/descriptor/remote.py | 114 ++++++++++++++++++++++++++++++++--------- test/unit/descriptor/remote.py | 53 ++++++++++++++----- 2 files changed, 130 insertions(+), 37 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index 136b9d15..50b3065c 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -100,8 +100,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression from stem.util import log, str_tools -from stem.util.asyncio import Synchronous -from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their # fingerprint or hashes due to a limit on the url length by squid proxies. @@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query' return get_instance().get_detached_signatures(**query_args)
-class Query(Synchronous): +class Query(object): """ Asynchronous request for descriptor content from a directory authority or mirror. These can either be made through the @@ -369,7 +368,6 @@ class Query(Synchronous): super(Query, self).__init__()
if not resource.startswith('/'): - self.stop() raise ValueError("Resources should start with a '/': %s" % resource)
if resource.endswith('.z'): @@ -380,7 +378,6 @@ class Query(Synchronous): elif isinstance(compression, stem.descriptor._Compression): compression = [compression] # caller provided only a single option else: - self.stop() raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
if Compression.ZSTD in compression and not Compression.ZSTD.available: @@ -404,7 +401,6 @@ class Query(Synchronous): if isinstance(endpoint, (stem.ORPort, stem.DirPort)): self.endpoints.append(endpoint) else: - self.stop() raise ValueError("Endpoints must be an stem.ORPort or stem.DirPort. '%s' is a %s." % (endpoint, type(endpoint).__name__))
self.resource = resource @@ -428,6 +424,12 @@ class Query(Synchronous): self._downloader_task = None # type: Optional[asyncio.Task] self._downloader_lock = threading.RLock()
+ # background thread if outside an asyncio context + + self._loop = None # type: Optional[asyncio.AbstractEventLoop] + self._loop_thread = None # type: Optional[threading.Thread] + self._loop_lock = threading.RLock() + if start: self.start()
@@ -441,9 +443,38 @@ class Query(Synchronous):
with self._downloader_lock: if self._downloader_task is None: - self._downloader_task = self._loop.create_task(Query._download_descriptors(self, self.retries, self.timeout)) + with self._loop_lock: + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread( + name = 'stem.descriptor.remote query', + target = self._loop.run_forever, + daemon = True, + ) + + self._loop_thread.start() + + self._downloader_task = self._loop.create_task(self._download_descriptors(self.retries, self.timeout)) + + def stop(self) -> None: + """ + Aborts our download if it's in progress, and cleans up underlying + resources. + """
- async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: + with self._downloader_lock: + if self._downloader_task and not self._downloader_task.done(): + self._downloader_task.cancel() + + with self._loop_lock: + if self._loop_thread and self._loop_thread.is_alive(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join() + + def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -461,12 +492,43 @@ class Query(Synchronous): * :class:`~stem.DownloadFailed` if our request fails """
- try: - return [desc async for desc in self._run(suppress)] - finally: - self.stop() + if not self.downloaded and not self.error: + with self._loop_lock: + if self._loop is None: + self.start() + + async def run_wrapper(): + return [desc async for desc in self.run_async(suppress = True)] + + asyncio.run_coroutine_threadsafe(run_wrapper(), self._loop).result() + + self.stop() + + if self.error: + if suppress: + return [] + + raise self.error + else: + return list(self.downloaded) + + async def run_async(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]: + """ + Asynchronous counterpart of :func:`stem.descriptor.remote.Query.run` + + :param suppress: avoids raising exceptions if **True** + + :returns: iterator for the requested :class:`~stem.descriptor.__init__.Descriptor` instances + + :raises: + Using the iterator can fail with the following if **suppress** is + **False**... + + * **ValueError** if the descriptor contents is malformed + * :class:`~stem.DownloadTimeout` if our request timed out + * :class:`~stem.DownloadFailed` if our request fails + """
- async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]: with self._downloader_lock: if not self.downloaded and not self.error: if not self._downloader_task: @@ -477,17 +539,21 @@ class Query(Synchronous): except Exception as exc: self.error = exc
- if self.error: - if suppress: - return + if self.error: + if suppress: + return
- raise self.error - else: - for desc in self.downloaded: - yield desc + raise self.error + else: + for desc in self.downloaded: + yield desc + + def __iter__(self) -> Iterator[stem.descriptor.Descriptor]: + for desc in self.run(True): + yield desc
async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]: - async for desc in self._run(True): + async for desc in self.run_async(True): yield desc
def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint: @@ -620,7 +686,7 @@ class DescriptorDownloader(object): directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST] new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])
- consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0] # type: ignore + consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]
for desc in consensus.routers.values(): if stem.Flag.V2DIR in desc.flags and desc.dir_port: @@ -630,7 +696,7 @@ class DescriptorDownloader(object):
self._endpoints = list(new_endpoints)
- return consensus + return consensus # type: ignore
def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ @@ -785,7 +851,7 @@ class DescriptorDownloader(object): # authority key certificates
if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT: - consensus = list(consensus_query.run())[0] # type: ignore + consensus = list(consensus_query.run())[0] key_certs = self.get_key_certificates(**query_args).run()
try: diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index 58c7276a..8635d6bd 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -2,6 +2,7 @@ Unit tests for stem.descriptor.remote. """
+import time import unittest
import stem @@ -87,12 +88,50 @@ class TestDescriptorDownloader(unittest.TestCase):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False) self.assertTrue(query._downloader_task is None) - query.stop()
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = True) self.assertTrue(query._downloader_task is not None) query.stop()
+ def test_stop(self): + """ + Stop a complete, in-process, and unstarted query. + """ + + # stop a completed query + + with mock_download(TEST_DESCRIPTOR): + query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31') + self.assertTrue(query._loop_thread.is_alive()) + + query.run() # complete the query + self.assertFalse(query._loop_thread.is_alive()) + self.assertFalse(query._downloader_task.cancelled()) + + query.stop() # nothing to do + self.assertFalse(query._loop_thread.is_alive()) + self.assertFalse(query._downloader_task.cancelled()) + + # stop an in-process query + + def pause(*args): + time.sleep(5) + + with patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = pause)): + query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31') + + query.stop() # terminates in-process query + self.assertFalse(query._loop_thread.is_alive()) + self.assertTrue(query._downloader_task.cancelled()) + + # stop an unstarted query + + query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False) + + query.stop() # nothing to do + self.assertTrue(query._loop_thread is None) + self.assertTrue(query._downloader_task is None) + @mock_download(TEST_DESCRIPTOR) def test_download(self): """ @@ -115,8 +154,6 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint) self.assertEqual(TEST_DESCRIPTOR.rstrip(), desc.get_bytes())
- reply.stop() - def test_response_header_code(self): """ When successful Tor provides a '200 OK' status, but we should accept other 2xx @@ -165,13 +202,11 @@ class TestDescriptorDownloader(unittest.TestCase): descriptors = list(query) self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) - query.stop()
def test_gzip_url_override(self): query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False) self.assertEqual([stem.descriptor.Compression.GZIP], query.compression) self.assertEqual(TEST_RESOURCE, query.resource) - query.stop()
@mock_download(read_resource('compressed_identity'), encoding = 'identity') def test_compression_plaintext(self): @@ -187,7 +222,6 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -206,7 +240,6 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -227,7 +260,6 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -248,7 +280,6 @@ class TestDescriptorDownloader(unittest.TestCase): )
descriptors = list(query) - query.stop()
self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname) @@ -300,8 +331,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- query.stop() - def test_query_with_invalid_endpoints(self): invalid_endpoints = { 'hello': "'h' is a str.", @@ -330,5 +359,3 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(list(query))) self.assertEqual(1, len(list(query))) self.assertEqual(1, len(list(query))) - - query.stop()
tor-commits@lists.torproject.org