commit 5a875ed329ddb119c2b6787e051571fb7323f622 Author: Damian Johnson atagar@torproject.org Date: Sat Apr 28 16:47:42 2018 -0700
Compression support for ORPort descriptor downloads
Compression works the same regardless of of if we download from an ORPort or DirPort. Also including a couple python3 compatibility fixes for circuit construction. --- stem/client/__init__.py | 4 +- stem/descriptor/remote.py | 88 +++++++++++++++++++++++++++--------------- test/unit/descriptor/remote.py | 10 ++--- 3 files changed, 62 insertions(+), 40 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index aa4aa274..6e25f748 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -168,7 +168,7 @@ class Relay(object): if not created_fast_cells: raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
- created_fast_cell = created_fast_cells[0] + created_fast_cell = list(created_fast_cells)[0] kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
if created_fast_cell.derivative_key != kdf.key_hash: @@ -211,7 +211,7 @@ class Circuit(object): from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend
- ctr = modes.CTR(ZERO * (algorithms.AES.block_size / 8)) + ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8))
self.relay = relay self.id = circ_id diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index e745ce20..d6832eaa 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -254,7 +254,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
-def _download_from_orport(endpoint, resource): +def _download_from_orport(endpoint, compression, resource): """ Downloads descriptors from the given orport. Payload is just like an http response (headers and all)... @@ -273,6 +273,7 @@ def _download_from_orport(endpoint, resource): ... rest of the descriptor content...
:param stem.ORPort endpoint: endpoint to download from + :param list compression: compression methods for the request :param str resource: descriptor resource to download
:returns: two value tuple of the form (data, reply_headers) @@ -286,26 +287,30 @@ def _download_from_orport(endpoint, resource):
with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay: with relay.create_circuit() as circ: + request = '\r\n'.join(( + 'GET %s HTTP/1.0' % resource, + 'Accept-Encoding: %s' % ', '.join(compression), + 'User-Agent: Stem/%s' % stem.__version__, + )) + '\r\n\r\n' + circ.send('RELAY_BEGIN_DIR', stream_id = 1) - lines = b''.join([cell.data for cell in circ.send('RELAY_DATA', 'GET %s HTTP/1.0\r\n\r\n' % resource, stream_id = 1)]).splitlines() - first_line = lines.pop(0) + response = b''.join([cell.data for cell in circ.send('RELAY_DATA', request, stream_id = 1)]) + first_line, data = response.split(b'\r\n', 1) + header_data, data = data.split(b'\r\n\r\n', 1)
- if first_line != 'HTTP/1.0 200 OK': + if first_line != b'HTTP/1.0 200 OK': raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % first_line)
headers = {} - next_line = lines.pop(0)
- while next_line: - if ': ' not in next_line: - raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % next_line) + for line in str_tools._to_unicode(header_data).splitlines(): + if ': ' not in line: + raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line)
- key, value = next_line.split(': ', 1) + key, value = line.split(': ', 1) headers[key] = value
- next_line = lines.pop(0) - - return '\n'.join(lines), headers + return _decompress(data, headers.get('Content-Encoding')), headers
def _download_from_dirport(url, compression, timeout): @@ -334,29 +339,49 @@ def _download_from_dirport(url, compression, timeout): timeout = timeout, )
- data = response.read() - encoding = response.headers.get('Content-Encoding') + return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
- # Tor doesn't include compression headers. As such when using gzip we - # need to include '32' for automatic header detection... - # - # https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompr... - # - # ... and with zstd we need to use the streaming API.
- if encoding in (Compression.GZIP, 'deflate'): - data = zlib.decompress(data, zlib.MAX_WBITS | 32) - elif encoding == Compression.ZSTD and ZSTD_SUPPORTED: +def _decompress(data, encoding): + """ + Decompresses descriptor data. + + Tor doesn't include compression headers. As such when using gzip we + need to include '32' for automatic header detection... + + https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompr... + + ... and with zstd we need to use the streaming API. + + :param bytes data: data we received + :param str encoding: 'Content-Encoding' header of the response + + :raises: + * **ValueError** if encoding is unrecognized + * **ImportError** if missing the decompression module + """ + + if encoding == Compression.PLAINTEXT: + return data.strip() + elif encoding in (Compression.GZIP, 'deflate'): + return zlib.decompress(data, zlib.MAX_WBITS | 32).strip() + elif encoding == Compression.ZSTD: + if not ZSTD_SUPPORTED: + raise ImportError('Decompressing zstd data requires https://pypi.python.org/pypi/zstandard') + output_buffer = io.BytesIO()
with zstd.ZstdDecompressor().write_to(output_buffer) as decompressor: decompressor.write(data)
- data = output_buffer.getvalue() - elif encoding == Compression.LZMA and LZMA_SUPPORTED: - data = lzma.decompress(data) + return output_buffer.getvalue().strip() + elif encoding == Compression.LZMA: + if not LZMA_SUPPORTED: + raise ImportError('Decompressing lzma data requires https://docs.python.org/3/library/lzma.html')
- return data.strip(), response.headers + return lzma.decompress(data).strip() + else: + raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
def _guess_descriptor_type(resource): @@ -476,6 +501,9 @@ class Query(object):
:var list endpoints: :class:`~stem.DirPort` or :class:`~stem.ORPort` of the authority or mirror we're querying, this uses authorities if undefined + :var list compression: list of :data:`stem.descriptor.remote.Compression` + we're willing to accept, when none are mutually supported downloads fall + back to Compression.PLAINTEXT :var int retries: number of times to attempt the request if downloading it fails :var bool fall_back_to_authority: when retrying request issues the last @@ -500,11 +528,7 @@ class Query(object): Following are only applicable when downloading from a :class:`~stem.DirPort`...
- :var list compression: list of :data:`stem.descriptor.remote.Compression` - we're willing to accept, when none are mutually supported downloads fall - back to Compression.PLAINTEXT :var float timeout: duration before we'll time out our request - :var str download_url: last url used to download the descriptor, this is unset until we've actually made a download attempt
@@ -683,7 +707,7 @@ class Query(object): endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort): - self.content, self.reply_headers = _download_from_orport(endpoint, self.resource) + self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource) elif isinstance(endpoint, stem.DirPort): self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/')) self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout) diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index 040d4c6c..753681e0 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -12,6 +12,7 @@ import unittest import stem.descriptor.remote import stem.prereq import stem.util.conf +import stem.util.str_tools
from stem.descriptor.remote import Compression from test.unit.descriptor import read_resource @@ -126,10 +127,9 @@ HEADER = '\r\n'.join([ 'Content-Encoding: %s', ])
-ORPORT_DESCRIPTOR = 'HTTP/1.0 200 OK\n' + HEADER + '\n\n' + TEST_DESCRIPTOR
- -def _orport_mock(data): +def _orport_mock(data, encoding = 'identity'): + data = b'HTTP/1.0 200 OK\r\n' + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data cells = []
for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]: @@ -167,7 +167,7 @@ class TestDescriptorDownloader(unittest.TestCase): # prevent our mocks from impacting other tests stem.descriptor.remote.SINGLETON_DOWNLOADER = None
- @patch('stem.client.Relay.connect', _orport_mock(ORPORT_DESCRIPTOR)) + @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR)) def test_using_orport(self): """ Download a descriptor through the ORPort. @@ -175,7 +175,6 @@ class TestDescriptorDownloader(unittest.TestCase):
reply = stem.descriptor.remote.their_server_descriptor( endpoints = [stem.ORPort('12.34.56.78', 1100)], - fall_back_to_authority = False, validate = True, )
@@ -191,7 +190,6 @@ class TestDescriptorDownloader(unittest.TestCase):
reply = stem.descriptor.remote.their_server_descriptor( endpoints = [stem.DirPort('12.34.56.78', 1100)], - fall_back_to_authority = False, validate = True, )