commit 3b5343e8c35044a1c05af09d7153b7f259e6d979 Author: Damian Johnson atagar@torproject.org Date: Sat Apr 21 12:40:39 2018 -0700
Helper for downloading from dirports
Our _download_descriptors() has grown pretty big. Downloading and decompressing DirPort responses are a good candidate for a little helper. --- stem/__init__.py | 13 +++++ stem/descriptor/remote.py | 107 +++++++++++++++++++++++------------------ test/unit/descriptor/remote.py | 3 +- test/unit/endpoint.py | 14 ++++++ 4 files changed, 89 insertions(+), 48 deletions(-)
diff --git a/stem/__init__.py b/stem/__init__.py index c8c43f65..c7432d86 100644 --- a/stem/__init__.py +++ b/stem/__init__.py @@ -476,6 +476,7 @@ Library for working with the tor process. ================= =========== """
+import stem.util import stem.util.enum
__version__ = '1.6.0-dev' @@ -637,6 +638,15 @@ class Endpoint(object): self.address = address self.port = int(port)
+ def __hash__(self): + return stem.util._hash_attr(self, 'address', 'port') + + def __eq__(self, other): + return hash(self) == hash(other) if isinstance(other, Endpoint) else False + + def __ne__(self, other): + return not self == other +
class ORPort(Endpoint): """ @@ -649,6 +659,9 @@ class ORPort(Endpoint): super(ORPort, self).__init__(address, port) self.link_protocols = link_protocols
+ def __hash__(self): + return super(ORPort, self).__hash__() + stem.util._hash_attr(self, 'link_protocols', 'port') * 10 +
class DirPort(Endpoint): """ diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index 5b942787..8835e8d9 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -236,6 +236,57 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
+def _download_from_dirport(url, compression, timeout): + """ + Downloads descriptors from the given url. + + :param str url: dirport url from which to download from + :param list compression: compression methods for the request + :param float timeout: duration before we'll time out our request + + :returns: two value tuple of the form (data, reply_headers) + + :raises: + * **socket.timeout** if our request timed out + * **urllib2.URLError** for most request failures + """ + + response = urllib.urlopen( + urllib.Request( + url, + headers = { + 'Accept-Encoding': ', '.join(compression), + 'User-Agent': 'Stem/%s' % stem.__version__, + } + ), + timeout = timeout, + ) + + data = response.read() + encoding = response.headers.get('Content-Encoding') + + # Tor doesn't include compression headers. As such when using gzip we + # need to include '32' for automatic header detection... + # + # https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompr... + # + # ... and with zstd we need to use the streaming API. + + if encoding in (Compression.GZIP, 'deflate'): + data = zlib.decompress(data, zlib.MAX_WBITS | 32) + elif encoding == Compression.ZSTD and ZSTD_SUPPORTED: + output_buffer = io.BytesIO() + + with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor: + decompressor.write(data) + + data = output_buffer.getvalue() + elif encoding == Compression.LZMA and LZMA_SUPPORTED: + data = lzma.decompress(data) + + return data.strip(), response.headers + + def _guess_descriptor_type(resource): # Attempts to determine the descriptor type based on the resource url. This # raises a ValueError if the resource isn't recognized. @@ -411,7 +462,7 @@ class Query(object): if endpoints: for endpoint in endpoints: if isinstance(endpoint, tuple) and len(endpoint) == 2: - self.endpoints.append(stem.DirPort(endpoint[0], endpoint[1])) + self.endpoints.append(stem.DirPort(endpoint[0], endpoint[1])) # TODO: remove this in stem 2.0 elif isinstance(endpoint, (stem.ORPort, stem.DirPort)): self.endpoints.append(endpoint) else: @@ -524,10 +575,10 @@ class Query(object): for desc in self._run(True): yield desc
- def _pick_url(self, use_authority = False): + def _pick_endpoint(self, use_authority = False): """ - Provides a url that can be queried. If we have multiple endpoints then one - will be picked randomly. + Provides an endpoint to query. If we have multiple endpoints then one + is picked at random.
:param bool use_authority: ignores our endpoints and uses a directory authority instead @@ -539,54 +590,18 @@ class Query(object): directories = get_authorities().values()
picked = random.choice(list(directories)) - address, dirport = picked.address, picked.dir_port + return stem.DirPort(picked.address, picked.dir_port) else: - picked = random.choice(self.endpoints) - address, dirport = picked.address, picked.port - - return 'http://%s:%i/%s' % (address, dirport, self.resource.lstrip('/')) + return random.choice(self.endpoints)
def _download_descriptors(self, retries, timeout): try: use_authority = retries == 0 and self.fall_back_to_authority - self.download_url = self._pick_url(use_authority) - self.start_time = time.time() - - response = urllib.urlopen( - urllib.Request( - self.download_url, - headers = { - 'Accept-Encoding': ', '.join(self.compression), - 'User-Agent': 'Stem/%s' % stem.__version__, - } - ), - timeout = timeout, - ) + endpoint = self._pick_endpoint(use_authority) + self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
- data = response.read() - encoding = response.headers.get('Content-Encoding') - - # Tor doesn't include compression headers. As such when using gzip we - # need to include '32' for automatic header detection... - # - # https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompr... - # - # ... and with zstd we need to use the streaming API. - - if encoding in (Compression.GZIP, 'deflate'): - data = zlib.decompress(data, zlib.MAX_WBITS | 32) - elif encoding == Compression.ZSTD and ZSTD_SUPPORTED: - output_buffer = io.BytesIO() - - with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor: - decompressor.write(data) - - data = output_buffer.getvalue() - elif encoding == Compression.LZMA and LZMA_SUPPORTED: - data = lzma.decompress(data) - - self.content = data.strip() - self.reply_headers = response.headers + self.start_time = time.time() + self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout) self.runtime = time.time() - self.start_time log.trace("Descriptors retrieved from '%s' in %0.2fs" % (self.download_url, self.runtime)) except: diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index deb10060..b2cfa7c1 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -282,8 +282,7 @@ class TestDescriptorDownloader(unittest.TestCase): validate = True, )
- expeced_url = 'http://128.31.0.39:9131' + TEST_RESOURCE - self.assertEqual(expeced_url, query._pick_url()) + self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint())
descriptors = list(query) self.assertEqual(1, len(descriptors)) diff --git a/test/unit/endpoint.py b/test/unit/endpoint.py index ad710a22..397323a7 100644 --- a/test/unit/endpoint.py +++ b/test/unit/endpoint.py @@ -30,3 +30,17 @@ class TestEndpoint(unittest.TestCase): self.assertRaises(ValueError, stem.DirPort, 'hello', 80) self.assertRaises(ValueError, stem.DirPort, -5, 80) self.assertRaises(ValueError, stem.DirPort, None, 80) + + def test_equality(self): + self.assertTrue(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 80)) + self.assertTrue(stem.ORPort('12.34.56.78', 80, [1, 2, 3]) == stem.ORPort('12.34.56.78', 80, [1, 2, 3])) + self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.88', 80)) + self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 443)) + self.assertFalse(stem.ORPort('12.34.56.78', 80, [2, 3]) == stem.ORPort('12.34.56.78', 80, [1, 2, 3])) + + self.assertTrue(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 80)) + self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.88', 80)) + self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 443)) + + self.assertFalse(stem.ORPort('12.34.56.78', 80) == stem.DirPort('12.34.56.78', 80)) + self.assertFalse(stem.DirPort('12.34.56.78', 80) == stem.ORPort('12.34.56.78', 80))
tor-commits@lists.torproject.org