[tor-commits] [stem/master] Use Compression class in stem.descriptor.remote

atagar at torproject.org atagar at torproject.org
Sat Aug 17 20:44:26 UTC 2019


commit 0c84d3cf4dcedd3629e6bb51c607240b1b2257b9
Author: Damian Johnson <atagar at torproject.org>
Date:   Wed Jun 19 15:55:39 2019 -0700

    Use Compression class in stem.descriptor.remote
    
    Now that our new Compression class is in place, deduplicating the remote
    module's decompression code with it.
---
 stem/descriptor/__init__.py    |  9 ++++--
 stem/descriptor/remote.py      | 67 ++++++++++++++++++++++++------------------
 test/unit/descriptor/remote.py | 11 +++----
 3 files changed, 52 insertions(+), 35 deletions(-)

diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py
index 0b3fda91..2c69aef0 100644
--- a/stem/descriptor/__init__.py
+++ b/stem/descriptor/__init__.py
@@ -225,7 +225,12 @@ class _Compression(object):
     """
 
     if not self.available:
-      raise ImportError("'%s' decompression module is unavailable" % self._module_name)
+      if self.name == 'zstd':
+        raise ImportError('Decompressing zstd data requires https://pypi.org/project/zstandard/')
+      elif self.name == 'lzma':
+        raise ImportError('Decompressing lzma data requires https://docs.python.org/3/library/lzma.html')
+      else:
+        raise ImportError("'%s' decompression module is unavailable" % self._module_name)
 
     return self._decompression_func(self._module, content)
 
@@ -247,7 +252,7 @@ Compression = stem.util.enum.Enum(
   ('GZIP', _Compression('gzip', 'zlib', 'gzip', '.gz', lambda module, content: module.decompress(content, module.MAX_WBITS | 32))),
   ('BZ2', _Compression('bzip2', 'bz2', 'bzip2', '.bz2', lambda module, content: module.decompress(content))),
   ('LZMA', _Compression('lzma', 'lzma', 'x-tor-lzma', '.xz', lambda module, content: module.decompress(content))),
-  ('ZSTD', _Compression('zstd', 'zstd', 'zstd', '.zst', _zstd_decompress)),
+  ('ZSTD', _Compression('zstd', 'zstd', 'x-zstd', '.zst', _zstd_decompress)),
 )
 
 
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 15f46070..f715d743 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -100,7 +100,6 @@ import random
 import sys
 import threading
 import time
-import zlib
 
 import stem
 import stem.client
@@ -119,6 +118,8 @@ try:
 except ImportError:
   import urllib2 as urllib
 
+# TODO: remove in stem 2.x, replaced with stem.descriptor.Compression
+
 Compression = stem.util.enum.Enum(
   ('PLAINTEXT', 'identity'),
   ('GZIP', 'gzip'),  # can also be 'deflate'
@@ -126,6 +127,13 @@ Compression = stem.util.enum.Enum(
   ('LZMA', 'x-tor-lzma'),
 )
 
+COMPRESSION_MIGRATION = {
+  'identity': stem.descriptor.Compression.PLAINTEXT,
+  'gzip': stem.descriptor.Compression.GZIP,
+  'x-zstd': stem.descriptor.Compression.ZSTD,
+  'x-tor-lzma': stem.descriptor.Compression.LZMA,
+}
+
 # Tor has a limited number of descriptors we can fetch explicitly by their
 # fingerprint or hashes due to a limit on the url length by squid proxies.
 
@@ -364,6 +372,11 @@ class Query(object):
   .. versionchanged:: 1.8.0
      Defaulting to gzip compression rather than plaintext downloads.
 
+  .. versionchanged:: 1.8.0
+     Using :class:`~stem.descriptor.__init__.Compression` for our compression
+     argument, usage of strings or this module's Compression enum is deprecated
+     and will be removed in stem 2.x.
+
   :var str resource: resource being fetched, such as '/tor/server/all'
   :var str descriptor_type: type of descriptors being fetched (for options see
     :func:`~stem.descriptor.__init__.parse_file`), this is guessed from the
@@ -371,7 +384,7 @@ class Query(object):
 
   :var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the
     authority or mirror we're querying, this uses authorities if undefined
-  :var list compression: list of :data:`stem.descriptor.remote.Compression`
+  :var list compression: list of :data:`stem.descriptor.Compression`
     we're willing to accept, when none are mutually supported downloads fall
     back to Compression.PLAINTEXT
   :var int retries: number of times to attempt the request if downloading it
@@ -429,6 +442,19 @@ class Query(object):
       if not compression:
         compression = [Compression.PLAINTEXT]
 
+    # TODO: Normalize from our old compression enum to
+    # stem.descriptor.Compression. This will get removed in Stem 2.x.
+
+    new_compression = []
+
+    for legacy_compression in compression:
+      if isinstance(legacy_compression, stem.descriptor._Compression):
+        new_compression.append(legacy_compression)
+      elif legacy_compression in COMPRESSION_MIGRATION:
+        new_compression.append(COMPRESSION_MIGRATION[legacy_compression])
+      else:
+        raise ValueError("'%s' (%s) is not a recognized type of compression" % (legacy_compression, type(legacy_compression).__name__))
+
     if descriptor_type:
       self.descriptor_type = descriptor_type
     else:
@@ -446,7 +472,7 @@ class Query(object):
           raise ValueError("Endpoints must be an stem.ORPort, stem.DirPort, or two value tuple. '%s' is a %s." % (endpoint, type(endpoint).__name__))
 
     self.resource = resource
-    self.compression = compression
+    self.compression = new_compression
     self.retries = retries
     self.fall_back_to_authority = fall_back_to_authority
 
@@ -1009,7 +1035,7 @@ def _download_from_orport(endpoint, compression, resource):
     with relay.create_circuit() as circ:
       request = '\r\n'.join((
         'GET %s HTTP/1.0' % resource,
-        'Accept-Encoding: %s' % ', '.join(compression),
+        'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
         'User-Agent: %s' % stem.USER_AGENT,
       )) + '\r\n\r\n'
 
@@ -1051,7 +1077,7 @@ def _download_from_dirport(url, compression, timeout):
     urllib.Request(
       url,
       headers = {
-        'Accept-Encoding': ', '.join(compression),
+        'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
         'User-Agent': stem.USER_AGENT,
       }
     ),
@@ -1080,29 +1106,14 @@ def _decompress(data, encoding):
     * **ImportError** if missing the decompression module
   """
 
-  if encoding == Compression.PLAINTEXT:
-    return data
-  elif encoding in (Compression.GZIP, 'deflate'):
-    return zlib.decompress(data, zlib.MAX_WBITS | 32)
-  elif encoding == Compression.ZSTD:
-    if not stem.prereq.is_zstd_available():
-      raise ImportError('Decompressing zstd data requires https://pypi.org/project/zstandard/')
-
-    import zstd
-    output_buffer = io.BytesIO()
-
-    with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor:
-      decompressor.write(data)
-
-    return output_buffer.getvalue()
-  elif encoding == Compression.LZMA:
-    if not stem.prereq.is_lzma_available():
-      raise ImportError('Decompressing lzma data requires https://docs.python.org/3/library/lzma.html')
-
-    import lzma
-    return lzma.decompress(data)
-  else:
-    raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
+  if encoding == 'deflate':
+    return stem.descriptor.Compression.GZIP.decompress(data)
+
+  for compression in stem.descriptor.Compression:
+    if encoding == compression.encoding:
+      return compression.decompress(data)
+
+  raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
 
 
 def _guess_descriptor_type(resource):
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 8c5e835b..6dbaf43e 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -8,6 +8,7 @@ import time
 import unittest
 
 import stem
+import stem.descriptor
 import stem.descriptor.remote
 import stem.prereq
 import stem.util.str_tools
@@ -181,26 +182,26 @@ class TestDescriptorDownloader(unittest.TestCase):
 
   def test_gzip_url_override(self):
     query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
-    self.assertEqual([Compression.GZIP], query.compression)
+    self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
     self.assertEqual(TEST_RESOURCE, query.resource)
 
   def test_zstd_support_check(self):
     with patch('stem.prereq.is_zstd_available', Mock(return_value = True)):
       query = stem.descriptor.remote.Query(TEST_RESOURCE, compression = Compression.ZSTD, start = False)
-      self.assertEqual([Compression.ZSTD], query.compression)
+      self.assertEqual([stem.descriptor.Compression.ZSTD], query.compression)
 
     with patch('stem.prereq.is_zstd_available', Mock(return_value = False)):
       query = stem.descriptor.remote.Query(TEST_RESOURCE, compression = Compression.ZSTD, start = False)
-      self.assertEqual([Compression.PLAINTEXT], query.compression)
+      self.assertEqual([stem.descriptor.Compression.PLAINTEXT], query.compression)
 
   def test_lzma_support_check(self):
     with patch('stem.prereq.is_lzma_available', Mock(return_value = True)):
       query = stem.descriptor.remote.Query(TEST_RESOURCE, compression = Compression.LZMA, start = False)
-      self.assertEqual([Compression.LZMA], query.compression)
+      self.assertEqual([stem.descriptor.Compression.LZMA], query.compression)
 
     with patch('stem.prereq.is_lzma_available', Mock(return_value = False)):
       query = stem.descriptor.remote.Query(TEST_RESOURCE, compression = Compression.LZMA, start = False)
-      self.assertEqual([Compression.PLAINTEXT], query.compression)
+      self.assertEqual([stem.descriptor.Compression.PLAINTEXT], query.compression)
 
   @patch(URL_OPEN, _dirport_mock(read_resource('compressed_identity'), encoding = 'identity'))
   def test_compression_plaintext(self):





More information about the tor-commits mailing list