[tor-commits] [stem/master] Helper for downloading from dirports

atagar at torproject.org atagar at torproject.org
Tue Apr 24 19:41:39 UTC 2018


commit 3b5343e8c35044a1c05af09d7153b7f259e6d979
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Apr 21 12:40:39 2018 -0700

    Helper for downloading from dirports
    
    Our _download_descriptors() has grown pretty big. Downloading and decompressing
    DirPort responses are a good candidate for a little helper.
---
 stem/__init__.py               |  13 +++++
 stem/descriptor/remote.py      | 107 +++++++++++++++++++++++------------------
 test/unit/descriptor/remote.py |   3 +-
 test/unit/endpoint.py          |  14 ++++++
 4 files changed, 89 insertions(+), 48 deletions(-)

diff --git a/stem/__init__.py b/stem/__init__.py
index c8c43f65..c7432d86 100644
--- a/stem/__init__.py
+++ b/stem/__init__.py
@@ -476,6 +476,7 @@ Library for working with the tor process.
   ================= ===========
 """
 
+import stem.util
 import stem.util.enum
 
 __version__ = '1.6.0-dev'
@@ -637,6 +638,15 @@ class Endpoint(object):
     self.address = address
     self.port = int(port)
 
+  def __hash__(self):
+    return stem.util._hash_attr(self, 'address', 'port')
+
+  def __eq__(self, other):
+    return hash(self) == hash(other) if isinstance(other, Endpoint) else False
+
+  def __ne__(self, other):
+    return not self == other
+
 
 class ORPort(Endpoint):
   """
@@ -649,6 +659,9 @@ class ORPort(Endpoint):
     super(ORPort, self).__init__(address, port)
     self.link_protocols = link_protocols
 
+  def __hash__(self):
+    return super(ORPort, self).__hash__() + stem.util._hash_attr(self, 'link_protocols', 'port') * 10
+
 
 class DirPort(Endpoint):
   """
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 5b942787..8835e8d9 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -236,6 +236,57 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg
   return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
 
 
+def _download_from_dirport(url, compression, timeout):
+  """
+  Downloads descriptors from the given url.
+
+  :param str url: dirport url from which to download from
+  :param list compression: compression methods for the request
+  :param float timeout: duration before we'll time out our request
+
+  :returns: two value tuple of the form (data, reply_headers)
+
+  :raises:
+    * **socket.timeout** if our request timed out
+    * **urllib2.URLError** for most request failures
+  """
+
+  response = urllib.urlopen(
+    urllib.Request(
+      url,
+      headers = {
+        'Accept-Encoding': ', '.join(compression),
+        'User-Agent': 'Stem/%s' % stem.__version__,
+      }
+    ),
+    timeout = timeout,
+  )
+
+  data = response.read()
+  encoding = response.headers.get('Content-Encoding')
+
+  # Tor doesn't include compression headers. As such when using gzip we
+  # need to include '32' for automatic header detection...
+  #
+  #   https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
+  #
+  # ... and with zstd we need to use the streaming API.
+
+  if encoding in (Compression.GZIP, 'deflate'):
+    data = zlib.decompress(data, zlib.MAX_WBITS | 32)
+  elif encoding == Compression.ZSTD and ZSTD_SUPPORTED:
+    output_buffer = io.BytesIO()
+
+    with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor:
+      decompressor.write(data)
+
+    data = output_buffer.getvalue()
+  elif encoding == Compression.LZMA and LZMA_SUPPORTED:
+    data = lzma.decompress(data)
+
+  return data.strip(), response.headers
+
+
 def _guess_descriptor_type(resource):
   # Attempts to determine the descriptor type based on the resource url. This
   # raises a ValueError if the resource isn't recognized.
@@ -411,7 +462,7 @@ class Query(object):
     if endpoints:
       for endpoint in endpoints:
         if isinstance(endpoint, tuple) and len(endpoint) == 2:
-          self.endpoints.append(stem.DirPort(endpoint[0], endpoint[1]))
+          self.endpoints.append(stem.DirPort(endpoint[0], endpoint[1]))  # TODO: remove this in stem 2.0
         elif isinstance(endpoint, (stem.ORPort, stem.DirPort)):
           self.endpoints.append(endpoint)
         else:
@@ -524,10 +575,10 @@ class Query(object):
     for desc in self._run(True):
       yield desc
 
-  def _pick_url(self, use_authority = False):
+  def _pick_endpoint(self, use_authority = False):
     """
-    Provides a url that can be queried. If we have multiple endpoints then one
-    will be picked randomly.
+    Provides an endpoint to query. If we have multiple endpoints then one
+    is picked at random.
 
     :param bool use_authority: ignores our endpoints and uses a directory
       authority instead
@@ -539,54 +590,18 @@ class Query(object):
       directories = get_authorities().values()
 
       picked = random.choice(list(directories))
-      address, dirport = picked.address, picked.dir_port
+      return stem.DirPort(picked.address, picked.dir_port)
     else:
-      picked = random.choice(self.endpoints)
-      address, dirport = picked.address, picked.port
-
-    return 'http://%s:%i/%s' % (address, dirport, self.resource.lstrip('/'))
+      return random.choice(self.endpoints)
 
   def _download_descriptors(self, retries, timeout):
     try:
       use_authority = retries == 0 and self.fall_back_to_authority
-      self.download_url = self._pick_url(use_authority)
-      self.start_time = time.time()
-
-      response = urllib.urlopen(
-        urllib.Request(
-          self.download_url,
-          headers = {
-            'Accept-Encoding': ', '.join(self.compression),
-            'User-Agent': 'Stem/%s' % stem.__version__,
-          }
-        ),
-        timeout = timeout,
-      )
+      endpoint = self._pick_endpoint(use_authority)
+      self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
 
-      data = response.read()
-      encoding = response.headers.get('Content-Encoding')
-
-      # Tor doesn't include compression headers. As such when using gzip we
-      # need to include '32' for automatic header detection...
-      #
-      #   https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompressing-incorrect-header-check/22310760#22310760
-      #
-      # ... and with zstd we need to use the streaming API.
-
-      if encoding in (Compression.GZIP, 'deflate'):
-        data = zlib.decompress(data, zlib.MAX_WBITS | 32)
-      elif encoding == Compression.ZSTD and ZSTD_SUPPORTED:
-        output_buffer = io.BytesIO()
-
-        with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor:
-          decompressor.write(data)
-
-        data = output_buffer.getvalue()
-      elif encoding == Compression.LZMA and LZMA_SUPPORTED:
-        data = lzma.decompress(data)
-
-      self.content = data.strip()
-      self.reply_headers = response.headers
+      self.start_time = time.time()
+      self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
       self.runtime = time.time() - self.start_time
       log.trace("Descriptors retrieved from '%s' in %0.2fs" % (self.download_url, self.runtime))
     except:
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index deb10060..b2cfa7c1 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -282,8 +282,7 @@ class TestDescriptorDownloader(unittest.TestCase):
       validate = True,
     )
 
-    expeced_url = 'http://128.31.0.39:9131' + TEST_RESOURCE
-    self.assertEqual(expeced_url, query._pick_url())
+    self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint())
 
     descriptors = list(query)
     self.assertEqual(1, len(descriptors))
diff --git a/test/unit/endpoint.py b/test/unit/endpoint.py
index ad710a22..397323a7 100644
--- a/test/unit/endpoint.py
+++ b/test/unit/endpoint.py
@@ -30,3 +30,17 @@ class TestEndpoint(unittest.TestCase):
     self.assertRaises(ValueError, stem.DirPort, 'hello', 80)
     self.assertRaises(ValueError, stem.DirPort, -5, 80)
     self.assertRaises(ValueError, stem.DirPort, None, 80)
+
+  def test_equality(self):
+    self.assertTrue(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 80))
+    self.assertTrue(stem.ORPort('12.34.56.78', 80, [1, 2, 3]) == stem.ORPort('12.34.56.78', 80, [1, 2, 3]))
+    self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.88', 80))
+    self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 443))
+    self.assertFalse(stem.ORPort('12.34.56.78', 80, [2, 3]) == stem.ORPort('12.34.56.78', 80, [1, 2, 3]))
+
+    self.assertTrue(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 80))
+    self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.88', 80))
+    self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 443))
+
+    self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 80))
+    self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 80))





More information about the tor-commits mailing list