[tor-commits] [stem/master] Compression support for ORPort descriptor downloads

atagar at torproject.org atagar at torproject.org
Sun Apr 29 01:30:17 UTC 2018


commit 5a875ed329ddb119c2b6787e051571fb7323f622
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Apr 28 16:47:42 2018 -0700

    Compression support for ORPort descriptor downloads
    
    Compression works the same regardless of of if we download from an ORPort or
    DirPort. Also including a couple python3 compatibility fixes for circuit
    construction.
---
 stem/client/__init__.py        |  4 +-
 stem/descriptor/remote.py      | 88 +++++++++++++++++++++++++++---------------
 test/unit/descriptor/remote.py | 10 ++---
 3 files changed, 62 insertions(+), 40 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index aa4aa274..6e25f748 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -168,7 +168,7 @@ class Relay(object):
       if not created_fast_cells:
         raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
 
-      created_fast_cell = created_fast_cells[0]
+      created_fast_cell = list(created_fast_cells)[0]
       kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
 
       if created_fast_cell.derivative_key != kdf.key_hash:
@@ -211,7 +211,7 @@ class Circuit(object):
     from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
     from cryptography.hazmat.backends import default_backend
 
-    ctr = modes.CTR(ZERO * (algorithms.AES.block_size / 8))
+    ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8))
 
     self.relay = relay
     self.id = circ_id
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e745ce20..d6832eaa 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -254,7 +254,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg
   return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
 
 
-def _download_from_orport(endpoint, resource):
+def _download_from_orport(endpoint, compression, resource):
   """
   Downloads descriptors from the given orport. Payload is just like an http
   response (headers and all)...
@@ -273,6 +273,7 @@ def _download_from_orport(endpoint, resource):
     ... rest of the descriptor content...
 
   :param stem.ORPort endpoint: endpoint to download from
+  :param list compression: compression methods for the request
   :param str resource: descriptor resource to download
 
   :returns: two value tuple of the form (data, reply_headers)
@@ -286,26 +287,30 @@ def _download_from_orport(endpoint, resource):
 
   with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
     with relay.create_circuit() as circ:
+      request = '\r\n'.join((
+        'GET %s HTTP/1.0' % resource,
+        'Accept-Encoding: %s' % ', '.join(compression),
+        'User-Agent: Stem/%s' % stem.__version__,
+      )) + '\r\n\r\n'
+
       circ.send('RELAY_BEGIN_DIR', stream_id = 1)
-      lines = b''.join([cell.data for cell in circ.send('RELAY_DATA', 'GET %s HTTP/1.0\r\n\r\n' % resource, stream_id = 1)]).splitlines()
-      first_line = lines.pop(0)
+      response = b''.join([cell.data for cell in circ.send('RELAY_DATA', request, stream_id = 1)])
+      first_line, data = response.split(b'\r\n', 1)
+      header_data, data = data.split(b'\r\n\r\n', 1)
 
-      if first_line != 'HTTP/1.0 200 OK':
+      if first_line != b'HTTP/1.0 200 OK':
         raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % first_line)
 
       headers = {}
-      next_line = lines.pop(0)
 
-      while next_line:
-        if ': ' not in next_line:
-          raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % next_line)
+      for line in str_tools._to_unicode(header_data).splitlines():
+        if ': ' not in line:
+          raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line)
 
-        key, value = next_line.split(': ', 1)
+        key, value = line.split(': ', 1)
         headers[key] = value
 
-        next_line = lines.pop(0)
-
-      return '\n'.join(lines), headers
+      return _decompress(data, headers.get('Content-Encoding')), headers
 
 
 def _download_from_dirport(url, compression, timeout):
@@ -334,29 +339,49 @@ def _download_from_dirport(url, compression, timeout):
     timeout = timeout,
   )
 
-  data = response.read()
-  encoding = response.headers.get('Content-Encoding')
+  return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
 
-  # Tor doesn't include compression headers. As such when using gzip we
-  # need to include '32' for automatic header detection...
-  #
-  #   https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
-  #
-  # ... and with zstd we need to use the streaming API.
 
-  if encoding in (Compression.GZIP, 'deflate'):
-    data = zlib.decompress(data, zlib.MAX_WBITS | 32)
-  elif encoding == Compression.ZSTD and ZSTD_SUPPORTED:
+def _decompress(data, encoding):
+  """
+  Decompresses descriptor data.
+
+  Tor doesn't include compression headers. As such when using gzip we
+  need to include '32' for automatic header detection...
+
+    https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
+
+  ... and with zstd we need to use the streaming API.
+
+  :param bytes data: data we received
+  :param str encoding: 'Content-Encoding' header of the response
+
+  :raises:
+    * **ValueError** if encoding is unrecognized
+    * **ImportError** if missing the decompression module
+  """
+
+  if encoding == Compression.PLAINTEXT:
+    return data.strip()
+  elif encoding in (Compression.GZIP, 'deflate'):
+    return zlib.decompress(data, zlib.MAX_WBITS | 32).strip()
+  elif encoding == Compression.ZSTD:
+    if not ZSTD_SUPPORTED:
+      raise ImportError('Decompressing zstd data requires https://pypi.python.org/pypi/zstandard')
+
     output_buffer = io.BytesIO()
 
     with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor:
       decompressor.write(data)
 
-    data = output_buffer.getvalue()
-  elif encoding == Compression.LZMA and LZMA_SUPPORTED:
-    data = lzma.decompress(data)
+    return output_buffer.getvalue().strip()
+  elif encoding == Compression.LZMA:
+    if not LZMA_SUPPORTED:
+      raise ImportError('Decompressing lzma data requires https://docs.python.org/3/library/lzma.html')
 
-  return data.strip(), response.headers
+    return lzma.decompress(data).strip()
+  else:
+    raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
 
 
 def _guess_descriptor_type(resource):
@@ -476,6 +501,9 @@ class Query(object):
 
   :var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the
     authority or mirror we're querying, this uses authorities if undefined
+  :var list compression: list of :data:`stem.descriptor.remote.Compression`
+    we're willing to accept, when none are mutually supported downloads fall
+    back to Compression.PLAINTEXT
   :var int retries: number of times to attempt the request if downloading it
     fails
   :var bool fall_back_to_authority: when retrying request issues the last
@@ -500,11 +528,7 @@ class Query(object):
   Following are only applicable when downloading from a
   :class:`~stem.DirPort`...
 
-  :var list compression: list of :data:`stem.descriptor.remote.Compression`
-    we're willing to accept, when none are mutually supported downloads fall
-    back to Compression.PLAINTEXT
   :var float timeout: duration before we'll time out our request
-
   :var str download_url: last url used to download the descriptor, this is
     unset until we've actually made a download attempt
 
@@ -683,7 +707,7 @@ class Query(object):
       endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
 
       if isinstance(endpoint, stem.ORPort):
-        self.content, self.reply_headers = _download_from_orport(endpoint, self.resource)
+        self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource)
       elif isinstance(endpoint, stem.DirPort):
         self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
         self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 040d4c6c..753681e0 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -12,6 +12,7 @@ import unittest
 import stem.descriptor.remote
 import stem.prereq
 import stem.util.conf
+import stem.util.str_tools
 
 from stem.descriptor.remote import Compression
 from test.unit.descriptor import read_resource
@@ -126,10 +127,9 @@ HEADER = '\r\n'.join([
   'Content-Encoding: %s',
 ])
 
-ORPORT_DESCRIPTOR = 'HTTP/1.0 200 OK\n' + HEADER + '\n\n' + TEST_DESCRIPTOR
 
-
-def _orport_mock(data):
+def _orport_mock(data, encoding = 'identity'):
+  data = b'HTTP/1.0 200 OK\r\n' + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data
   cells = []
 
   for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]:
@@ -167,7 +167,7 @@ class TestDescriptorDownloader(unittest.TestCase):
     # prevent our mocks from impacting other tests
     stem.descriptor.remote.SINGLETON_DOWNLOADER = None
 
-  @patch('stem.client.Relay.connect', _orport_mock(ORPORT_DESCRIPTOR))
+  @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR))
   def test_using_orport(self):
     """
     Download a descriptor through the ORPort.
@@ -175,7 +175,6 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     reply = stem.descriptor.remote.their_server_descriptor(
       endpoints = [stem.ORPort('12.34.56.78', 1100)],
-      fall_back_to_authority = False,
       validate = True,
     )
 
@@ -191,7 +190,6 @@ class TestDescriptorDownloader(unittest.TestCase):
 
     reply = stem.descriptor.remote.their_server_descriptor(
       endpoints = [stem.DirPort('12.34.56.78', 1100)],
-      fall_back_to_authority = False,
       validate = True,
     )
 





More information about the tor-commits mailing list