[tor-commits] [stem/master] Remove Query's Synchronous usage

atagar at torproject.org atagar at torproject.org
Sun Nov 8 01:24:38 UTC 2020


commit 7ce8a5e090fc95bfb874299d61c824638d5242f4
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Nov 7 17:18:53 2020 -0800

    Remove Query's Synchronous usage
    
    First step to remove our asyncio metaprogramming...
    
      https://github.com/torproject/stem/issues/77
    
    Our Query class now provides a run method for synchronous users, and run_async
    for asyncio. This also adds a stop method that can cancel our download.
---
 stem/descriptor/remote.py      | 114 ++++++++++++++++++++++++++++++++---------
 test/unit/descriptor/remote.py |  53 ++++++++++++++-----
 2 files changed, 130 insertions(+), 37 deletions(-)

diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 136b9d15..50b3065c 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -100,8 +100,7 @@ import stem.util.tor_tools
 
 from stem.descriptor import Compression
 from stem.util import log, str_tools
-from stem.util.asyncio import Synchronous
-from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
 
 # Tor has a limited number of descriptors we can fetch explicitly by their
 # fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
   return get_instance().get_detached_signatures(**query_args)
 
 
-class Query(Synchronous):
+class Query(object):
   """
   Asynchronous request for descriptor content from a directory authority or
   mirror. These can either be made through the
@@ -369,7 +368,6 @@ class Query(Synchronous):
     super(Query, self).__init__()
 
     if not resource.startswith('/'):
-      self.stop()
       raise ValueError("Resources should start with a '/': %s" % resource)
 
     if resource.endswith('.z'):
@@ -380,7 +378,6 @@ class Query(Synchronous):
     elif isinstance(compression, stem.descriptor._Compression):
       compression = [compression]  # caller provided only a single option
     else:
-      self.stop()
       raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
 
     if Compression.ZSTD in compression and not Compression.ZSTD.available:
@@ -404,7 +401,6 @@ class Query(Synchronous):
         if isinstance(endpoint, (stem.ORPort, stem.DirPort)):
           self.endpoints.append(endpoint)
         else:
-          self.stop()
           raise ValueError("Endpoints must be an stem.ORPort or stem.DirPort. '%s' is a %s." % (endpoint, type(endpoint).__name__))
 
     self.resource = resource
@@ -428,6 +424,12 @@ class Query(Synchronous):
     self._downloader_task = None  # type: Optional[asyncio.Task]
     self._downloader_lock = threading.RLock()
 
+    # background thread if outside an asyncio context
+
+    self._loop = None  # type: Optional[asyncio.AbstractEventLoop]
+    self._loop_thread = None  # type: Optional[threading.Thread]
+    self._loop_lock = threading.RLock()
+
     if start:
       self.start()
 
@@ -441,9 +443,38 @@ class Query(Synchronous):
 
     with self._downloader_lock:
       if self._downloader_task is None:
-        self._downloader_task = self._loop.create_task(Query._download_descriptors(self, self.retries, self.timeout))
+        with self._loop_lock:
+          if self._loop is None:
+            try:
+              self._loop = asyncio.get_running_loop()
+            except RuntimeError:
+              self._loop = asyncio.new_event_loop()
+              self._loop_thread = threading.Thread(
+                name = 'stem.descriptor.remote query',
+                target = self._loop.run_forever,
+                daemon = True,
+              )
+
+              self._loop_thread.start()
+
+        self._downloader_task = self._loop.create_task(self._download_descriptors(self.retries, self.timeout))
+
+  def stop(self) -> None:
+    """
+    Aborts our download if it's in progress, and cleans up underlying
+    resources.
+    """
 
-  async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+    with self._downloader_lock:
+      if self._downloader_task and not self._downloader_task.done():
+        self._downloader_task.cancel()
+
+    with self._loop_lock:
+      if self._loop_thread and self._loop_thread.is_alive():
+        self._loop.call_soon_threadsafe(self._loop.stop)
+        self._loop_thread.join()
+
+  def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
     """
     Blocks until our request is complete then provides the descriptors. If we
     haven't yet started our request then this does so.
@@ -461,12 +492,43 @@ class Query(Synchronous):
         * :class:`~stem.DownloadFailed` if our request fails
     """
 
-    try:
-      return [desc async for desc in self._run(suppress)]
-    finally:
-      self.stop()
+    if not self.downloaded and not self.error:
+      with self._loop_lock:
+        if self._loop is None:
+          self.start()
+
+        async def run_wrapper():
+          return [desc async for desc in self.run_async(suppress = True)]
+
+        asyncio.run_coroutine_threadsafe(run_wrapper(), self._loop).result()
+
+        self.stop()
+
+    if self.error:
+      if suppress:
+        return []
+
+      raise self.error
+    else:
+      return list(self.downloaded)
+
+  async def run_async(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
+    """
+    Asynchronous counterpart of :func:`stem.descriptor.remote.Query.run`
+
+    :param suppress: avoids raising exceptions if **True**
+
+    :returns: iterator for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
+
+    :raises:
+      Using the iterator can fail with the following if **suppress** is
+      **False**...
+
+        * **ValueError** if the descriptor contents is malformed
+        * :class:`~stem.DownloadTimeout` if our request timed out
+        * :class:`~stem.DownloadFailed` if our request fails
+    """
 
-  async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
     with self._downloader_lock:
       if not self.downloaded and not self.error:
         if not self._downloader_task:
@@ -477,17 +539,21 @@ class Query(Synchronous):
         except Exception as exc:
           self.error = exc
 
-      if self.error:
-        if suppress:
-          return
+    if self.error:
+      if suppress:
+        return
 
-        raise self.error
-      else:
-        for desc in self.downloaded:
-          yield desc
+      raise self.error
+    else:
+      for desc in self.downloaded:
+        yield desc
+
+  def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
+    for desc in self.run(True):
+      yield desc
 
   async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]:
-    async for desc in self._run(True):
+    async for desc in self.run_async(True):
       yield desc
 
   def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
@@ -620,7 +686,7 @@ class DescriptorDownloader(object):
     directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST]
     new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])
 
-    consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]  # type: ignore
+    consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]
 
     for desc in consensus.routers.values():
       if stem.Flag.V2DIR in desc.flags and desc.dir_port:
@@ -630,7 +696,7 @@ class DescriptorDownloader(object):
 
     self._endpoints = list(new_endpoints)
 
-    return consensus
+    return consensus  # type: ignore
 
   def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
     """
@@ -785,7 +851,7 @@ class DescriptorDownloader(object):
     # authority key certificates
 
     if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT:
-      consensus = list(consensus_query.run())[0]  # type: ignore
+      consensus = list(consensus_query.run())[0]
       key_certs = self.get_key_certificates(**query_args).run()
 
       try:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 58c7276a..8635d6bd 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -2,6 +2,7 @@
 Unit tests for stem.descriptor.remote.
 """
 
+import time
 import unittest
 
 import stem
@@ -87,12 +88,50 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
     self.assertTrue(query._downloader_task is None)
-    query.stop()
 
     query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = True)
     self.assertTrue(query._downloader_task is not None)
     query.stop()
 
+  def test_stop(self):
+    """
+    Stop a complete, in-process, and unstarted query.
+    """
+
+    # stop a completed query
+
+    with mock_download(TEST_DESCRIPTOR):
+      query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')
+      self.assertTrue(query._loop_thread.is_alive())
+
+      query.run()  # complete the query
+      self.assertFalse(query._loop_thread.is_alive())
+      self.assertFalse(query._downloader_task.cancelled())
+
+      query.stop()  # nothing to do
+      self.assertFalse(query._loop_thread.is_alive())
+      self.assertFalse(query._downloader_task.cancelled())
+
+    # stop an in-process query
+
+    def pause(*args):
+      time.sleep(5)
+
+    with patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = pause)):
+      query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')
+
+      query.stop()  # terminates in-process query
+      self.assertFalse(query._loop_thread.is_alive())
+      self.assertTrue(query._downloader_task.cancelled())
+
+    # stop an unstarted query
+
+    query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
+
+    query.stop()  # nothing to do
+    self.assertTrue(query._loop_thread is None)
+    self.assertTrue(query._downloader_task is None)
+
   @mock_download(TEST_DESCRIPTOR)
   def test_download(self):
     """
@@ -115,8 +154,6 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
     self.assertEqual(TEST_DESCRIPTOR.rstrip(), desc.get_bytes())
 
-    reply.stop()
-
   def test_response_header_code(self):
     """
     When successful Tor provides a '200 OK' status, but we should accept other 2xx
@@ -165,13 +202,11 @@ class TestDescriptorDownloader(unittest.TestCase):
     descriptors = list(query)
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
-    query.stop()
 
   def test_gzip_url_override(self):
     query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
     self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
     self.assertEqual(TEST_RESOURCE, query.resource)
-    query.stop()
 
   @mock_download(read_resource('compressed_identity'), encoding = 'identity')
   def test_compression_plaintext(self):
@@ -187,7 +222,6 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -206,7 +240,6 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -227,7 +260,6 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -248,7 +280,6 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     descriptors = list(query)
-    query.stop()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -300,8 +331,6 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     self.assertRaises(ValueError, query.run)
 
-    query.stop()
-
   def test_query_with_invalid_endpoints(self):
     invalid_endpoints = {
       'hello': "'h' is a str.",
@@ -330,5 +359,3 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(list(query)))
     self.assertEqual(1, len(list(query)))
     self.assertEqual(1, len(list(query)))
-
-    query.stop()



More information about the tor-commits mailing list