commit 448060eabed41b3bad22cc5b0a5b5494f2793816 Author: Damian Johnson atagar@torproject.org Date: Mon Jun 15 16:23:40 2020 -0700
Rewrite descriptor downloading
Using run_in_executor() here has a couple issues...
1. Executor threads aren't cleaned up. Running our tests with the '--all' argument concludes with...
Threads lingering after test run: <_MainThread(MainThread, started 140249831520000)> <Thread(ThreadPoolExecutor-0_0, started daemon 140249689769728)> <Thread(ThreadPoolExecutor-0_1, started daemon 140249606911744)> <Thread(ThreadPoolExecutor-0_2, started daemon 140249586980608)> <Thread(ThreadPoolExecutor-0_3, started daemon 140249578587904)> <Thread(ThreadPoolExecutor-0_4, started daemon 140249570195200)> ...
2. Asyncio has its own IO. Wrapping urllib within an executor is easy, but loses asyncio benefits such as imposing timeouts through asyncio.wait_for().
Urllib marshals and parses HTTP headers, but we already do that for ORPort requests, so using a raw asyncio connection actually lets us deduplicate some code.
Deduplication greatly simplifies testing in that we can mock _download_from() rather than the raw connection. However, I couldn't adapt our timeout test. Asyncio's wait_for() works in practice, but no dice when mocked. --- stem/descriptor/remote.py | 229 ++++++++++++++++------------------------- test/unit/descriptor/remote.py | 183 ++++++++------------------------ 2 files changed, 133 insertions(+), 279 deletions(-)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index c23ab7a9..f1ce79db 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -84,14 +84,11 @@ content. For example... """
import asyncio -import functools import io import random -import socket import sys import threading import time -import urllib.request
import stem import stem.client @@ -313,7 +310,7 @@ class AsyncQuery(object): :var bool is_done: flag that indicates if our request has finished
:var float start_time: unix timestamp when we first started running - :var http.client.HTTPMessage reply_headers: headers provided in the response, + :var dict reply_headers: headers provided in the response, **None** if we haven't yet made our request :var float runtime: time our query took, this is **None** if it's not yet finished @@ -330,13 +327,9 @@ class AsyncQuery(object): :var float timeout: duration before we'll time out our request :var str download_url: last url used to download the descriptor, this is unset until we've actually made a download attempt - - :param start: start making the request when constructed (default is **True**) - :param block: only return after the request has been completed, this is - the same as running **query.run(True)** (default is **False**) """
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: + def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: if not resource.startswith('/'): raise ValueError("Resources should start with a '/': %s" % resource)
@@ -395,22 +388,15 @@ class AsyncQuery(object): self._downloader_task = None # type: Optional[asyncio.Task] self._downloader_lock = threading.RLock()
- self._asyncio_loop = asyncio.get_event_loop() - - if start: - self.start() - - if block: - self.run(True) - - def start(self) -> None: + async def start(self) -> None: """ Starts downloading the scriptors if we haven't started already. """
with self._downloader_lock: if self._downloader_task is None: - self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout)) + loop = asyncio.get_running_loop() + self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: """ @@ -434,7 +420,7 @@ class AsyncQuery(object):
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]: with self._downloader_lock: - self.start() + await self.start() await self._downloader_task
if self.error: @@ -491,36 +477,71 @@ class AsyncQuery(object): return random.choice(self.endpoints)
async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None: - try: - self.start_time = time.time() + self.start_time = time.time() + + retries = self.retries + time_remaining = self.timeout + + while True: endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort): downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource) - self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource) elif isinstance(endpoint, stem.DirPort): - self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/')) - downloaded_from = self.download_url - self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout) + downloaded_from = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/')) + self.download_url = downloaded_from else: raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
- self.runtime = time.time() - self.start_time - log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime)) - except: - exc = sys.exc_info()[1] + try: + response = await asyncio.wait_for(self._download_from(endpoint), time_remaining) + self.content, self.reply_headers = _http_body_and_headers(response) + + self.is_done = True + self.runtime = time.time() - self.start_time + + log.trace('Descriptors retrieved from %s in %0.2fs' % (downloaded_from, self.runtime)) + return + except asyncio.TimeoutError as exc: + self.is_done = True + self.error = stem.DownloadTimeout(downloaded_from, exc, sys.exc_info()[2], self.timeout) + return + except: + exception = sys.exc_info()[1] + retries -= 1 + + if time_remaining is not None: + time_remaining -= time.time() - self.start_time + + if retries > 0: + log.debug("Failed to download descriptors from '%s' (%i retries remaining): %s" % (downloaded_from, retries, exception)) + else: + log.debug("Failed to download descriptors from '%s': %s" % (self.download_url, exception)) + + self.is_done = True + self.error = exception + return
- if timeout is not None: - timeout -= time.time() - self.start_time + async def _download_from(self, endpoint: stem.Endpoint) -> bytes: + http_request = '\r\n'.join(( + 'GET %s HTTP/1.0' % self.resource, + 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, self.compression)), + 'User-Agent: %s' % stem.USER_AGENT, + )) + '\r\n\r\n'
- if retries > 0 and (timeout is None or timeout > 0): - log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc)) - return await self._download_descriptors(retries - 1, timeout) - else: - log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc)) - self.error = exc - finally: - self.is_done = True + if isinstance(endpoint, stem.ORPort): + link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3] + + async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay: + async with await relay.create_circuit() as circ: + return await circ.directory(http_request, stream_id = 1) + elif isinstance(endpoint, stem.DirPort): + reader, writer = await asyncio.open_connection(endpoint.address, endpoint.port) + writer.write(str_tools._to_bytes(http_request)) + + return await reader.read() + else: + raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
class Query(stem.util.AsyncClassWrapper): @@ -663,8 +684,8 @@ class Query(stem.util.AsyncClassWrapper): """
def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio') + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'query asyncio') self._loop_thread.setDaemon(True) self._loop_thread.start()
@@ -677,19 +698,23 @@ class Query(stem.util.AsyncClassWrapper): retries, fall_back_to_authority, timeout, - start, - block, validate, document_handler, **kwargs, )
+ if start: + self.start() + + if block: + self.run(True) + def start(self) -> None: """ Starts downloading the scriptors if we haven't started already. """
- self._call_async_method_soon('start') + self._execute_async_method('start')
def run(self, suppress = False) -> List['stem.descriptor.Descriptor']: """ @@ -1146,10 +1171,9 @@ class DescriptorDownloader(object): return Query(resource, **args)
-async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]: +def _http_body_and_headers(data: bytes) -> Tuple[bytes, Dict[str, str]]: """ - Downloads descriptors from the given orport. Payload is just like an http - response (headers and all)... + Parse the headers and decompressed body from a HTTP response, such as...
::
@@ -1164,112 +1188,41 @@ async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[ste identity-ed25519 ... rest of the descriptor content...
- :param endpoint: endpoint to download from - :param compression: compression methods for the request - :param resource: descriptor resource to download + :param data: HTTP response
- :returns: two value tuple of the form (data, reply_headers) + :returns: **tuple** with the decompressed data and headers
:raises: - * :class:`stem.ProtocolError` if not a valid descriptor response - * :class:`stem.SocketError` if unable to establish a connection - """ - - link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3] - - async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay: - async with await relay.create_circuit() as circ: - request = '\r\n'.join(( - 'GET %s HTTP/1.0' % resource, - 'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)), - 'User-Agent: %s' % stem.USER_AGENT, - )) + '\r\n\r\n' - - response = await circ.directory(request, stream_id = 1) - first_line, data = response.split(b'\r\n', 1) - header_data, body_data = data.split(b'\r\n\r\n', 1) - - if not first_line.startswith(b'HTTP/1.0 2'): - raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line)) - - headers = {} - - for line in str_tools._to_unicode(header_data).splitlines(): - if ': ' not in line: - raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8'))) - - key, value = line.split(': ', 1) - headers[key] = value - - return _decompress(body_data, headers.get('Content-Encoding')), headers - - -async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: - """ - Downloads descriptors from the given url. - - :param url: dirport url from which to download from - :param compression: compression methods for the request - :param timeout: duration before we'll time out our request - - :returns: two value tuple of the form (data, reply_headers) - - :raises: - * :class:`~stem.DownloadTimeout` if our request timed out - * :class:`~stem.DownloadFailed` if our request fails - """ - - # TODO: use an asyncronous solution for the HTTP request. - request = urllib.request.Request( - url, - headers = { - 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)), - 'User-Agent': stem.USER_AGENT, - } - ) - get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout) - - loop = asyncio.get_event_loop() - try: - response = await loop.run_in_executor(None, get_response) - except socket.timeout as exc: - raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) - except: - exception, stacktrace = sys.exc_info()[1:3] - raise stem.DownloadFailed(url, exception, stacktrace) - - return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers - - -def _decompress(data: bytes, encoding: str) -> bytes: + * **stem.ProtocolError** if response was unsuccessful or malformed + * **ValueError** if encoding is unrecognized + * **ImportError** if missing the decompression module """ - Decompresses descriptor data.
- Tor doesn't include compression headers. As such when using gzip we - need to include '32' for automatic header detection... + first_line, data = data.split(b'\r\n', 1) + header_data, body_data = data.split(b'\r\n\r\n', 1)
- https://stackoverflow.com/questions/3122145/zlib-error-error-3-while-decompr... + if not first_line.startswith(b'HTTP/1.0 2'): + raise stem.ProtocolError("Response should begin with HTTP success, but was '%s'" % str_tools._to_unicode(first_line))
- ... and with zstd we need to use the streaming API. + headers = {}
- :param data: data we received - :param encoding: 'Content-Encoding' header of the response + for line in str_tools._to_unicode(header_data).splitlines(): + if ': ' not in line: + raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
- :returns: **bytes** with the decompressed data + key, value = line.split(': ', 1) + headers[key] = value
- :raises: - * **ValueError** if encoding is unrecognized - * **ImportError** if missing the decompression module - """ + encoding = headers.get('Content-Encoding')
if encoding == 'deflate': - return stem.descriptor.Compression.GZIP.decompress(data) + return stem.descriptor.Compression.GZIP.decompress(body_data), headers
for compression in stem.descriptor.Compression: if encoding == compression.encoding: - return compression.decompress(data) + return compression.decompress(body_data), headers
- raise ValueError("'%s' isn't a recognized type of encoding" % encoding) + raise ValueError("'%s' is an unrecognized encoding" % encoding)
def _guess_descriptor_type(resource: str) -> str: diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py index 33ee57fb..797bc8a3 100644 --- a/test/unit/descriptor/remote.py +++ b/test/unit/descriptor/remote.py @@ -2,9 +2,6 @@ Unit tests for stem.descriptor.remote. """
-import http.client -import socket -import time import unittest
import stem @@ -67,47 +64,13 @@ HEADER = '\r\n'.join([ ])
-def _orport_mock(data, encoding = 'identity', response_code_header = None): +def mock_download(descriptor, encoding = 'identity', response_code_header = None): if response_code_header is None: response_code_header = b'HTTP/1.0 200 OK\r\n'
- data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + data - cells = [] + data = response_code_header + stem.util.str_tools._to_bytes(HEADER % encoding) + b'\r\n\r\n' + descriptor
- for hunk in [data[i:i + 50] for i in range(0, len(data), 50)]: - cell = Mock() - cell.data = hunk - cells.append(cell) - - class AsyncMock(Mock): - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - return - - circ_mock = AsyncMock() - circ_mock.directory.side_effect = coro_func_returning_value(data) - - relay_mock = AsyncMock() - relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock) - - return coro_func_returning_value(relay_mock) - - -def _dirport_mock(data, encoding = 'identity'): - dirport_mock = Mock() - dirport_mock().read.return_value = data - - headers = http.client.HTTPMessage() - - for line in HEADER.splitlines(): - key, value = line.split(': ', 1) - headers.add_header(key, encoding if key == 'Content-Encoding' else value) - - dirport_mock().headers = headers - - return dirport_mock + return patch('stem.descriptor.remote.AsyncQuery._download_from', Mock(side_effect = coro_func_returning_value(data)))
class TestDescriptorDownloader(unittest.TestCase): @@ -115,10 +78,10 @@ class TestDescriptorDownloader(unittest.TestCase): # prevent our mocks from impacting other tests stem.descriptor.remote.SINGLETON_DOWNLOADER = None
- @patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR)) - def test_using_orport(self): + @mock_download(TEST_DESCRIPTOR) + def test_download(self): """ - Download a descriptor through the ORPort. + Simply download and parse a descriptor. """
reply = stem.descriptor.remote.their_server_descriptor( @@ -128,10 +91,16 @@ class TestDescriptorDownloader(unittest.TestCase): )
self.assertEqual(1, len(list(reply))) - self.assertEqual('moria1', list(reply)[0].nickname) self.assertEqual(5, len(reply.reply_headers))
- def test_orport_response_code_headers(self): + desc = list(reply)[0] + + self.assertEqual('moria1', desc.nickname) + self.assertEqual('128.31.0.34', desc.address) + self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint) + self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes()) + + def test_response_header_code(self): """ When successful Tor provides a '200 OK' status, but we should accept other 2xx response codes, reason text, and recognize HTTP errors. @@ -144,14 +113,14 @@ class TestDescriptorDownloader(unittest.TestCase): )
for header in response_code_headers: - with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = header)): + with mock_download(TEST_DESCRIPTOR, response_code_header = header): stem.descriptor.remote.their_server_descriptor( endpoints = [stem.ORPort('12.34.56.78', 1100)], validate = True, skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE, ).run()
- with patch('stem.client.Relay.connect', _orport_mock(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n')): + with mock_download(TEST_DESCRIPTOR, response_code_header = b'HTTP/1.0 500 Kaboom\r\n'): request = stem.descriptor.remote.their_server_descriptor( endpoints = [stem.ORPort('12.34.56.78', 1100)], validate = True, @@ -160,28 +129,32 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaisesRegexp(stem.ProtocolError, "^Response should begin with HTTP success, but was 'HTTP/1.0 500 Kaboom'", request.run)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR)) - def test_using_dirport(self): - """ - Download a descriptor through the DirPort. - """ + @mock_download(TEST_DESCRIPTOR) + def test_reply_header_data(self): + query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False) + self.assertEqual(None, query.reply_headers) # initially we don't have a reply + query.run()
- reply = stem.descriptor.remote.their_server_descriptor( - endpoints = [stem.DirPort('12.34.56.78', 1100)], - validate = True, - skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE, - ) + self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date')) + self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type')) + self.assertEqual('97.103.17.56', query.reply_headers.get('X-Your-Address-Is')) + self.assertEqual('no-cache', query.reply_headers.get('Pragma')) + self.assertEqual('identity', query.reply_headers.get('Content-Encoding'))
- self.assertEqual(1, len(list(reply))) - self.assertEqual('moria1', list(reply)[0].nickname) - self.assertEqual(5, len(reply.reply_headers)) + # request a header that isn't present + self.assertEqual(None, query.reply_headers.get('no-such-header')) + self.assertEqual('default', query.reply_headers.get('no-such-header', 'default')) + + descriptors = list(query) + self.assertEqual(1, len(descriptors)) + self.assertEqual('moria1', descriptors[0].nickname)
def test_gzip_url_override(self): query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False) self.assertEqual([stem.descriptor.Compression.GZIP], query.compression) self.assertEqual(TEST_RESOURCE, query.resource)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_identity'), encoding = 'identity')) + @mock_download(read_resource('compressed_identity'), encoding = 'identity') def test_compression_plaintext(self): """ Download a plaintext descriptor. @@ -197,7 +170,7 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_gzip'), encoding = 'gzip')) + @mock_download(read_resource('compressed_gzip'), encoding = 'gzip') def test_compression_gzip(self): """ Download a gip compressed descriptor. @@ -213,7 +186,7 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_zstd'), encoding = 'x-zstd')) + @mock_download(read_resource('compressed_zstd'), encoding = 'x-zstd') def test_compression_zstd(self): """ Download a zstd compressed descriptor. @@ -231,7 +204,7 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen', _dirport_mock(read_resource('compressed_lzma'), encoding = 'x-tor-lzma')) + @mock_download(read_resource('compressed_lzma'), encoding = 'x-tor-lzma') def test_compression_lzma(self): """ Download a lzma compressed descriptor. @@ -249,8 +222,8 @@ class TestDescriptorDownloader(unittest.TestCase): self.assertEqual(1, len(descriptors)) self.assertEqual('moria1', descriptors[0].nickname)
- @patch('urllib.request.urlopen') - def test_each_getter(self, dirport_mock): + @mock_download(TEST_DESCRIPTOR) + def test_each_getter(self): """ Surface level exercising of each getter method for downloading descriptors. """ @@ -266,57 +239,8 @@ class TestDescriptorDownloader(unittest.TestCase): downloader.get_bandwidth_file() downloader.get_detached_signatures()
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR)) - def test_reply_headers(self): - query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False) - self.assertEqual(None, query.reply_headers) # initially we don't have a reply - query.run() - - self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('date')) - self.assertEqual('application/octet-stream', query.reply_headers.get('content-type')) - self.assertEqual('97.103.17.56', query.reply_headers.get('x-your-address-is')) - self.assertEqual('no-cache', query.reply_headers.get('pragma')) - self.assertEqual('identity', query.reply_headers.get('content-encoding')) - - # getting headers should be case insensitive - self.assertEqual('identity', query.reply_headers.get('CoNtEnT-ENCODING')) - - # request a header that isn't present - self.assertEqual(None, query.reply_headers.get('no-such-header')) - self.assertEqual('default', query.reply_headers.get('no-such-header', 'default')) - - descriptors = list(query) - self.assertEqual(1, len(descriptors)) - self.assertEqual('moria1', descriptors[0].nickname) - - @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR)) - def test_query_download(self): - """ - Check Query functionality when we successfully download a descriptor. - """ - - query = stem.descriptor.remote.Query( - TEST_RESOURCE, - 'server-descriptor 1.0', - endpoints = [stem.DirPort('128.31.0.39', 9131)], - compression = Compression.PLAINTEXT, - validate = True, - skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE, - ) - - self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint()) - - descriptors = list(query) - self.assertEqual(1, len(descriptors)) - desc = descriptors[0] - - self.assertEqual('moria1', desc.nickname) - self.assertEqual('128.31.0.34', desc.address) - self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint) - self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes()) - - @patch('urllib.request.urlopen', _dirport_mock(b'some malformed stuff')) - def test_query_with_malformed_content(self): + @mock_download(b'some malformed stuff') + def test_malformed_content(self): """ Query with malformed descriptor content. """ @@ -340,29 +264,6 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- @patch('urllib.request.urlopen') - def test_query_with_timeout(self, dirport_mock): - def urlopen_call(*args, **kwargs): - time.sleep(0.06) - raise socket.timeout('connection timed out') - - dirport_mock.side_effect = urlopen_call - - query = stem.descriptor.remote.Query( - TEST_RESOURCE, - 'server-descriptor 1.0', - endpoints = [stem.DirPort('128.31.0.39', 9131)], - fall_back_to_authority = False, - timeout = 0.1, - validate = True, - ) - - # After two requests we'll have reached our total permissable timeout. - # It would be nice to check that we don't make a third, but this - # assertion has proved unreliable so only checking for the exception. - - self.assertRaises(stem.DownloadTimeout, query.run) - def test_query_with_invalid_endpoints(self): invalid_endpoints = { 'hello': "'h' is a str.", @@ -375,7 +276,7 @@ class TestDescriptorDownloader(unittest.TestCase): expected_error = 'Endpoints must be an stem.ORPort or stem.DirPort. ' + error_suffix self.assertRaisesWith(ValueError, expected_error, stem.descriptor.remote.Query, TEST_RESOURCE, 'server-descriptor 1.0', endpoints = endpoints)
- @patch('urllib.request.urlopen', _dirport_mock(TEST_DESCRIPTOR)) + @mock_download(TEST_DESCRIPTOR) def test_can_iterate_multiple_times(self): query = stem.descriptor.remote.Query( TEST_RESOURCE,