[tor-commits] [stem/master] Rewrite descriptor downloading

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit 448060eabed41b3bad22cc5b0a5b5494f2793816
Author: Damian Johnson <atagar at torproject.org>
Date:   Mon Jun 15 16:23:40 2020 -0700

    Rewrite descriptor downloading
    
    Using run_in_executor() here has a couple issues...
    
      1. Executor threads aren't cleaned up. Running our tests with the '--all'
         argument concludes with...
    
         Threads lingering after test run:
           <_MainThread(MainThread, started 140249831520000)>
           <Thread(ThreadPoolExecutor-0_0, started daemon 140249689769728)>
           <Thread(ThreadPoolExecutor-0_1, started daemon 140249606911744)>
           <Thread(ThreadPoolExecutor-0_2, started daemon 140249586980608)>
           <Thread(ThreadPoolExecutor-0_3, started daemon 140249578587904)>
           <Thread(ThreadPoolExecutor-0_4, started daemon 140249570195200)>
          ...
    
      2. Asyncio has its own IO. Wrapping urllib within an executor is easy,
         but loses asyncio benefits such as imposing timeouts through
         asyncio.wait_for().
    
         Urllib marshals and parses HTTP headers, but we already do that
         for ORPort requests, so using a raw asyncio connection actually
         lets us deduplicate some code.
    
    Deduplication greatly simplifies testing in that we can mock _download_from()
    rather than the raw connection. However, I couldn't adapt our timeout test.
    Asyncio's wait_for() works in practice, but no dice when mocked.
---
 stem/descriptor/remote.py      | 229 ++++++++++++++++-------------------------
 test/unit/descriptor/remote.py | 183 ++++++++------------------------
 2 files changed, 133 insertions(+), 279 deletions(-)

diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index c23ab7a9..f1ce79db 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -84,14 +84,11 @@ content. For example...
 """
 
 import asyncio
-import functools
 import io
 import random
-import socket
 import sys
 import threading
 import time
-import urllib.request
 
 import stem
 import stem.client
@@ -313,7 +310,7 @@ class AsyncQuery(object):
   :var bool is_done: flag that indicates if our request has finished
 
   :var float start_time: unix timestamp when we first started running
-  :var http.client.HTTPMessage reply_headers: headers provided in the response,
+  :var dict reply_headers: headers provided in the response,
     **None** if we haven't yet made our request
   :var float runtime: time our query took, this is **None** if it's not yet
     finished
@@ -330,13 +327,9 @@ class AsyncQuery(object):
   :var float timeout: duration before we'll time out our request
   :var str download_url: last url used to download the descriptor, this is
     unset until we've actually made a download attempt
-
-  :param start: start making the request when constructed (default is **True**)
-  :param block: only return after the request has been completed, this is
-    the same as running **query.run(True)** (default is **False**)
   """
 
-  def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
+  def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
     if not resource.startswith('/'):
       raise ValueError("Resources should start with a '/': %s" % resource)
 
@@ -395,22 +388,15 @@ class AsyncQuery(object):
     self._downloader_task = None  # type: Optional[asyncio.Task]
     self._downloader_lock = threading.RLock()
 
-    self._asyncio_loop = asyncio.get_event_loop()
-
-    if start:
-      self.start()
-
-    if block:
-      self.run(True)
-
-  def start(self) -> None:
+  async def start(self) -> None:
     """
     Starts downloading the scriptors if we haven't started already.
     """
 
     with self._downloader_lock:
       if self._downloader_task is None:
-        self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
+        loop = asyncio.get_running_loop()
+        self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
 
   async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
     """
@@ -434,7 +420,7 @@ class AsyncQuery(object):
 
   async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
     with self._downloader_lock:
-      self.start()
+      await self.start()
       await self._downloader_task
 
       if self.error:
@@ -491,36 +477,71 @@ class AsyncQuery(object):
       return random.choice(self.endpoints)
 
   async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
-    try:
-      self.start_time = time.time()
+    self.start_time = time.time()
+
+    retries = self.retries
+    time_remaining = self.timeout
+
+    while True:
       endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
 
       if isinstance(endpoint, stem.ORPort):
         downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
-        self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
       elif isinstance(endpoint, stem.DirPort):
-        self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
-        downloaded_from = self.download_url
-        self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
+        downloaded_from = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
+        self.download_url = downloaded_from
       else:
         raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
 
-      self.runtime = time.time() - self.start_time
-      log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
-    except:
-      exc = sys.exc_info()[1]
+      try:
+        response = await asyncio.wait_for(self._download_from(endpoint), time_remaining)
+        self.content, self.reply_headers = _http_body_and_headers(response)
+
+        self.is_done = True
+        self.runtime = time.time() - self.start_time
+
+        log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime))
+        return
+      except asyncio.TimeoutError as exc:
+        self.is_done = True
+        self.error = stem.DownloadTimeout(downloaded_from, exc, sys.exc_info()[2], self.timeout)
+        return
+      except:
+        exception = sys.exc_info()[1]
+        retries -= 1
+
+        if time_remaining is not None:
+          time_remaining -= time.time() - self.start_time
+
+        if retries > 0:
+          log.debug("Failed to download descriptors from '%s' (%i retries remaining): %s" % (downloaded_from, retries, exception))
+        else:
+          log.debug("Failed to download descriptors from '%s': %s" % (self.download_url, exception))
+
+          self.is_done = True
+          self.error = exception
+          return
 
-      if timeout is not None:
-        timeout -= time.time() - self.start_time
+  async def _download_from(self, endpoint: stem.Endpoint) -> bytes:
+    http_request = '\r\n'.join((
+      'GET %s HTTP/1.0' % self.resource,
+      'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, self.compression)),
+      'User-Agent: %s' % stem.USER_AGENT,
+    )) + '\r\n\r\n'
 
-      if retries > 0 and (timeout is None or timeout > 0):
-        log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
-        return await self._download_descriptors(retries - 1, timeout)
-      else:
-        log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
-        self.error = exc
-    finally:
-      self.is_done = True
+    if isinstance(endpoint, stem.ORPort):
+      link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
+
+      async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+        async with await relay.create_circuit() as circ:
+          return await circ.directory(http_request, stream_id = 1)
+    elif isinstance(endpoint, stem.DirPort):
+      reader, writer = await asyncio.open_connection(endpoint.address, endpoint.port)
+      writer.write(str_tools._to_bytes(http_request))
+
+      return await reader.read()
+    else:
+      raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
 
 
 class Query(stem.util.AsyncClassWrapper):
@@ -663,8 +684,8 @@ class Query(stem.util.AsyncClassWrapper):
   """
 
   def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None:
-    self._loop = asyncio.get_event_loop()
-    self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
+    self._loop = asyncio.new_event_loop()
+    self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio')
     self._loop_thread.setDaemon(True)
     self._loop_thread.start()
 
@@ -677,19 +698,23 @@ class Query(stem.util.AsyncClassWrapper):
       retries,
       fall_back_to_authority,
       timeout,
-      start,
-      block,
       validate,
       document_handler,
       **kwargs,
     )
 
+    if start:
+      self.start()
+
+    if block:
+      self.run(True)
+
   def start(self) -> None:
     """
     Starts downloading the scriptors if we haven't started already.
     """
 
-    self._call_async_method_soon('start')
+    self._execute_async_method('start')
 
   def run(self, suppress = False) -> List['stem.descriptor.Descriptor']:
     """
@@ -1146,10 +1171,9 @@ class DescriptorDownloader(object):
     return Query(resource, **args)
 
 
-async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+def _http_body_and_headers(data: bytes) -> Tuple[bytes, Dict[str, str]]:
   """
-  Downloads descriptors from the given orport. Payload is just like an http
-  response (headers and all)...
+  Parse the headers and decompressed body from a HTTP response, such as...
 
   ::
 
@@ -1164,112 +1188,41 @@ async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[ste
     identity-ed25519
     ... rest of the descriptor content...
 
-  :param endpoint: endpoint to download from
-  :param compression: compression methods for the request
-  :param resource: descriptor resource to download
+  :param data: HTTP response
 
-  :returns: two value tuple of the form (data, reply_headers)
+  :returns: **tuple** with the decompressed data and headers
 
   :raises:
-    * :class:`stem.ProtocolError` if not a valid descriptor response
-    * :class:`stem.SocketError` if unable to establish a connection
-  """
-
-  link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
-
-  async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
-    async with await relay.create_circuit() as circ:
-      request = '\r\n'.join((
-        'GET %s HTTP/1.0' % resource,
-        'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
-        'User-Agent: %s' % stem.USER_AGENT,
-      )) + '\r\n\r\n'
-
-      response = await circ.directory(request, stream_id = 1)
-      first_line, data = response.split(b'\r\n', 1)
-      header_data, body_data = data.split(b'\r\n\r\n', 1)
-
-      if not first_line.startswith(b'HTTP/1.0 2'):
-        raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
-
-      headers = {}
-
-      for line in str_tools._to_unicode(header_data).splitlines():
-        if ': ' not in line:
-          raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
-
-        key, value = line.split(': ', 1)
-        headers[key] = value
-
-      return _decompress(body_data, headers.get('Content-Encoding')), headers
-
-
-async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
-  """
-  Downloads descriptors from the given url.
-
-  :param url: dirport url from which to download from
-  :param compression: compression methods for the request
-  :param timeout: duration before we'll time out our request
-
-  :returns: two value tuple of the form (data, reply_headers)
-
-  :raises:
-    * :class:`~stem.DownloadTimeout` if our request timed out
-    * :class:`~stem.DownloadFailed` if our request fails
-  """
-
-  # TODO: use an asyncronous solution for the HTTP request.
-  request = urllib.request.Request(
-    url,
-    headers = {
-      'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
-      'User-Agent': stem.USER_AGENT,
-    }
-  )
-  get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
-
-  loop = asyncio.get_event_loop()
-  try:
-    response = await loop.run_in_executor(None, get_response)
-  except socket.timeout as exc:
-    raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
-  except:
-    exception, stacktrace = sys.exc_info()[1:3]
-    raise stem.DownloadFailed(url, exception, stacktrace)
-
-  return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
-
-
-def _decompress(data: bytes, encoding: str) -> bytes:
+    * **stem.ProtocolError** if response was unsuccessful or malformed
+    * **ValueError** if encoding is unrecognized
+    * **ImportError** if missing the decompression module
   """
-  Decompresses descriptor data.
 
-  Tor doesn't include compression headers. As such when using gzip we
-  need to include '32' for automatic header detection...
+  first_line, data = data.split(b'\r\n', 1)
+  header_data, body_data = data.split(b'\r\n\r\n', 1)
 
-    https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
+  if not first_line.startswith(b'HTTP/1.0 2'):
+    raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
 
-  ... and with zstd we need to use the streaming API.
+  headers = {}
 
-  :param data: data we received
-  :param encoding: 'Content-Encoding' header of the response
+  for line in str_tools._to_unicode(header_data).splitlines():
+    if ': ' not in line:
+      raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
 
-  :returns: **bytes** with the decompressed data
+    key, value = line.split(': ', 1)
+    headers[key] = value
 
-  :raises:
-    * **ValueError** if encoding is unrecognized
-    * **ImportError** if missing the decompression module
-  """
+  encoding = headers.get('Content-Encoding')
 
   if encoding == 'deflate':
-    return stem.descriptor.Compression.GZIP.decompress(data)
+    return stem.descriptor.Compression.GZIP.decompress(body_data), headers
 
   for compression in stem.descriptor.Compression:
     if encoding == compression.encoding:
-      return compression.decompress(data)
+      return compression.decompress(body_data), headers
 
-  raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
+  raise ValueError("'%s' is an unrecognized encoding" % encoding)
 
 
 def _guess_descriptor_type(resource: str) -> str:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 33ee57fb..797bc8a3 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -2,9 +2,6 @@
 Unit tests for stem.descriptor.remote.
 """
 
-import http.client
-import socket
-import time
 import unittest
 
 import stem
@@ -67,47 +64,13 @@ HEADER = '\r\n'.join([
 ])
 
 
-def _orport_mock(data, encoding = 'identity', response_code_header = None):
+def mock_download(descriptor, encoding = 'identity', response_code_header = None):
   if response_code_header is None:
     response_code_header = b'HTTP/1.0 200 OK\r\n'
 
-  data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data
-  cells = []
+  data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
 
-  for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]:
-    cell = Mock()
-    cell.data = hunk
-    cells.append(cell)
-
-  class AsyncMock(Mock):
-    async def __aenter__(self):
-      return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-      return
-
-  circ_mock = AsyncMock()
-  circ_mock.directory.side_effect = coro_func_returning_value(data)
-
-  relay_mock = AsyncMock()
-  relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
-
-  return coro_func_returning_value(relay_mock)
-
-
-def _dirport_mock(data, encoding = 'identity'):
-  dirport_mock = Mock()
-  dirport_mock().read.return_value = data
-
-  headers = http.client.HTTPMessage()
-
-  for line in HEADER.splitlines():
-    key, value = line.split(': ', 1)
-    headers.add_header(key, encoding if key == 'Content-Encoding' else value)
-
-  dirport_mock().headers = headers
-
-  return dirport_mock
+  return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
 
 
 class TestDescriptorDownloader(unittest.TestCase):
@@ -115,10 +78,10 @@ class TestDescriptorDownloader(unittest.TestCase):
     # prevent our mocks from impacting other tests
     stem.descriptor.remote.SINGLETON_DOWNLOADER = None
 
-  @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR))
-  def test_using_orport(self):
+  @mock_download(TEST_DESCRIPTOR)
+  def test_download(self):
     """
-    Download a descriptor through the ORPort.
+    Simply download and parse a descriptor.
     """
 
     reply = stem.descriptor.remote.their_server_descriptor(
@@ -128,10 +91,16 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     self.assertEqual(1, len(list(reply)))
-    self.assertEqual('moria1', list(reply)[0].nickname)
     self.assertEqual(5, len(reply.reply_headers))
 
-  def test_orport_response_code_headers(self):
+    desc = list(reply)[0]
+
+    self.assertEqual('moria1', desc.nickname)
+    self.assertEqual('128.31.0.34', desc.address)
+    self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
+    self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
+
+  def test_response_header_code(self):
     """
     When successful Tor provides a '200 OK' status, but we should accept other 2xx
     response codes, reason text, and recognize HTTP errors.
@@ -144,14 +113,14 @@ class TestDescriptorDownloader(unittest.TestCase):
     )
 
     for header in response_code_headers:
-      with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = header)):
+      with mock_download(TEST_DESCRIPTOR, response_code_header = header):
         stem.descriptor.remote.their_server_descriptor(
           endpoints = [stem.ORPort('12.34.56.78', 1100)],
           validate = True,
           skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
         ).run()
 
-    with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n')):
+    with mock_download(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n'):
       request = stem.descriptor.remote.their_server_descriptor(
         endpoints = [stem.ORPort('12.34.56.78', 1100)],
         validate = True,
@@ -160,28 +129,32 @@ class TestDescriptorDownloader(unittest.TestCase):
 
       self.assertRaisesRegexp(stem.ProtocolError, "^Response should begin with HTTP success, but was 'HTTP/1.0 500 Kaboom'", request.run)
 
-  @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
-  def test_using_dirport(self):
-    """
-    Download a descriptor through the DirPort.
-    """
+  @mock_download(TEST_DESCRIPTOR)
+  def test_reply_header_data(self):
+    query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
+    self.assertEqual(None, query.reply_headers)  # initially we don't have a reply
+    query.run()
 
-    reply = stem.descriptor.remote.their_server_descriptor(
-      endpoints = [stem.DirPort('12.34.56.78', 1100)],
-      validate = True,
-      skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
-    )
+    self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
+    self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
+    self.assertEqual('97.103.17.56', query.reply_headers.get('X-Your-Address-Is'))
+    self.assertEqual('no-cache', query.reply_headers.get('Pragma'))
+    self.assertEqual('identity', query.reply_headers.get('Content-Encoding'))
 
-    self.assertEqual(1, len(list(reply)))
-    self.assertEqual('moria1', list(reply)[0].nickname)
-    self.assertEqual(5, len(reply.reply_headers))
+    # request a header that isn't present
+    self.assertEqual(None, query.reply_headers.get('no-such-header'))
+    self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
+
+    descriptors = list(query)
+    self.assertEqual(1, len(descriptors))
+    self.assertEqual('moria1', descriptors[0].nickname)
 
   def test_gzip_url_override(self):
     query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
     self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
     self.assertEqual(TEST_RESOURCE, query.resource)
 
-  @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_identity'), encoding = 'identity'))
+  @mock_download(read_resource('compressed_identity'), encoding = 'identity')
   def test_compression_plaintext(self):
     """
     Download a plaintext descriptor.
@@ -197,7 +170,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
 
-  @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_gzip'), encoding = 'gzip'))
+  @mock_download(read_resource('compressed_gzip'), encoding = 'gzip')
   def test_compression_gzip(self):
     """
     Download a gip compressed descriptor.
@@ -213,7 +186,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
 
-  @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_zstd'), encoding = 'x-zstd'))
+  @mock_download(read_resource('compressed_zstd'), encoding = 'x-zstd')
   def test_compression_zstd(self):
     """
     Download a zstd compressed descriptor.
@@ -231,7 +204,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
 
-  @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_lzma'), encoding = 'x-tor-lzma'))
+  @mock_download(read_resource('compressed_lzma'), encoding = 'x-tor-lzma')
   def test_compression_lzma(self):
     """
     Download a lzma compressed descriptor.
@@ -249,8 +222,8 @@ class TestDescriptorDownloader(unittest.TestCase):
     self.assertEqual(1, len(descriptors))
     self.assertEqual('moria1', descriptors[0].nickname)
 
-  @patch('urllib.request.urlopen')
-  def test_each_getter(self, dirport_mock):
+  @mock_download(TEST_DESCRIPTOR)
+  def test_each_getter(self):
     """
     Surface level exercising of each getter method for downloading descriptors.
     """
@@ -266,57 +239,8 @@ class TestDescriptorDownloader(unittest.TestCase):
     downloader.get_bandwidth_file()
     downloader.get_detached_signatures()
 
-  @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
-  def test_reply_headers(self):
-    query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
-    self.assertEqual(None, query.reply_headers)  # initially we don't have a reply
-    query.run()
-
-    self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('date'))
-    self.assertEqual('application/octet-stream', query.reply_headers.get('content-type'))
-    self.assertEqual('97.103.17.56', query.reply_headers.get('x-your-address-is'))
-    self.assertEqual('no-cache', query.reply_headers.get('pragma'))
-    self.assertEqual('identity', query.reply_headers.get('content-encoding'))
-
-    # getting headers should be case insensitive
-    self.assertEqual('identity', query.reply_headers.get('CoNtEnT-ENCODING'))
-
-    # request a header that isn't present
-    self.assertEqual(None, query.reply_headers.get('no-such-header'))
-    self.assertEqual('default', query.reply_headers.get('no-such-header', 'default'))
-
-    descriptors = list(query)
-    self.assertEqual(1, len(descriptors))
-    self.assertEqual('moria1', descriptors[0].nickname)
-
-  @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
-  def test_query_download(self):
-    """
-    Check Query functionality when we successfully download a descriptor.
-    """
-
-    query = stem.descriptor.remote.Query(
-      TEST_RESOURCE,
-      'server-descriptor 1.0',
-      endpoints = [stem.DirPort('128.31.0.39', 9131)],
-      compression = Compression.PLAINTEXT,
-      validate = True,
-      skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
-    )
-
-    self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
-
-    descriptors = list(query)
-    self.assertEqual(1, len(descriptors))
-    desc = descriptors[0]
-
-    self.assertEqual('moria1', desc.nickname)
-    self.assertEqual('128.31.0.34', desc.address)
-    self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
-    self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
-
-  @patch('urllib.request.urlopen', _dirport_mock(b'some malformed stuff'))
-  def test_query_with_malformed_content(self):
+  @mock_download(b'some malformed stuff')
+  def test_malformed_content(self):
     """
     Query with malformed descriptor content.
     """
@@ -340,29 +264,6 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     self.assertRaises(ValueError, query.run)
 
-  @patch('urllib.request.urlopen')
-  def test_query_with_timeout(self, dirport_mock):
-    def urlopen_call(*args, **kwargs):
-      time.sleep(0.06)
-      raise socket.timeout('connection timed out')
-
-    dirport_mock.side_effect = urlopen_call
-
-    query = stem.descriptor.remote.Query(
-      TEST_RESOURCE,
-      'server-descriptor 1.0',
-      endpoints = [stem.DirPort('128.31.0.39', 9131)],
-      fall_back_to_authority = False,
-      timeout = 0.1,
-      validate = True,
-    )
-
-    # After two requests we'll have reached our total permissable timeout.
-    # It would be nice to check that we don't make a third, but this
-    # assertion has proved unreliable so only checking for the exception.
-
-    self.assertRaises(stem.DownloadTimeout, query.run)
-
   def test_query_with_invalid_endpoints(self):
     invalid_endpoints = {
       'hello': "'h' is a str.",
@@ -375,7 +276,7 @@ class TestDescriptorDownloader(unittest.TestCase):
       expected_error = 'Endpoints must be an stem.ORPort or stem.DirPort. ' + error_suffix
       self.assertRaisesWith(ValueError, expected_error, stem.descriptor.remote.Query, TEST_RESOURCE, 'server-descriptor 1.0', endpoints = endpoints)
 
-  @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR))
+  @mock_download(TEST_DESCRIPTOR)
   def test_can_iterate_multiple_times(self):
     query = stem.descriptor.remote.Query(
       TEST_RESOURCE,





More information about the tor-commits mailing list