[tor-commits] [stem/master] Use Synchronous for Query

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 69be99b4aaa0ebdb038266103a7c9e748e38ef3b
Author: Damian Johnson <atagar at torproject.org>
Date:   Wed Jun 24 17:47:23 2020 -0700

    Use Synchronous for Query
    
    Time to use our mixin in practice. Good news is that it works and *greatly*
    deduplicates our code, but it's not all sunshine and ponies...
    
      * Query users now must call close(). This is a significant hassle in terms of
        usability, and must be fixed prior to release. However, it'll require some
        API adustments.
    
      * Mypy's type checks assume that Synchronous users are calling Coroutines,
        causing false positives when objects are used in a synchronous fashion.
---
 stem/descriptor/remote.py      | 351 ++++++++---------------------------------
 stem/util/__init__.py          |  28 +++-
 test/unit/descriptor/remote.py |  61 ++++---
 3 files changed, 134 insertions(+), 306 deletions(-)

diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index f1ce79db..ad6a02ab 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -92,7 +92,6 @@ import time
 
 import stem
 import stem.client
-import stem.control
 import stem.descriptor
 import stem.descriptor.networkstatus
 import stem.directory
@@ -100,8 +99,8 @@ import stem.util.enum
 import stem.util.tor_tools
 
 from stem.descriptor import Compression
-from stem.util import log, str_tools
-from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from stem.util import Synchronous, log, str_tools
+from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
 
 # Tor has a limited number of descriptors we can fetch explicitly by their
 # fingerprint or hashes due to a limit on the url length by squid proxies.
@@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
   return get_instance().get_detached_signatures(**query_args)
 
 
-class AsyncQuery(object):
+class Query(Synchronous):
   """
   Asynchronous request for descriptor content from a directory authority or
   mirror. These can either be made through the
@@ -235,18 +234,18 @@ class AsyncQuery(object):
   advanced usage.
 
   To block on the response and get results either call
-  :func:`~stem.descriptor.remote.AsyncQuery.run` or iterate over the Query. The
-  :func:`~stem.descriptor.remote.AsyncQuery.run` method pass along any errors
-    that arise...
+  :func:`~stem.descriptor.remote.Query.run` or iterate over the Query. The
+  :func:`~stem.descriptor.remote.Query.run` method pass along any errors that
+  arise...
 
   ::
 
-    from stem.descriptor.remote import AsyncQuery
+    from stem.descriptor.remote import Query
 
     print('Current relays:')
 
     try:
-      for desc in await AsyncQuery('/tor/server/all', 'server-descriptor 1.0').run():
+      for desc in await Query('/tor/server/all', 'server-descriptor 1.0').run():
         print(desc.fingerprint)
     except Exception as exc:
       print('Unable to retrieve the server descriptors: %s' % exc)
@@ -257,7 +256,7 @@ class AsyncQuery(object):
 
     print('Current relays:')
 
-    async for desc in AsyncQuery('/tor/server/all', 'server-descriptor 1.0'):
+    async for desc in Query('/tor/server/all', 'server-descriptor 1.0'):
       print(desc.fingerprint)
 
   In either case exceptions are available via our 'error' attribute.
@@ -290,6 +289,39 @@ class AsyncQuery(object):
   For legacy reasons if our resource has a '.z' suffix then our **compression**
   argument is overwritten with Compression.GZIP.
 
+  .. versionchanged:: 1.7.0
+     Added support for downloading from ORPorts.
+
+  .. versionchanged:: 1.7.0
+     Added the compression argument.
+
+  .. versionchanged:: 1.7.0
+     Added the reply_headers attribute.
+
+     The class this provides changed between Python versions. In python2
+     this was called httplib.HTTPMessage, whereas in python3 the class was
+     renamed to http.client.HTTPMessage.
+
+  .. versionchanged:: 1.7.0
+     Avoid downloading from tor26. This directory authority throttles its
+     DirPort to such an extent that requests either time out or take on the
+     order of minutes.
+
+  .. versionchanged:: 1.7.0
+     Avoid downloading from Bifroest. This is the bridge authority so it
+     doesn't vote in the consensus, and apparently times out frequently.
+
+  .. versionchanged:: 1.8.0
+     Serge has replaced Bifroest as our bridge authority. Avoiding descriptor
+     downloads from it instead.
+
+  .. versionchanged:: 1.8.0
+     Defaulting to gzip compression rather than plaintext downloads.
+
+  .. versionchanged:: 1.8.0
+     Using :class:`~stem.descriptor.__init__.Compression` for our compression
+     argument.
+
   :var str resource: resource being fetched, such as '/tor/server/all'
   :var str descriptor_type: type of descriptors being fetched (for options see
     :func:`~stem.descriptor.__init__.parse_file`), this is guessed from the
@@ -327,9 +359,15 @@ class AsyncQuery(object):
   :var float timeout: duration before we'll time out our request
   :var str download_url: last url used to download the descriptor, this is
     unset until we've actually made a download attempt
+
+  :param start: start making the request when constructed (default is **True**)
+  :param block: only return after the request has been completed, this is
+    the same as running **query.run(True)** (default is **False**)
   """
 
-  def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+  def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+    super(Query, self).__init__()
+
     if not resource.startswith('/'):
       raise ValueError("Resources should start with a '/': %s" % resource)
 
@@ -388,6 +426,12 @@ class AsyncQuery(object):
     self._downloader_task = None  # type: Optional[asyncio.Task]
     self._downloader_lock = threading.RLock()
 
+    if start:
+      self.start()
+
+    if block:
+      self.run(True)
+
   async def start(self) -> None:
     """
     Starts downloading the scriptors if we haven't started already.
@@ -398,12 +442,14 @@ class AsyncQuery(object):
         loop = asyncio.get_running_loop()
         self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
 
-  async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+  async def run(self, suppress: bool = False, close: bool = True) -> List['stem.descriptor.Descriptor']:
     """
     Blocks until our request is complete then provides the descriptors. If we
     haven't yet started our request then this does so.
 
     :param suppress: avoids raising exceptions if **True**
+    :param close: terminates the resources backing this query if **True**,
+      further method calls will raise a RuntimeError
 
     :returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
 
@@ -416,7 +462,15 @@ class AsyncQuery(object):
         * :class:`~stem.DownloadFailed` if our request fails
     """
 
-    return [desc async for desc in self._run(suppress)]
+    # TODO: We should replace our 'close' argument with a new API design prior
+    # to release. Self-destructing this object by default for synchronous users
+    # is quite a step backward, but is acceptable as we iterate on this.
+
+    try:
+      return [desc async for desc in self._run(suppress)]
+    finally:
+      if close:
+        self._loop.call_soon_threadsafe(self._loop.stop)
 
   async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
     with self._downloader_lock:
@@ -544,271 +598,6 @@ class AsyncQuery(object):
       raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
 
 
-class Query(stem.util.AsyncClassWrapper):
-  """
-  Asynchronous request for descriptor content from a directory authority or
-  mirror. These can either be made through the
-  :class:`~stem.descriptor.remote.DescriptorDownloader` or directly for more
-  advanced usage.
-
-  To block on the response and get results either call
-  :func:`~stem.descriptor.remote.Query.run` or iterate over the Query. The
-  :func:`~stem.descriptor.remote.Query.run` method pass along any errors that
-  arise...
-
-  ::
-
-    from stem.descriptor.remote import Query
-
-    print('Current relays:')
-
-    try:
-      for desc in Query('/tor/server/all', 'server-descriptor 1.0').run():
-        print(desc.fingerprint)
-    except Exception as exc:
-      print('Unable to retrieve the server descriptors: %s' % exc)
-
-  ... while iterating fails silently...
-
-  ::
-
-    print('Current relays:')
-
-    for desc in Query('/tor/server/all', 'server-descriptor 1.0'):
-      print(desc.fingerprint)
-
-  In either case exceptions are available via our 'error' attribute.
-
-  Tor provides quite a few different descriptor resources via its directory
-  protocol (see section 4.2 and later of the `dir-spec
-  <https://gitweb.torproject.org/torspec.git/tree/dir-spec.txt>`_).
-  Commonly useful ones include...
-
-  =============================================== ===========
-  Resource                                        Description
-  =============================================== ===========
-  /tor/server/all                                 all present server descriptors
-  /tor/server/fp/<fp1>+<fp2>+<fp3>                server descriptors with the given fingerprints
-  /tor/extra/all                                  all present extrainfo descriptors
-  /tor/extra/fp/<fp1>+<fp2>+<fp3>                 extrainfo descriptors with the given fingerprints
-  /tor/micro/d/<hash1>-<hash2>                    microdescriptors with the given hashes
-  /tor/status-vote/current/consensus              present consensus
-  /tor/status-vote/current/consensus-microdesc    present microdescriptor consensus
-  /tor/status-vote/next/bandwidth                 bandwidth authority heuristics for the next consenus
-  /tor/status-vote/next/consensus-signatures      detached signature, used for making the next consenus
-  /tor/keys/all                                   key certificates for the authorities
-  /tor/keys/fp/<v3ident1>+<v3ident2>              key certificates for specific authorities
-  =============================================== ===========
-
-  **ZSTD** compression requires `zstandard
-  <https://pypi.org/project/zstandard/>`_, and **LZMA** requires the `lzma
-  module <https://docs.python.org/3/library/lzma.html>`_.
-
-  For legacy reasons if our resource has a '.z' suffix then our **compression**
-  argument is overwritten with Compression.GZIP.
-
-  .. versionchanged:: 1.7.0
-     Added support for downloading from ORPorts.
-
-  .. versionchanged:: 1.7.0
-     Added the compression argument.
-
-  .. versionchanged:: 1.7.0
-     Added the reply_headers attribute.
-
-     The class this provides changed between Python versions. In python2
-     this was called httplib.HTTPMessage, whereas in python3 the class was
-     renamed to http.client.HTTPMessage.
-
-  .. versionchanged:: 1.7.0
-     Avoid downloading from tor26. This directory authority throttles its
-     DirPort to such an extent that requests either time out or take on the
-     order of minutes.
-
-  .. versionchanged:: 1.7.0
-     Avoid downloading from Bifroest. This is the bridge authority so it
-     doesn't vote in the consensus, and apparently times out frequently.
-
-  .. versionchanged:: 1.8.0
-     Serge has replaced Bifroest as our bridge authority. Avoiding descriptor
-     downloads from it instead.
-
-  .. versionchanged:: 1.8.0
-     Defaulting to gzip compression rather than plaintext downloads.
-
-  .. versionchanged:: 1.8.0
-     Using :class:`~stem.descriptor.__init__.Compression` for our compression
-     argument.
-
-  :var str resource: resource being fetched, such as '/tor/server/all'
-  :var str descriptor_type: type of descriptors being fetched (for options see
-    :func:`~stem.descriptor.__init__.parse_file`), this is guessed from the
-    resource if **None**
-
-  :var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the
-    authority or mirror we're querying, this uses authorities if undefined
-  :var list compression: list of :data:`stem.descriptor.Compression`
-    we're willing to accept, when none are mutually supported downloads fall
-    back to Compression.PLAINTEXT
-  :var int retries: number of times to attempt the request if downloading it
-    fails
-  :var bool fall_back_to_authority: when retrying request issues the last
-    request to a directory authority if **True**
-
-  :var str content: downloaded descriptor content
-  :var Exception error: exception if a problem occured
-  :var bool is_done: flag that indicates if our request has finished
-
-  :var float start_time: unix timestamp when we first started running
-  :var http.client.HTTPMessage reply_headers: headers provided in the response,
-    **None** if we haven't yet made our request
-  :var float runtime: time our query took, this is **None** if it's not yet
-    finished
-
-  :var bool validate: checks the validity of the descriptor's content if
-    **True**, skips these checks otherwise
-  :var stem.descriptor.__init__.DocumentHandler document_handler: method in
-    which to parse a :class:`~stem.descriptor.networkstatus.NetworkStatusDocument`
-  :var dict kwargs: additional arguments for the descriptor constructor
-
-  Following are only applicable when downloading from a
-  :class:`~stem.DirPort`...
-
-  :var float timeout: duration before we'll time out our request
-  :var str download_url: last url used to download the descriptor, this is
-    unset until we've actually made a download attempt
-
-  :param start: start making the request when constructed (default is **True**)
-  :param block: only return after the request has been completed, this is
-    the same as running **query.run(True)** (default is **False**)
-  """
-
-  def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
-    self._loop = asyncio.new_event_loop()
-    self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio')
-    self._loop_thread.setDaemon(True)
-    self._loop_thread.start()
-
-    self._wrapped_instance: AsyncQuery = self._init_async_class(  # type: ignore
-      AsyncQuery,
-      resource,
-      descriptor_type,
-      endpoints,
-      compression,
-      retries,
-      fall_back_to_authority,
-      timeout,
-      validate,
-      document_handler,
-      **kwargs,
-    )
-
-    if start:
-      self.start()
-
-    if block:
-      self.run(True)
-
-  def start(self) -> None:
-    """
-    Starts downloading the scriptors if we haven't started already.
-    """
-
-    self._execute_async_method('start')
-
-  def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
-    """
-    Blocks until our request is complete then provides the descriptors. If we
-    haven't yet started our request then this does so.
-
-    :param suppress: avoids raising exceptions if **True**
-
-    :returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
-
-    :raises:
-      Using the iterator can fail with the following if **suppress** is
-      **False**...
-
-        * **ValueError** if the descriptor contents is malformed
-        * :class:`~stem.DownloadTimeout` if our request timed out
-        * :class:`~stem.DownloadFailed` if our request fails
-    """
-
-    return self._execute_async_method('run', suppress)
-
-  def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
-    for desc in self._execute_async_generator_method('__aiter__'):
-      yield desc
-
-  @property
-  def descriptor_type(self) -> str:
-    return self._wrapped_instance.descriptor_type
-
-  @property
-  def endpoints(self) -> List[Union[stem.ORPort, stem.DirPort]]:
-    return self._wrapped_instance.endpoints
-
-  @property
-  def resource(self) -> str:
-    return self._wrapped_instance.resource
-
-  @property
-  def compression(self) -> List[stem.descriptor._Compression]:
-    return self._wrapped_instance.compression
-
-  @property
-  def retries(self) -> int:
-    return self._wrapped_instance.retries
-
-  @property
-  def fall_back_to_authority(self) -> bool:
-    return self._wrapped_instance.fall_back_to_authority
-
-  @property
-  def content(self) -> Optional[bytes]:
-    return self._wrapped_instance.content
-
-  @property
-  def error(self) -> Optional[BaseException]:
-    return self._wrapped_instance.error
-
-  @property
-  def is_done(self) -> bool:
-    return self._wrapped_instance.is_done
-
-  @property
-  def download_url(self) -> Optional[str]:
-    return self._wrapped_instance.download_url
-
-  @property
-  def start_time(self) -> Optional[float]:
-    return self._wrapped_instance.start_time
-
-  @property
-  def timeout(self) -> Optional[float]:
-    return self._wrapped_instance.timeout
-
-  @property
-  def runtime(self) -> Optional[float]:
-    return self._wrapped_instance.runtime
-
-  @property
-  def validate(self) -> bool:
-    return self._wrapped_instance.validate
-
-  @property
-  def document_handler(self) -> stem.descriptor.DocumentHandler:
-    return self._wrapped_instance.document_handler
-
-  @property
-  def reply_headers(self) -> Optional[Dict[str, str]]:
-    return self._wrapped_instance.reply_headers
-
-  @property
-  def kwargs(self) -> Dict[str, Any]:
-    return self._wrapped_instance.kwargs
-
-
 class DescriptorDownloader(object):
   """
   Configurable class that issues :class:`~stem.descriptor.remote.Query`
@@ -848,7 +637,7 @@ class DescriptorDownloader(object):
     directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST]
     new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])
 
-    consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]
+    consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]  # type: ignore
 
     for desc in consensus.routers.values():
       if stem.Flag.V2DIR in desc.flags and desc.dir_port:
@@ -858,7 +647,7 @@ class DescriptorDownloader(object):
 
     self._endpoints = list(new_endpoints)
 
-    return consensus  # type: ignore
+    return consensus
 
   def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
     """
@@ -1013,7 +802,7 @@ class DescriptorDownloader(object):
     # authority key certificates
 
     if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT:
-      consensus = list(consensus_query.run())[0]
+      consensus = list(consensus_query.run())[0]  # type: ignore
       key_certs = self.get_key_certificates(**query_args).run()
 
       try:
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 72239273..e8ef361e 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -12,7 +12,7 @@ import inspect
 import threading
 from concurrent.futures import Future
 
-from typing import Any, AsyncIterator, Iterator, Type, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Type, Union
 
 __all__ = [
   'conf',
@@ -116,7 +116,7 @@ def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.
     raise ValueError('Key must be a string or cryptographic public/private key (was %s)' % type(key).__name__)
 
 
-def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
+def _hash_attr(obj: Any, *attributes: str, **kwargs: Any) -> int:
   """
   Provide a hash value for the given set of attributes.
 
@@ -124,6 +124,8 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
   :param attributes: attribute names to take into account
   :param cache: persists hash in a '_cached_hash' object attribute
   :param parent: include parent's hash value
+
+  :returns: **int** object hash
   """
 
   is_cached = kwargs.get('cache', False)
@@ -174,11 +176,11 @@ class Synchronous(object):
   finished to clean up underlying resources.
   """
 
-  def __init__(self):
+  def __init__(self) -> None:
     self._loop = asyncio.new_event_loop()
     self._loop_lock = threading.RLock()
     self._loop_thread = threading.Thread(
-      name = '%s asyncio' % self.__class__.__name__,
+      name = '%s asyncio' % type(self).__name__,
       target = self._loop.run_forever,
       daemon = True,
     )
@@ -188,7 +190,7 @@ class Synchronous(object):
     # overwrite asynchronous class methods with instance methods that can be
     # called from either context
 
-    def wrap(func, *args, **kwargs):
+    def wrap(func: Callable, *args: Any, **kwargs: Any) -> Any:
       if Synchronous.is_asyncio_context():
         return func(*args, **kwargs)
       else:
@@ -204,7 +206,7 @@ class Synchronous(object):
       if inspect.iscoroutinefunction(func):
         setattr(self, method_name, functools.partial(wrap, func))
 
-  def close(self):
+  def close(self) -> None:
     """
     Terminate resources that permits this from being callable from synchronous
     contexts. Once called any further synchronous invocations will fail with a
@@ -219,7 +221,7 @@ class Synchronous(object):
       self._is_closed = True
 
   @staticmethod
-  def is_asyncio_context():
+  def is_asyncio_context() -> bool:
     """
     Check if running within a synchronous or asynchronous context.
 
@@ -232,6 +234,18 @@ class Synchronous(object):
     except RuntimeError:
       return False
 
+  def __iter__(self) -> Iterator:
+    async def convert_async_generator(generator: AsyncIterator) -> Iterator:
+      return iter([d async for d in generator])
+
+    iter_func = getattr(self, '__aiter__')
+
+    if iter_func:
+      with self._loop_lock:
+        return asyncio.run_coroutine_threadsafe(convert_async_generator(iter_func()), self._loop).result()
+    else:
+      raise TypeError("'%s' object is not iterable" % type(self).__name__)
+
 
 class AsyncClassWrapper:
   _loop: asyncio.AbstractEventLoop
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 797bc8a3..1fd2aaf9 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -70,7 +70,7 @@ def mock_download(descriptor, encoding = 'identity', response_code_header = None
 
   data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
 
-  return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
+  return patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = coro_func_returning_value(data)))
 
 
 class TestDescriptorDownloader(unittest.TestCase):
@@ -100,6 +100,8 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
     self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
 
+    reply.close()
+
   def test_response_header_code(self):
     """
     When successful Tor provides a '200 OK' status, but we should accept other 2xx
@@ -133,7 +135,7 @@ class TestDescriptorDownloader(unittest.TestCase):
   def test_reply_header_data(self):
     query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
     self.assertEqual(None, query.reply_headers)  # initially we don't have a reply
-    query.run()
+    query.run(close = False)
 
     self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
     self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
@@ -148,11 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase):
     descriptors = list(query)
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
+    query.close()
 
   def test_gzip_url_override(self):
     query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
     self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
     self.assertEqual(TEST_RESOURCE, query.resource)
+    query.close()
 
   @mock_download(read_resource('compressed_identity'), encoding = 'identity')
   def test_compression_plaintext(self):
@@ -160,12 +164,15 @@ class TestDescriptorDownloader(unittest.TestCase):
     Download a plaintext descriptor.
     """
 
-    descriptors = list(stem.descriptor.remote.get_server_descriptors(
+    query = stem.descriptor.remote.get_server_descriptors(
       '9695DFC35FFEB861329B9F1AB04C46397020CE31',
       compression = Compression.PLAINTEXT,
       validate = True,
       skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
-    ))
+    )
+
+    descriptors = list(query)
+    query.close()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -176,12 +183,15 @@ class TestDescriptorDownloader(unittest.TestCase):
     Download a gip compressed descriptor.
     """
 
-    descriptors = list(stem.descriptor.remote.get_server_descriptors(
+    query = stem.descriptor.remote.get_server_descriptors(
       '9695DFC35FFEB861329B9F1AB04C46397020CE31',
       compression = Compression.GZIP,
       validate = True,
       skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
-    ))
+    )
+
+    descriptors = list(query)
+    query.close()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -195,11 +205,14 @@ class TestDescriptorDownloader(unittest.TestCase):
     if not Compression.ZSTD.available:
       self.skipTest('(requires zstd module)')
 
-    descriptors = list(stem.descriptor.remote.get_server_descriptors(
+    query = stem.descriptor.remote.get_server_descriptors(
       '9695DFC35FFEB861329B9F1AB04C46397020CE31',
       compression = Compression.ZSTD,
       validate = True,
-    ))
+    )
+
+    descriptors = list(query)
+    query.close()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -213,11 +226,14 @@ class TestDescriptorDownloader(unittest.TestCase):
     if not Compression.LZMA.available:
       self.skipTest('(requires lzma module)')
 
-    descriptors = list(stem.descriptor.remote.get_server_descriptors(
+    query = stem.descriptor.remote.get_server_descriptors(
       '9695DFC35FFEB861329B9F1AB04C46397020CE31',
       compression = Compression.LZMA,
       validate = True,
-    ))
+    )
+
+    descriptors = list(query)
+    query.close()
 
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
@@ -228,16 +244,21 @@ class TestDescriptorDownloader(unittest.TestCase):
     Surface level exercising of each getter method for downloading descriptors.
     """
 
+    queries = []
+
     downloader = stem.descriptor.remote.get_instance()
 
-    downloader.get_server_descriptors()
-    downloader.get_extrainfo_descriptors()
-    downloader.get_microdescriptors('test-hash')
-    downloader.get_consensus()
-    downloader.get_vote(stem.directory.Authority.from_cache()['moria1'])
-    downloader.get_key_certificates()
-    downloader.get_bandwidth_file()
-    downloader.get_detached_signatures()
+    queries.append(downloader.get_server_descriptors())
+    queries.append(downloader.get_extrainfo_descriptors())
+    queries.append(downloader.get_microdescriptors('test-hash'))
+    queries.append(downloader.get_consensus())
+    queries.append(downloader.get_vote(stem.directory.Authority.from_cache()['moria1']))
+    queries.append(downloader.get_key_certificates())
+    queries.append(downloader.get_bandwidth_file())
+    queries.append(downloader.get_detached_signatures())
+
+    for query in queries:
+      query.close()
 
   @mock_download(b'some malformed stuff')
   def test_malformed_content(self):
@@ -264,6 +285,8 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     self.assertRaises(ValueError, query.run)
 
+    query.close()
+
   def test_query_with_invalid_endpoints(self):
     invalid_endpoints = {
       'hello': "'h' is a str.",
@@ -292,3 +315,5 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(list(query)))
     self.assertEqual(1, len(list(query)))
     self.assertEqual(1, len(list(query)))
+
+    query.close()





More information about the tor-commits mailing list