commit 076f89dfb4fd4a156ae32fde8c78c531385162e8 Author: Damian Johnson atagar@torproject.org Date: Tue Apr 14 18:43:20 2020 -0700
Fix mypy issues
Correcting the issues brought up by mypy. Some bug fixes and reconfiguring. Significant changes include...
* Refactored commmandline parsing (stem/interpreter/arguments.py and test/arguments.py) to be a typing.NamedTuple derived class.
* Temporarily ignoring all mypy warnings related to our enum class. Python 3.x added its own enum that I'd like to swap us over to but that will be a separate project. --- run_tests.py | 6 +- stem/__init__.py | 12 +- stem/client/__init__.py | 46 ++++--- stem/client/cell.py | 52 +++---- stem/client/datatype.py | 91 +++++++------ stem/connection.py | 41 +++--- stem/control.py | 228 +++++++++++++++---------------- stem/descriptor/__init__.py | 231 ++++++++++++++++--------------- stem/descriptor/bandwidth_file.py | 34 +++-- stem/descriptor/certificate.py | 20 +-- stem/descriptor/collector.py | 82 +++++------ stem/descriptor/extrainfo_descriptor.py | 60 ++++---- stem/descriptor/hidden_service.py | 118 ++++++++-------- stem/descriptor/microdescriptor.py | 20 +-- stem/descriptor/networkstatus.py | 136 ++++++++++--------- stem/descriptor/remote.py | 48 +++---- stem/descriptor/router_status_entry.py | 61 +++++---- stem/descriptor/server_descriptor.py | 69 +++++----- stem/descriptor/tordnsel.py | 30 ++-- stem/directory.py | 32 ++--- stem/exit_policy.py | 60 ++++---- stem/interpreter/__init__.py | 12 +- stem/interpreter/arguments.py | 177 ++++++++++++------------ stem/interpreter/autocomplete.py | 11 +- stem/interpreter/commands.py | 28 ++-- stem/interpreter/help.py | 13 +- stem/manual.py | 37 ++--- stem/process.py | 6 +- stem/response/__init__.py | 104 ++++++++++---- stem/response/events.py | 234 ++++++++++++++++++++++++++++---- stem/response/getconf.py | 4 +- stem/response/getinfo.py | 12 +- stem/response/protocolinfo.py | 7 +- stem/socket.py | 55 +++++--- stem/util/__init__.py | 6 +- stem/util/conf.py | 26 ++-- stem/util/connection.py | 41 +++--- stem/util/enum.py | 31 +++-- stem/util/log.py | 3 +- stem/util/proc.py | 12 +- stem/util/str_tools.py | 32 +++-- stem/util/system.py | 48 ++++--- stem/util/term.py | 12 +- stem/util/test_tools.py | 85 ++++++------ stem/version.py | 11 +- test/arguments.py | 233 ++++++++++++++++--------------- test/mypy.ini | 6 + test/settings.cfg | 19 +++ test/task.py | 2 +- test/unit/client/address.py | 2 +- test/unit/control/controller.py | 6 +- test/unit/descriptor/bandwidth_file.py | 3 +- test/unit/interpreter/arguments.py | 32 ++--- test/unit/util/proc.py | 23 ++-- 54 files changed, 1608 insertions(+), 1202 deletions(-)
diff --git a/run_tests.py b/run_tests.py index 2ea07dab..fd46211f 100755 --- a/run_tests.py +++ b/run_tests.py @@ -194,7 +194,7 @@ def main(): test_config.load(os.environ['STEM_TEST_CONFIG'])
try: - args = test.arguments.parse(sys.argv[1:]) + args = test.arguments.Arguments.parse(sys.argv[1:]) test.task.TOR_VERSION.args = (args.tor_path,) test.output.SUPPRESS_STDOUT = args.quiet except ValueError as exc: @@ -202,7 +202,7 @@ def main(): sys.exit(1)
if args.print_help: - println(test.arguments.get_help()) + println(test.arguments.Arguments.get_help()) sys.exit() elif not args.run_unit and not args.run_integ: println('Nothing to run (for usage provide --help)\n') @@ -383,7 +383,7 @@ def _print_static_issues(static_check_issues): if static_check_issues: println('STATIC CHECKS', STATUS)
- for file_path in static_check_issues: + for file_path in sorted(static_check_issues): println('* %s' % file_path, STATUS)
# Make a dict of line numbers to its issues. This is so we can both sort diff --git a/stem/__init__.py b/stem/__init__.py index c0efab19..ce8d70a9 100644 --- a/stem/__init__.py +++ b/stem/__init__.py @@ -567,7 +567,7 @@ __all__ = [ ]
# Constant that we use by default for our User-Agent when downloading descriptors -stem.USER_AGENT = 'Stem/%s' % __version__ +USER_AGENT = 'Stem/%s' % __version__
# Constant to indicate an undefined argument default. Usually we'd use None for # this, but users will commonly provide None as the argument so need something @@ -612,7 +612,7 @@ class ORPort(Endpoint): :var list link_protocols: link protocol version we're willing to establish """
- def __init__(self, address: str, port: int, link_protocols: Optional[Sequence[int]] = None) -> None: + def __init__(self, address: str, port: int, link_protocols: Optional[Sequence['stem.client.datatype.LinkProtocol']] = None) -> None: # type: ignore super(ORPort, self).__init__(address, port) self.link_protocols = link_protocols
@@ -644,6 +644,8 @@ class OperationFailed(ControllerError): message """
+ # TODO: should the code be an int instead? + def __init__(self, code: Optional[str] = None, message: Optional[str] = None) -> None: super(ControllerError, self).__init__(message) self.code = code @@ -663,7 +665,7 @@ class CircuitExtensionFailed(UnsatisfiableRequest): :var stem.response.events.CircuitEvent circ: response notifying us of the failure """
- def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None: + def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None: # type: ignore super(CircuitExtensionFailed, self).__init__(message = message) self.circ = circ
@@ -775,7 +777,7 @@ class DownloadTimeout(DownloadFailed): .. versionadded:: 1.8.0 """
- def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: int): + def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: float): message = 'Failed to download from %s: %0.1f second timeout reached' % (url, timeout) super(DownloadTimeout, self).__init__(url, error, stacktrace, message)
@@ -919,7 +921,7 @@ StreamStatus = stem.util.enum.UppercaseEnum( )
# StreamClosureReason is a superset of RelayEndReason -StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [ +StreamClosureReason = stem.util.enum.UppercaseEnum(*(RelayEndReason.keys() + [ # type: ignore 'END', 'PRIVATE_ADDR', ])) diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 2972985d..8726bdbf 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -34,11 +34,12 @@ import stem.socket import stem.util.connection
from types import TracebackType -from typing import Iterator, Optional, Tuple, Type +from typing import Dict, Iterator, List, Optional, Sequence, Type, Union
from stem.client.cell import ( CELL_TYPE_SIZE, FIXED_PAYLOAD_LEN, + PAYLOAD_LEN_SIZE, Cell, )
@@ -66,15 +67,15 @@ class Relay(object): :var int link_protocol: link protocol version we established """
- def __init__(self, orport: int, link_protocol: int) -> None: + def __init__(self, orport: stem.socket.RelaySocket, link_protocol: int) -> None: self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_buffer = b'' # unread bytes self._orport_lock = threading.RLock() - self._circuits = {} + self._circuits = {} # type: Dict[int, stem.client.Circuit]
@staticmethod - def connect(address: str, port: int, link_protocols: Tuple[int] = DEFAULT_LINK_PROTOCOLS) -> None: + def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore """ Establishes a connection with the given ORPort.
@@ -121,7 +122,7 @@ class Relay(object): # first VERSIONS cell, always have CIRCID_LEN == 2 for backward # compatibility.
- conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) + conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore response = conn.recv()
# Link negotiation ends right away if we lack a common protocol @@ -131,12 +132,12 @@ class Relay(object): conn.close() raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
- versions_reply = stem.client.cell.Cell.pop(response, 2)[0] + versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore common_protocols = set(link_protocols).intersection(versions_reply.versions)
if not common_protocols: conn.close() - raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(link_protocols), address, port, ', '.join(versions_reply.versions))) + raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
# Establishing connections requires sending a NETINFO, but including our # address is optional. We can revisit including it when we have a usecase @@ -147,7 +148,10 @@ class Relay(object):
return Relay(conn, link_protocol)
- def _recv(self, raw: bool = False) -> None: + def _recv_bytes(self) -> bytes: + return self._recv(True) # type: ignore + + def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell': """ Reads the next cell from our ORPort. If none is present this blocks until one is available. @@ -172,18 +176,18 @@ class Relay(object): else: # variable length, our next field is the payload size
- while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size): + while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN): self._orport_buffer += self._orport.recv() # read until we know the cell size
- payload_len = FIXED_PAYLOAD_LEN.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0] - cell_size = circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN.size + payload_len + payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0] + cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
while len(self._orport_buffer) < cell_size: self._orport_buffer += self._orport.recv() # read until we have the full cell
if raw: content, self._orport_buffer = split(self._orport_buffer, cell_size) - return content + return content # type: ignore else: cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol) return cell @@ -213,12 +217,12 @@ class Relay(object): :returns: **generator** with the cells received in reply """
+ # TODO: why is this an iterator? + self._orport.recv(timeout = 0) # discard unread data self._orport.send(cell.pack(self.link_protocol)) response = self._orport.recv(timeout = 1) - - for received_cell in stem.client.cell.Cell.pop(response, self.link_protocol): - yield received_cell + yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
def is_alive(self) -> bool: """ @@ -251,7 +255,7 @@ class Relay(object): with self._orport_lock: return self._orport.close()
- def create_circuit(self) -> None: + def create_circuit(self) -> 'stem.client.Circuit': """ Establishes a new circuit. """ @@ -314,7 +318,7 @@ class Circuit(object): except ImportError: raise ImportError('Circuit construction requires the cryptography module')
- ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8)) + ctr = modes.CTR(ZERO * (algorithms.AES.block_size // 8)) # type: ignore
self.relay = relay self.id = circ_id @@ -323,7 +327,7 @@ class Circuit(object): self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor() self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request: str, stream_id: int = 0) -> str: + def directory(self, request: str, stream_id: int = 0) -> bytes: """ Request descriptors from the relay.
@@ -337,13 +341,13 @@ class Circuit(object): self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id) self._send(RelayCommand.DATA, request, stream_id = stream_id)
- response = [] + response = [] # type: List[stem.client.cell.RelayCell]
while True: # Decrypt relay cells received in response. Our digest/key only # updates when handled successfully.
- encrypted_cell = self.relay._recv(raw = True) + encrypted_cell = self.relay._recv_bytes()
decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
@@ -358,7 +362,7 @@ class Circuit(object): else: response.append(decrypted_cell)
- def _send(self, command: 'stem.client.datatype.RelayCommand', data: bytes = b'', stream_id: int = 0) -> None: + def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None: """ Sends a message over the circuit.
diff --git a/stem/client/cell.py b/stem/client/cell.py index ef445a64..c88ba716 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -49,12 +49,12 @@ from stem import UNDEFINED from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split from stem.util import datetime_to_unix, str_tools
-from typing import Any, Sequence, Tuple, Type +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, Union
FIXED_PAYLOAD_LEN = 509 # PAYLOAD_LEN, per tor-spec section 0.2 AUTH_CHALLENGE_SIZE = 32
-CELL_TYPE_SIZE = Size.CHAR +CELL_TYPE_SIZE = Size.CHAR # type: stem.client.datatype.Size PAYLOAD_LEN_SIZE = Size.SHORT RELAY_DIGEST_SIZE = Size.LONG
@@ -138,11 +138,11 @@ class Cell(object):
raise ValueError("'%s' isn't a valid cell value" % value)
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: raise NotImplementedError('Packing not yet implemented for %s cells' % type(self).NAME)
@staticmethod - def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell': + def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Iterator['stem.client.cell.Cell']: """ Unpacks all cells from a response.
@@ -193,7 +193,7 @@ class Cell(object): return cls._unpack(payload, circ_id, link_protocol), content
@classmethod - def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: int = None) -> bytes: + def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: Optional[int] = None) -> bytes: """ Provides bytes that can be used on the wire for these cell attributes. Format of a properly packed cell depends on if it's fixed or variable @@ -292,7 +292,7 @@ class PaddingCell(Cell): VALUE = 0 IS_FIXED_SIZE = True
- def __init__(self, payload: bytes = None) -> None: + def __init__(self, payload: Optional[bytes] = None) -> None: if not payload: payload = os.urandom(FIXED_PAYLOAD_LEN) elif len(payload) != FIXED_PAYLOAD_LEN: @@ -317,8 +317,8 @@ class CreateCell(CircuitCell): VALUE = 1 IS_FIXED_SIZE = True
- def __init__(self) -> None: - super(CreateCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(CreateCell, self).__init__(circ_id, unused) # TODO: implement
class CreatedCell(CircuitCell): @@ -326,8 +326,8 @@ class CreatedCell(CircuitCell): VALUE = 2 IS_FIXED_SIZE = True
- def __init__(self) -> None: - super(CreatedCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(CreatedCell, self).__init__(circ_id, unused) # TODO: implement
class RelayCell(CircuitCell): @@ -352,13 +352,13 @@ class RelayCell(CircuitCell): VALUE = 3 IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, command, data: bytes, digest: int = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None: + def __init__(self, circ_id: int, command, data: Union[bytes, str], digest: Union[int, bytes, str, 'hashlib._HASH'] = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None: # type: ignore if 'hash' in str(type(digest)).lower(): # Unfortunately hashlib generates from a dynamic private class so # isinstance() isn't such a great option. With python2/python3 the # name is 'hashlib.HASH' whereas PyPy calls it just 'HASH' or 'Hash'.
- digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size] + digest_packed = digest.digest()[:RELAY_DIGEST_SIZE.size] # type: ignore digest = RELAY_DIGEST_SIZE.unpack(digest_packed) elif isinstance(digest, (bytes, str)): digest_packed = digest[:RELAY_DIGEST_SIZE.size] @@ -393,7 +393,7 @@ class RelayCell(CircuitCell): return RelayCell._pack(link_protocol, bytes(payload), self.unused, self.circ_id)
@staticmethod - def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']: + def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore """ Decrypts content as a relay cell addressed to us. This provides back a tuple of the form... @@ -447,7 +447,7 @@ class RelayCell(CircuitCell):
return cell, new_key, new_digest
- def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']: + def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib._HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib._HASH']: # type: ignore """ Encrypts our cell content to be sent with the given key. This provides back a tuple of the form... @@ -540,7 +540,7 @@ class CreateFastCell(CircuitCell): VALUE = 5 IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, key_material: bytes = None, unused: bytes = b'') -> None: + def __init__(self, circ_id: int, key_material: Optional[bytes] = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -577,7 +577,7 @@ class CreatedFastCell(CircuitCell): VALUE = 6 IS_FIXED_SIZE = True
- def __init__(self, circ_id: int, derivative_key: bytes, key_material: bytes = None, unused: bytes = b'') -> None: + def __init__(self, circ_id: int, derivative_key: bytes, key_material: Optional[bytes] = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -594,7 +594,7 @@ class CreatedFastCell(CircuitCell): return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.unused, self.circ_id)
@classmethod - def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell': + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreatedFastCell': if len(content) < HASH_LEN * 2: raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
@@ -653,7 +653,7 @@ class NetinfoCell(Cell): VALUE = 8 IS_FIXED_SIZE = True
- def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: datetime.datetime = None, unused: bytes = b'') -> None: + def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: Optional[datetime.datetime] = None, unused: bytes = b'') -> None: super(NetinfoCell, self).__init__(unused) self.timestamp = timestamp if timestamp else datetime.datetime.now() self.receiver_address = receiver_address @@ -693,8 +693,8 @@ class RelayEarlyCell(CircuitCell): VALUE = 9 IS_FIXED_SIZE = True
- def __init__(self) -> None: - super(RelayEarlyCell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(RelayEarlyCell, self).__init__(circ_id, unused) # TODO: implement
class Create2Cell(CircuitCell): @@ -702,8 +702,8 @@ class Create2Cell(CircuitCell): VALUE = 10 IS_FIXED_SIZE = True
- def __init__(self) -> None: - super(Create2Cell, self).__init__() # TODO: implement + def __init__(self, circ_id: int, unused: bytes = b'') -> None: + super(Create2Cell, self).__init__(circ_id, unused) # TODO: implement
class Created2Cell(Cell): @@ -735,7 +735,7 @@ class VPaddingCell(Cell): VALUE = 128 IS_FIXED_SIZE = False
- def __init__(self, size: int = None, payload: bytes = None) -> None: + def __init__(self, size: Optional[int] = None, payload: Optional[bytes] = None) -> None: if size is None and payload is None: raise ValueError('VPaddingCell constructor must specify payload or size') elif size is not None and size < 0: @@ -768,7 +768,7 @@ class CertsCell(Cell): VALUE = 129 IS_FIXED_SIZE = False
- def __init__(self, certs: Sequence['stem.client.Certificate'], unused: bytes = b'') -> None: + def __init__(self, certs: Sequence['stem.client.datatype.Certificate'], unused: bytes = b'') -> None: super(CertsCell, self).__init__(unused) self.certificates = certs
@@ -778,7 +778,7 @@ class CertsCell(Cell): @classmethod def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CertsCell': cert_count, content = Size.CHAR.pop(content) - certs = [] + certs = [] # type: List[stem.client.datatype.Certificate]
for i in range(cert_count): if not content: @@ -806,7 +806,7 @@ class AuthChallengeCell(Cell): VALUE = 130 IS_FIXED_SIZE = False
- def __init__(self, methods: Sequence[int], challenge: bytes = None, unused: bytes = b'') -> None: + def __init__(self, methods: Sequence[int], challenge: Optional[bytes] = None, unused: bytes = b'') -> None: if not challenge: challenge = os.urandom(AUTH_CHALLENGE_SIZE) elif len(challenge) != AUTH_CHALLENGE_SIZE: diff --git a/stem/client/datatype.py b/stem/client/datatype.py index 8d8ae7fb..acc9ec34 100644 --- a/stem/client/datatype.py +++ b/stem/client/datatype.py @@ -144,7 +144,7 @@ import stem.util import stem.util.connection import stem.util.enum
-from typing import Any, Tuple, Type, Union +from typing import Any, Optional, Tuple, Union
ZERO = b'\x00' HASH_LEN = 20 @@ -157,17 +157,17 @@ class _IntegerEnum(stem.util.enum.Enum): **UNKNOWN** value for integer values that lack a mapping. """
- def __init__(self, *args: Tuple[str, int]) -> None: + def __init__(self, *args: Union[Tuple[str, int], Tuple[str, str, int]]) -> None: self._enum_to_int = {} self._int_to_enum = {} parent_args = []
for entry in args: if len(entry) == 2: - enum, int_val = entry + enum, int_val = entry # type: ignore str_val = enum elif len(entry) == 3: - enum, str_val, int_val = entry + enum, str_val, int_val = entry # type: ignore else: raise ValueError('IntegerEnums can only be constructed with two or three value tuples: %s' % repr(entry))
@@ -272,19 +272,16 @@ class LinkProtocol(int): from a range that's determined by our link protocol. """
- def __new__(cls: Type['stem.client.datatype.LinkProtocol'], version: int) -> 'stem.client.datatype.LinkProtocol': - if isinstance(version, LinkProtocol): - return version # already a LinkProtocol + def __new__(self, version: int) -> 'stem.client.datatype.LinkProtocol': + return int.__new__(self, version) # type: ignore
- protocol = int.__new__(cls, version) - protocol.version = version - protocol.circ_id_size = Size.LONG if version > 3 else Size.SHORT - protocol.first_circ_id = 0x80000000 if version > 3 else 0x01 + def __init__(self, version: int) -> None: + self.version = version + self.circ_id_size = Size.LONG if version > 3 else Size.SHORT + self.first_circ_id = 0x80000000 if version > 3 else 0x01
- cell_header_size = protocol.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte) - protocol.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN - - return protocol + cell_header_size = self.circ_id_size.size + 1 # circuit id (2 or 4 bytes) + command (1 byte) + self.fixed_cell_length = cell_header_size + stem.client.cell.FIXED_PAYLOAD_LEN
def __hash__(self) -> int: # All LinkProtocol attributes can be derived from our version, so that's @@ -380,6 +377,11 @@ class Size(Field): ==================== =========== """
+ CHAR = None # type: Optional[stem.client.datatype.Size] + SHORT = None # type: Optional[stem.client.datatype.Size] + LONG = None # type: Optional[stem.client.datatype.Size] + LONG_LONG = None # type: Optional[stem.client.datatype.Size] + def __init__(self, name: str, size: int) -> None: self.name = name self.size = size @@ -388,7 +390,7 @@ class Size(Field): def pop(packed: bytes) -> Tuple[int, bytes]: raise NotImplementedError("Use our constant's unpack() and pop() instead")
- def pack(self, content: int) -> bytes: + def pack(self, content: int) -> bytes: # type: ignore try: return content.to_bytes(self.size, 'big') except: @@ -399,13 +401,13 @@ class Size(Field): else: raise
- def unpack(self, packed: bytes) -> int: + def unpack(self, packed: bytes) -> int: # type: ignore if self.size != len(packed): raise ValueError('%s is the wrong size for a %s field' % (repr(packed), self.name))
return int.from_bytes(packed, 'big')
- def pop(self, packed: bytes) -> Tuple[int, bytes]: + def pop(self, packed: bytes) -> Tuple[int, bytes]: # type: ignore to_unpack, remainder = split(packed, self.size)
return self.unpack(to_unpack), remainder @@ -420,48 +422,53 @@ class Address(Field):
:var stem.client.AddrType type: address type :var int type_int: integer value of the address type - :var unicode value: address value + :var str value: address value :var bytes value_bin: encoded address value """
- def __init__(self, value: str, addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None: + def __init__(self, value: Union[bytes, str], addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None: if addr_type is None: - if stem.util.connection.is_valid_ipv4_address(value): + if stem.util.connection.is_valid_ipv4_address(value): # type: ignore addr_type = AddrType.IPv4 - elif stem.util.connection.is_valid_ipv6_address(value): + elif stem.util.connection.is_valid_ipv6_address(value): # type: ignore addr_type = AddrType.IPv6 else: - raise ValueError("'%s' isn't an IPv4 or IPv6 address" % value) + raise ValueError("'%s' isn't an IPv4 or IPv6 address" % stem.util.str_tools._to_unicode(value)) + + value_bytes = stem.util.str_tools._to_bytes(value) + + self.value = None # type: Optional[str] + self.value_bin = None # type: Optional[bytes]
self.type, self.type_int = AddrType.get(addr_type)
if self.type == AddrType.IPv4: - if stem.util.connection.is_valid_ipv4_address(value): - self.value = value - self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value.split('.')]) + if stem.util.connection.is_valid_ipv4_address(value_bytes): # type: ignore + self.value = stem.util.str_tools._to_unicode(value_bytes) + self.value_bin = b''.join([Size.CHAR.pack(int(v)) for v in value_bytes.split(b'.')]) else: - if len(value) != 4: + if len(value_bytes) != 4: raise ValueError('Packed IPv4 addresses should be four bytes, but was: %s' % repr(value))
- self.value = _unpack_ipv4_address(value) - self.value_bin = value + self.value = _unpack_ipv4_address(value_bytes) + self.value_bin = value_bytes elif self.type == AddrType.IPv6: - if stem.util.connection.is_valid_ipv6_address(value): - self.value = stem.util.connection.expand_ipv6_address(value).lower() + if stem.util.connection.is_valid_ipv6_address(value_bytes): # type: ignore + self.value = stem.util.connection.expand_ipv6_address(value_bytes).lower() # type: ignore self.value_bin = b''.join([Size.SHORT.pack(int(v, 16)) for v in self.value.split(':')]) else: - if len(value) != 16: + if len(value_bytes) != 16: raise ValueError('Packed IPv6 addresses should be sixteen bytes, but was: %s' % repr(value))
- self.value = _unpack_ipv6_address(value) - self.value_bin = value + self.value = _unpack_ipv6_address(value_bytes) + self.value_bin = value_bytes else: # The spec doesn't really tell us what form to expect errors to be. For # now just leaving the value unset so we can fill it in later when we # know what would be most useful.
self.value = None - self.value_bin = value + self.value_bin = value_bytes
def pack(self) -> bytes: cell = bytearray() @@ -471,7 +478,7 @@ class Address(Field): return bytes(cell)
@staticmethod - def pop(content) -> Tuple['stem.client.datatype.Address', bytes]: + def pop(content: bytes) -> Tuple['stem.client.datatype.Address', bytes]: addr_type, content = Size.CHAR.pop(content) addr_length, content = Size.CHAR.pop(content)
@@ -590,7 +597,7 @@ class LinkByIPv4(LinkSpecifier): @staticmethod def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv4': if len(value) != 6: - raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
addr, port = split(value, 4) return LinkByIPv4(_unpack_ipv4_address(addr), Size.SHORT.unpack(port)) @@ -615,7 +622,7 @@ class LinkByIPv6(LinkSpecifier): @staticmethod def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv6': if len(value) != 18: - raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
addr, port = split(value, 16) return LinkByIPv6(_unpack_ipv6_address(addr), Size.SHORT.unpack(port)) @@ -634,7 +641,7 @@ class LinkByFingerprint(LinkSpecifier): super(LinkByFingerprint, self).__init__(2, value)
if len(value) != 20: - raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('Fingerprint link specifiers should be twenty bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
self.fingerprint = stem.util.str_tools._to_unicode(value)
@@ -652,7 +659,7 @@ class LinkByEd25519(LinkSpecifier): super(LinkByEd25519, self).__init__(3, value)
if len(value) != 32: - raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value))) + raise ValueError('Fingerprint link specifiers should be thirty two bytes, but was %i instead: %s' % (len(value), stem.util.str_tools._to_unicode(binascii.hexlify(value))))
self.fingerprint = stem.util.str_tools._to_unicode(value)
@@ -695,7 +702,7 @@ def _pack_ipv4_address(address: str) -> bytes: return b''.join([Size.CHAR.pack(int(v)) for v in address.split('.')])
-def _unpack_ipv4_address(value: str) -> bytes: +def _unpack_ipv4_address(value: bytes) -> str: return '.'.join([str(Size.CHAR.unpack(value[i:i + 1])) for i in range(4)])
@@ -703,7 +710,7 @@ def _pack_ipv6_address(address: str) -> bytes: return b''.join([Size.SHORT.pack(int(v, 16)) for v in address.split(':')])
-def _unpack_ipv6_address(value: str) -> bytes: +def _unpack_ipv6_address(value: bytes) -> str: return ':'.join(['%04x' % Size.SHORT.unpack(value[i * 2:(i + 1) * 2]) for i in range(8)])
diff --git a/stem/connection.py b/stem/connection.py index 3d3eb3ee..ff950a0c 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -143,7 +143,7 @@ import stem.util.str_tools import stem.util.system import stem.version
-from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN') @@ -211,7 +211,7 @@ COMMON_TOR_COMMANDS = ( )
-def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: type = stem.control.Controller) -> Union[stem.control.BaseController, stem.socket.ControlSocket]: +def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type = stem.control.Controller) -> Any: """ Convenience function for quickly getting a control connection. This is very handy for debugging or CLI setup, handling setup and prompting for a password @@ -250,6 +250,8 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so **control_port** and **control_socket** are **None** """
+ # TODO: change this function's API so we can provide a concrete type + if control_port is None and control_socket is None: raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.') elif control_port: @@ -260,7 +262,8 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so elif control_port[1] != 'default' and not stem.util.connection.is_valid_port(control_port[1]): raise ValueError("'%s' isn't a valid port" % control_port[1])
- control_connection, error_msg = None, '' + control_connection = None # type: Optional[stem.socket.ControlSocket] + error_msg = ''
if control_socket: if os.path.exists(control_socket): @@ -297,7 +300,7 @@ def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_so return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Union[Type[stem.control.BaseController], Type[stem.socket.ControlSocket]]) -> Union[stem.control.BaseController, stem.socket.ControlSocket]: +def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any: """ Helper for the connect_* functions that authenticates the socket and constructs the controller. @@ -363,7 +366,7 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass return None
-def authenticate(controller: Any, password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None: +def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None: """ Authenticates to a control socket using the information provided by a PROTOCOLINFO response. In practice this will often be all we need to @@ -481,7 +484,7 @@ def authenticate(controller: Any, password: Optional[str] = None, chroot_path: O raise AuthenticationFailure('socket connection failed (%s)' % exc)
auth_methods = list(protocolinfo_response.auth_methods) - auth_exceptions = [] + auth_exceptions = [] # type: List[stem.connection.AuthenticationFailure]
if len(auth_methods) == 0: raise NoAuthMethods('our PROTOCOLINFO response did not have any methods for authenticating') @@ -846,10 +849,11 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
cookie_data = _read_cookie(cookie_path, True) client_nonce = os.urandom(32) + authchallenge_response = None # type: stem.response.authchallenge.AuthChallengeResponse
try: client_nonce_hex = stem.util.str_tools._to_unicode(binascii.b2a_hex(client_nonce)) - authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex) + authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex) # type: ignore
if not authchallenge_response.is_ok(): try: @@ -862,13 +866,18 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem. if 'Authentication required.' in authchallenge_response_str: raise AuthChallengeUnsupported("SAFECOOKIE authentication isn't supported", cookie_path) elif 'AUTHCHALLENGE only supports' in authchallenge_response_str: - raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path) + # TODO: This code path has been broken for years. Do we still need it? + # If so, what should authchallenge_method be? + + authchallenge_method = None + + raise UnrecognizedAuthChallengeMethod(authchallenge_response_str, cookie_path, authchallenge_method) elif 'Invalid base16 client nonce' in authchallenge_response_str: raise InvalidClientNonce(authchallenge_response_str, cookie_path) elif 'Cookie authentication is disabled' in authchallenge_response_str: raise CookieAuthRejected(authchallenge_response_str, cookie_path, True) else: - raise AuthChallengeFailed(authchallenge_response, cookie_path) + raise AuthChallengeFailed(authchallenge_response_str, cookie_path) except stem.ControllerError as exc: try: controller.connect() @@ -878,7 +887,7 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem. if not suppress_ctl_errors: raise else: - raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path, True) + raise AuthChallengeFailed('Socket failed (%s)' % exc, cookie_path)
try: stem.response.convert('AUTHCHALLENGE', authchallenge_response) @@ -970,7 +979,7 @@ def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket. raise stem.SocketError(exc)
stem.response.convert('PROTOCOLINFO', protocolinfo_response) - return protocolinfo_response + return protocolinfo_response # type: ignore
def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage: @@ -1008,7 +1017,7 @@ def _connection_for_default_port(address: str) -> stem.socket.ControlPort: raise exc
-def _read_cookie(cookie_path: str, is_safecookie: bool) -> str: +def _read_cookie(cookie_path: str, is_safecookie: bool) -> bytes: """ Provides the contents of a given cookie file.
@@ -1016,7 +1025,7 @@ def _read_cookie(cookie_path: str, is_safecookie: bool) -> str: :param bool is_safecookie: **True** if this was for SAFECOOKIE authentication, **False** if for COOKIE
- :returns: **str** with the cookie file content + :returns: **bytes** with the cookie file content
:raises: * :class:`stem.connection.UnreadableCookieFile` if the cookie file is @@ -1052,12 +1061,12 @@ def _read_cookie(cookie_path: str, is_safecookie: bool) -> str: raise UnreadableCookieFile(exc_msg, cookie_path, is_safecookie)
-def _hmac_sha256(key: str, msg: str) -> bytes: +def _hmac_sha256(key: bytes, msg: bytes) -> bytes: """ Generates a sha256 digest using the given key and message.
- :param str key: starting key for the hash - :param str msg: message to be hashed + :param bytes key: starting key for the hash + :param bytes msg: message to be hashed
:returns: sha256 digest of msg as bytes, hashed using the given key """ diff --git a/stem/control.py b/stem/control.py index ec4ba54e..626b2b3e 100644 --- a/stem/control.py +++ b/stem/control.py @@ -271,7 +271,7 @@ import stem.version from stem import UNDEFINED, CircStatus, Signal from stem.util import log from types import TracebackType -from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events, # but if it takes longer than this we terminate. @@ -404,7 +404,7 @@ SERVER_DESCRIPTORS_UNSUPPORTED = "Tor is currently not configured to retrieve \ server descriptors. As of Tor version 0.2.3.25 it downloads microdescriptors \ instead unless you set 'UseMicrodescriptors 0' in your torrc."
-EVENT_DESCRIPTIONS = None +EVENT_DESCRIPTIONS = None # type: Dict[str, str]
class AccountingStats(collections.namedtuple('AccountingStats', ['retrieved', 'status', 'interval_end', 'time_until_reset', 'read_bytes', 'read_bytes_left', 'read_limit', 'written_bytes', 'write_bytes_left', 'write_limit'])): @@ -518,7 +518,7 @@ def event_description(event: str) -> str:
try: config.load(config_path) - EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')]) + EVENT_DESCRIPTIONS = dict([(key.lower()[18:], config.get_value(key)) for key in config.keys() if key.startswith('event.description.')]) # type: ignore except Exception as exc: log.warn("BUG: stem failed to load its internal manual information from '%s': %s" % (config_path, exc)) return None @@ -546,19 +546,19 @@ class BaseController(object): self._socket = control_socket self._msg_lock = threading.RLock()
- self._status_listeners = [] # tuples of the form (callback, spawn_thread) + self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread) self._status_listeners_lock = threading.RLock()
# queues where incoming messages are directed - self._reply_queue = queue.Queue() - self._event_queue = queue.Queue() + self._reply_queue = queue.Queue() # type: queue.Queue[Union[stem.response.ControlMessage, stem.ControllerError]] + self._event_queue = queue.Queue() # type: queue.Queue[stem.response.ControlMessage]
# thread to continually pull from the control socket - self._reader_thread = None + self._reader_thread = None # type: Optional[threading.Thread]
# thread to pull from the _event_queue and call handle_event self._event_notice = threading.Event() - self._event_thread = None + self._event_thread = None # type: Optional[threading.Thread]
# saves our socket's prior _connect() and _close() methods so they can be # called along with ours @@ -566,13 +566,13 @@ class BaseController(object): self._socket_connect = self._socket._connect self._socket_close = self._socket._close
- self._socket._connect = self._connect - self._socket._close = self._close + self._socket._connect = self._connect # type: ignore + self._socket._close = self._close # type: ignore
self._last_heartbeat = 0.0 # timestamp for when we last heard from tor self._is_authenticated = False
- self._state_change_threads = [] # threads we've spawned to notify of state changes + self._state_change_threads = [] # type: List[threading.Thread] # threads we've spawned to notify of state changes
if self._socket.is_alive(): self._launch_threads() @@ -757,7 +757,7 @@ class BaseController(object):
return self._last_heartbeat
- def add_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None], spawn: bool = True) -> None: + def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None: """ Notifies a given function when the state of our socket changes. Functions are expected to be of the form... @@ -986,7 +986,7 @@ class Controller(BaseController): """
@staticmethod - def from_port(address: str = '127.0.0.1', port: int = 'default') -> 'stem.control.Controller': + def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller': """ Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -1015,7 +1015,7 @@ class Controller(BaseController): if port == 'default': control_port = stem.connection._connection_for_default_port(address) else: - control_port = stem.socket.ControlPort(address, port) + control_port = stem.socket.ControlPort(address, int(port))
return Controller(control_port)
@@ -1036,33 +1036,33 @@ class Controller(BaseController):
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._is_caching_enabled = True - self._request_cache = {} + self._request_cache = {} # type: Dict[str, Any] self._last_newnym = 0.0
self._cache_lock = threading.RLock()
# mapping of event types to their listeners
- self._event_listeners = {} + self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], None]]] self._event_listeners_lock = threading.RLock() - self._enabled_features = [] + self._enabled_features = [] # type: List[str]
- self._last_address_exc = None - self._last_fingerprint_exc = None + self._last_address_exc = None # type: Optional[BaseException] + self._last_fingerprint_exc = None # type: Optional[BaseException]
super(Controller, self).__init__(control_socket, is_authenticated)
- def _sighup_listener(event: stem.response.events.Event) -> None: + def _sighup_listener(event: stem.response.events.SignalEvent) -> None: if event.signal == Signal.RELOAD: self.clear_cache() self._notify_status_listeners(State.RESET)
- self.add_event_listener(_sighup_listener, EventType.SIGNAL) + self.add_event_listener(_sighup_listener, EventType.SIGNAL) # type: ignore
- def _confchanged_listener(event: stem.response.events.Event) -> None: + def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None: if self.is_caching_enabled(): to_cache_changed = dict((k.lower(), v) for k, v in event.changed.items()) - to_cache_unset = dict((k.lower(), []) for k in event.unset) # [] represents None value in cache + to_cache_unset = dict((k.lower(), []) for k in event.unset) # type: Dict[str, List[str]] # [] represents None value in cache
to_cache = {} to_cache.update(to_cache_changed) @@ -1072,15 +1072,15 @@ class Controller(BaseController):
self._confchanged_cache_invalidation(to_cache)
- self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED) + self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED) # type: ignore
- def _address_changed_listener(event: stem.response.events.Event) -> None: + def _address_changed_listener(event: stem.response.events.StatusEvent) -> None: if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'): self._set_cache({'exit_policy': None}) self._set_cache({'address': None}, 'getinfo') self._last_address_exc = None
- self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER) + self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER) # type: ignore
def close(self) -> None: self.clear_cache() @@ -1152,15 +1152,15 @@ class Controller(BaseController):
if isinstance(params, (bytes, str)): is_multiple = False - params = set([params]) + param_set = set([params]) else: if not params: return {}
is_multiple = True - params = set(params) + param_set = set(params)
- for param in params: + for param in param_set: if param.startswith('ip-to-country/') and param != 'ip-to-country/0.0.0.0' and self.get_info('ip-to-country/ipv4-available', '0') != '1': raise stem.ProtocolError('Tor geoip database is unavailable') elif param == 'address' and self._last_address_exc: @@ -1170,16 +1170,16 @@ class Controller(BaseController):
# check for cached results
- from_cache = [param.lower() for param in params] + from_cache = [param.lower() for param in param_set] cached_results = self._get_cache_map(from_cache, 'getinfo')
for key in cached_results: - user_expected_key = _case_insensitive_lookup(params, key) + user_expected_key = _case_insensitive_lookup(param_set, key) reply[user_expected_key] = cached_results[key] - params.remove(user_expected_key) + param_set.remove(user_expected_key)
# if everything was cached then short circuit making the query - if not params: + if not param_set: if LOG_CACHE_FETCHES: log.trace('GETINFO %s (cache fetch)' % ' '.join(reply.keys()))
@@ -1189,14 +1189,13 @@ class Controller(BaseController): return list(reply.values())[0]
try: - response = self.msg('GETINFO %s' % ' '.join(params)) - stem.response.convert('GETINFO', response) - response._assert_matches(params) + response = stem.response._convert_to_getinfo(self.msg('GETINFO %s' % ' '.join(param_set))) + response._assert_matches(param_set)
# usually we want unicode values under python 3.x
if not get_bytes: - response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items()) + response.entries = dict((k, stem.util.str_tools._to_unicode(v)) for (k, v) in response.entries.items()) # type: ignore
reply.update(response.entries)
@@ -1213,26 +1212,26 @@ class Controller(BaseController):
self._set_cache(to_cache, 'getinfo')
- if 'address' in params: + if 'address' in param_set: self._last_address_exc = None
- if 'fingerprint' in params: + if 'fingerprint' in param_set: self._last_fingerprint_exc = None
- log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(params), time.time() - start_time)) + log.debug('GETINFO %s (runtime: %0.4f)' % (' '.join(param_set), time.time() - start_time))
if is_multiple: return reply else: return list(reply.values())[0] except stem.ControllerError as exc: - if 'address' in params: + if 'address' in param_set: self._last_address_exc = exc
- if 'fingerprint' in params: + if 'fingerprint' in param_set: self._last_fingerprint_exc = exc
- log.debug('GETINFO %s (failed: %s)' % (' '.join(params), exc)) + log.debug('GETINFO %s (failed: %s)' % (' '.join(param_set), exc)) raise
@with_default() @@ -1363,7 +1362,7 @@ class Controller(BaseController):
if listeners is None: proxy_addrs = [] - query = 'net/listeners/%s' % listener_type.lower() + query = 'net/listeners/%s' % str(listener_type).lower()
try: for listener in self.get_info(query).split(): @@ -1413,7 +1412,7 @@ class Controller(BaseController): Listener.CONTROL: 'ControlListenAddress', }[listener_type]
- port_value = self.get_conf(port_option).split()[0] + port_value = self._get_conf_single(port_option).split()[0]
for listener in self.get_conf(listener_option, multiple = True): if ':' in listener: @@ -1571,7 +1570,7 @@ class Controller(BaseController): pid = int(getinfo_pid)
if not pid and self.is_localhost(): - pid_file_path = self.get_conf('PidFile', None) + pid_file_path = self._get_conf_single('PidFile', None)
if pid_file_path is not None: with open(pid_file_path) as pid_file: @@ -1666,7 +1665,7 @@ class Controller(BaseController):
return time.time() - self.get_start_time()
- def is_user_traffic_allowed(self) -> bool: + def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed': """ Checks if we're likely to service direct user traffic. This essentially boils down to... @@ -1687,7 +1686,7 @@ class Controller(BaseController):
.. versionadded:: 1.5.0
- :returns: :class:`~stem.cotroller.UserTrafficAllowed` with **inbound** and + :returns: :class:`~stem.control.UserTrafficAllowed` with **inbound** and **outbound** boolean attributes to indicate if we're likely servicing direct user traffic """ @@ -1860,7 +1859,7 @@ class Controller(BaseController): return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True) - def get_server_descriptors(self, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor: + def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ get_server_descriptors(default = UNDEFINED)
@@ -1893,7 +1892,7 @@ class Controller(BaseController): raise stem.DescriptorUnavailable('Descriptor information is unavailable, tor might still be downloading it')
for desc in stem.descriptor.server_descriptor._parse_file(io.BytesIO(desc_content)): - yield desc + yield desc # type: ignore
@with_default() def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3: @@ -1989,7 +1988,7 @@ class Controller(BaseController): )
for desc in desc_iterator: - yield desc + yield desc # type: ignore
@with_default() def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: @@ -2035,8 +2034,12 @@ class Controller(BaseController): if not stem.util.tor_tools.is_valid_hidden_service_address(address): raise ValueError("'%s.onion' isn't a valid hidden service address" % address)
- hs_desc_queue, hs_desc_listener = queue.Queue(), None - hs_desc_content_queue, hs_desc_content_listener = queue.Queue(), None + hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_listener = None + + hs_desc_content_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_content_listener = None + start_time = time.time()
if await_result: @@ -2055,8 +2058,7 @@ class Controller(BaseController): if servers: request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
- response = self.msg(request) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(request))
if not response.is_ok(): raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code) @@ -2137,6 +2139,14 @@ class Controller(BaseController): entries = self.get_conf_map(param, default, multiple) return _case_insensitive_lookup(entries, param, default)
+ # TODO: temporary aliases until we have better type support in our API + + def _get_conf_single(self, param: str, default: Any = UNDEFINED) -> str: + return self.get_conf(param, default) # type: ignore + + def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]: + return self.get_conf(param, default, multiple = True) # type: ignore + def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: """ get_conf_map(params, default = UNDEFINED, multiple = True) @@ -2218,8 +2228,7 @@ class Controller(BaseController): return self._get_conf_dict_to_response(reply, default, multiple)
try: - response = self.msg('GETCONF %s' % ' '.join(lookup_params)) - stem.response.convert('GETCONF', response) + response = stem.response._convert_to_getconf(self.msg('GETCONF %s' % ' '.join(lookup_params))) reply.update(response.entries)
if self.is_caching_enabled(): @@ -2414,8 +2423,7 @@ class Controller(BaseController): raise ValueError('Cannot set %s to %s since the value was a %s but we only accept strings' % (param, value, type(value).__name__))
query = ' '.join(query_comp) - response = self.msg(query) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(query))
if response.is_ok(): log.debug('%s (runtime: %0.4f)' % (query, time.time() - start_time)) @@ -2489,15 +2497,14 @@ class Controller(BaseController): start_time = time.time()
try: - response = self.msg('GETCONF HiddenServiceOptions') - stem.response.convert('GETCONF', response) + response = stem.response._convert_to_getconf(self.msg('GETCONF HiddenServiceOptions')) log.debug('GETCONF HiddenServiceOptions (runtime: %0.4f)' % (time.time() - start_time)) except stem.ControllerError as exc: log.debug('GETCONF HiddenServiceOptions (failed: %s)' % exc) raise
- service_dir_map = collections.OrderedDict() + service_dir_map = collections.OrderedDict() # type: collections.OrderedDict[str, Any] directory = None
for status_code, divider, content in response.content(): @@ -2603,7 +2610,7 @@ class Controller(BaseController):
self.set_options(hidden_service_options)
- def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.cotroller.CreateHiddenServiceOutput': + def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput': """ Create a new hidden service. If the directory is already present, a new port is added. @@ -2629,7 +2636,7 @@ class Controller(BaseController): :param str auth_type: authentication type: basic, stealth or None to disable auth :param list client_names: client names (1-16 characters "A-Za-z0-9+-_")
- :returns: :class:`~stem.cotroller.CreateHiddenServiceOutput` if we create + :returns: :class:`~stem.control.CreateHiddenServiceOutput` if we create or update a hidden service, **None** otherwise
:raises: :class:`stem.ControllerError` if the call fails @@ -2905,7 +2912,8 @@ class Controller(BaseController): * :class:`stem.Timeout` if **timeout** was reached """
- hs_desc_queue, hs_desc_listener = queue.Queue(), None + hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_listener = None start_time = time.time()
if await_publication: @@ -2957,8 +2965,7 @@ class Controller(BaseController): else: request += ' ClientAuth=%s' % client_name
- response = self.msg(request) - stem.response.convert('ADD_ONION', response) + response = stem.response._convert_to_add_onion(stem.response._convert_to_add_onion(self.msg(request)))
if await_publication: # We should receive five UPLOAD events, followed by up to another five @@ -3002,8 +3009,7 @@ class Controller(BaseController): :raises: :class:`stem.ControllerError` if the call fails """
- response = self.msg('DEL_ONION %s' % service_id) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('DEL_ONION %s' % service_id))
if response.is_ok(): return True @@ -3056,7 +3062,7 @@ class Controller(BaseController): event_type = stem.response.events.EVENT_TYPE_TO_CLASS.get(event_type)
if event_type and (self.get_version() < event_type._VERSION_ADDED): - raise stem.InvalidRequest(552, '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED)) + raise stem.InvalidRequest('552', '%s event requires Tor version %s or later' % (event_type, event_type._VERSION_ADDED))
for event_type in events: self._event_listeners.setdefault(event_type, []).append(listener) @@ -3135,7 +3141,7 @@ class Controller(BaseController):
return cached_values
- def _set_cache(self, params: Mapping[str, Any], namespace: Optional[str] = None) -> None: + def _set_cache(self, params: Dict[str, Any], namespace: Optional[str] = None) -> None: """ Sets the given request cache entries. If the new cache value is **None** then it is removed from our cache. @@ -3241,8 +3247,7 @@ class Controller(BaseController): :raises: :class:`stem.ControllerError` if the call fails """
- response = self.msg('LOADCONF\n%s' % configtext) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('LOADCONF\n%s' % configtext))
if response.code in ('552', '553'): if response.code == '552' and response.message.startswith('Invalid config file: Failed to parse/validate config: Unknown option'): @@ -3267,11 +3272,10 @@ class Controller(BaseController): the configuration file """
- response = self.msg('SAVECONF FORCE' if force else 'SAVECONF') - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SAVECONF FORCE' if force else 'SAVECONF'))
if response.is_ok(): - return True + pass elif response.code == '551': raise stem.OperationFailed(response.code, response.message) else: @@ -3311,8 +3315,7 @@ class Controller(BaseController): if isinstance(features, (bytes, str)): features = [features]
- response = self.msg('USEFEATURE %s' % ' '.join(features)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('USEFEATURE %s' % ' '.join(features)))
if not response.is_ok(): if response.code == '552': @@ -3353,7 +3356,7 @@ class Controller(BaseController): raise ValueError("Tor currently does not have a circuit with the id of '%s'" % circuit_id)
@with_default() - def get_circuits(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.CircuitEvent]: + def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]: """ get_circuits(default = UNDEFINED)
@@ -3366,13 +3369,12 @@ class Controller(BaseController): :raises: :class:`stem.ControllerError` if the call fails and no default was provided """
- circuits = [] + circuits = [] # type: List[stem.response.events.CircuitEvent] response = self.get_info('circuit-status')
for circ in response.splitlines(): - circ_message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ))) - stem.response.convert('EVENT', circ_message) - circuits.append(circ_message) + circ_message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 CIRC %s\r\n' % circ)))) + circuits.append(circ_message) # type: ignore
return circuits
@@ -3442,7 +3444,8 @@ class Controller(BaseController): # to build. This is icky, but we can't reliably do this via polling since # we then can't get the failure if it can't be created.
- circ_queue, circ_listener = queue.Queue(), None + circ_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + circ_listener = None start_time = time.time()
if await_build: @@ -3463,8 +3466,7 @@ class Controller(BaseController): if purpose: args.append('purpose=%s' % purpose)
- response = self.msg('EXTENDCIRCUIT %s' % ' '.join(args)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('EXTENDCIRCUIT %s' % ' '.join(args)))
if response.code in ('512', '552'): raise stem.InvalidRequest(response.code, response.message) @@ -3505,8 +3507,7 @@ class Controller(BaseController): :raises: :class:`stem.InvalidArguments` if the circuit doesn't exist or if the purpose was invalid """
- response = self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SETCIRCUITPURPOSE %s purpose=%s' % (circuit_id, purpose)))
if not response.is_ok(): if response.code == '552': @@ -3527,8 +3528,7 @@ class Controller(BaseController): * :class:`stem.InvalidRequest` if not enough information is provided """
- response = self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag)))
if not response.is_ok(): if response.code in ('512', '552'): @@ -3539,7 +3539,7 @@ class Controller(BaseController): raise stem.ProtocolError('CLOSECIRCUIT returned unexpected response code: %s' % response.code)
@with_default() - def get_streams(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.StreamEvent]: + def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]: """ get_streams(default = UNDEFINED)
@@ -3553,13 +3553,12 @@ class Controller(BaseController): provided """
- streams = [] + streams = [] # type: List[stem.response.events.StreamEvent] response = self.get_info('stream-status')
for stream in response.splitlines(): - message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream))) - stem.response.convert('EVENT', message) - streams.append(message) + message = stem.response._convert_to_event(stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes('650 STREAM %s\r\n' % stream)))) + streams.append(message) # type: ignore
return streams
@@ -3585,8 +3584,7 @@ class Controller(BaseController): if exiting_hop: query += ' HOP=%s' % exiting_hop
- response = self.msg(query) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg(query))
if not response.is_ok(): if response.code == '552': @@ -3614,8 +3612,7 @@ class Controller(BaseController): # there's a single value offset between RelayEndReason.index_of() and the # value that tor expects since tor's value starts with the index of one
- response = self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag)) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('CLOSESTREAM %s %s %s' % (stream_id, stem.RelayEndReason.index_of(reason) + 1, flag)))
if not response.is_ok(): if response.code in ('512', '552'): @@ -3638,8 +3635,7 @@ class Controller(BaseController): * :class:`stem.InvalidArguments` if signal provided wasn't recognized """
- response = self.msg('SIGNAL %s' % signal) - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('SIGNAL %s' % signal))
if response.is_ok(): if signal == stem.Signal.NEWNYM: @@ -3703,14 +3699,14 @@ class Controller(BaseController): """
if not burst: - attributes = ('BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth') + attributes = ['BandwidthRate', 'RelayBandwidthRate', 'MaxAdvertisedBandwidth'] else: - attributes = ('BandwidthBurst', 'RelayBandwidthBurst') + attributes = ['BandwidthBurst', 'RelayBandwidthBurst']
value = None
for attr in attributes: - attr_value = int(self.get_conf(attr)) + attr_value = int(self._get_conf_single(attr))
if attr_value == 0 and attr.startswith('Relay'): continue # RelayBandwidthRate and RelayBandwidthBurst default to zero @@ -3740,9 +3736,7 @@ class Controller(BaseController):
mapaddress_arg = ' '.join(['%s=%s' % (k, v) for (k, v) in list(mapping.items())]) response = self.msg('MAPADDRESS %s' % mapaddress_arg) - stem.response.convert('MAPADDRESS', response) - - return response.entries + return stem.response._convert_to_mapaddress(response).entries
def drop_guards(self) -> None: """ @@ -3779,8 +3773,7 @@ class Controller(BaseController): owning_pid = self.get_conf('__OwningControllerProcess', None)
if owning_pid == str(os.getpid()) and self.is_localhost(): - response = self.msg('TAKEOWNERSHIP') - stem.response.convert('SINGLELINE', response) + response = stem.response._convert_to_single_line(self.msg('TAKEOWNERSHIP'))
if response.is_ok(): # Now that tor is tracking our ownership of the process via the control @@ -3793,11 +3786,18 @@ class Controller(BaseController): else: log.warn('We were unable assert ownership of tor through TAKEOWNERSHIP, despite being configured to be the owning process through __OwningControllerProcess. (%s)' % response)
- def _handle_event(self, event_message: str) -> None: + def _handle_event(self, event_message: stem.response.ControlMessage) -> None: + event = None # type: Optional[stem.response.events.Event] + try: - stem.response.convert('EVENT', event_message) - event_type = event_message.type + event = stem.response._convert_to_event(event_message) + event_type = event.type except stem.ProtocolError as exc: + # TODO: We should change this so malformed events convert to the base + # Event class, so we don't provide raw ControlMessages to listeners. + + event = event_message # type: ignore + log.error('Tor sent a malformed event (%s): %s' % (exc, event_message)) event_type = MALFORMED_EVENTS
@@ -3806,9 +3806,9 @@ class Controller(BaseController): if listener_type == event_type: for listener in event_listeners: try: - listener(event_message) + listener(event) except Exception as exc: - log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event_message)) + log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]: """ diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py index 9c769749..477e15e9 100644 --- a/stem/descriptor/__init__.py +++ b/stem/descriptor/__init__.py @@ -108,6 +108,7 @@ import base64 import codecs import collections import copy +import hashlib import io import os import random @@ -120,7 +121,7 @@ import stem.util.enum import stem.util.str_tools import stem.util.system
-from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type +from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
__all__ = [ 'bandwidth_file', @@ -152,7 +153,7 @@ KEYWORD_LINE = re.compile('^([%s]+)(?:[%s]+(.*))?$' % (KEYWORD_CHAR, WHITESPACE) SPECIFIC_KEYWORD_LINE = '^(%%s)(?:[%s]+(.*))?$' % WHITESPACE PGP_BLOCK_START = re.compile('^-----BEGIN ([%s%s]+)-----$' % (KEYWORD_CHAR, WHITESPACE)) PGP_BLOCK_END = '-----END %s-----' -EMPTY_COLLECTION = ([], {}, set()) +EMPTY_COLLECTION = ([], {}, set()) # type: ignore
DIGEST_TYPE_INFO = b'\x00\x01' DIGEST_PADDING = b'\xFF' @@ -164,6 +165,8 @@ skFtXhOHHqTRN4GPPrZsAIUOQGzQtGb66IQgT4tO/pj+P6QmSCCdTfhvGfgTCsC+ WPi4Fl2qryzTb3QO5r5x7T8OsG2IBUET1bLQzmtbC560SYR49IvVAgMBAAE= """
+ENTRY_TYPE = Dict[str, List[Tuple[str, str, str]]] + DigestHash = stem.util.enum.UppercaseEnum( 'SHA1', 'SHA256', @@ -194,7 +197,7 @@ class _Compression(object): .. versionadded:: 1.8.0 """
- def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, str], bytes]) -> None: + def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, bytes], bytes]) -> None: if module is None: self._module = None self.available = True @@ -256,7 +259,7 @@ class _Compression(object): return self._name
-def _zstd_decompress(module: Any, content: str) -> bytes: +def _zstd_decompress(module: Any, content: bytes) -> bytes: output_buffer = io.BytesIO()
with module.ZstdDecompressor().write_to(output_buffer) as decompressor: @@ -304,7 +307,7 @@ class SigningKey(collections.namedtuple('SigningKey', ['private', 'public', 'pub """
-def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: +def parse_file(descriptor_file: Union[str, BinaryIO, tarfile.TarFile, IO[bytes]], descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: """ Simple function to read the descriptor contents from a file, providing an iterator for its :class:`~stem.descriptor.__init__.Descriptor` contents. @@ -372,7 +375,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
# Delegate to a helper if this is a path or tarfile.
- handler = None + handler = None # type: Callable
if isinstance(descriptor_file, (bytes, str)): if stem.util.system.is_tarfile(descriptor_file): @@ -388,7 +391,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate:
return
- if not descriptor_file.seekable(): + if not descriptor_file.seekable(): # type: ignore raise IOError(UNSEEKABLE_MSG)
# The tor descriptor specifications do not provide a reliable method for @@ -397,19 +400,19 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: # by an annotation on their first line... # https://trac.torproject.org/5651
- initial_position = descriptor_file.tell() - first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip()) + initial_position = descriptor_file.tell() # type: ignore + first_line = stem.util.str_tools._to_unicode(descriptor_file.readline().strip()) # type: ignore metrics_header_match = re.match('^@type (\S+) (\d+).(\d+)$', first_line)
if not metrics_header_match: - descriptor_file.seek(initial_position) + descriptor_file.seek(initial_position) # type: ignore
descriptor_path = getattr(descriptor_file, 'name', None) - filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name) + filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name) # type: str # type: ignore
def parse(descriptor_file: BinaryIO) -> Iterator['stem.descriptor.Descriptor']: if normalize_newlines: - descriptor_file = NewlineNormalizer(descriptor_file) + descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore
if descriptor_type is not None: descriptor_type_match = re.match('^(\S+) (\d+).(\d+)$', descriptor_type) @@ -428,7 +431,7 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: # Cached descriptor handling. These contain multiple descriptors per file.
if normalize_newlines is None and stem.util.system.is_windows(): - descriptor_file = NewlineNormalizer(descriptor_file) + descriptor_file = NewlineNormalizer(descriptor_file) # type: ignore
if filename == 'cached-descriptors' or filename == 'cached-descriptors.new': return stem.descriptor.server_descriptor._parse_file(descriptor_file, validate = validate, **kwargs) @@ -441,29 +444,29 @@ def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: elif filename == 'cached-microdesc-consensus': return stem.descriptor.networkstatus._parse_file(descriptor_file, is_microdescriptor = True, validate = validate, document_handler = document_handler, **kwargs) else: - raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, first_line)) + raise TypeError("Unable to determine the descriptor's type. filename: '%s', first line: '%s'" % (filename, stem.util.str_tools._to_unicode(first_line)))
- for desc in parse(descriptor_file): + for desc in parse(descriptor_file): # type: ignore if descriptor_path is not None: desc._set_path(os.path.abspath(descriptor_path))
yield desc
-def _parse_file_for_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: +def _parse_file_for_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with open(descriptor_file, 'rb') as desc_file: for desc in parse_file(desc_file, *args, **kwargs): yield desc
-def _parse_file_for_tar_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: +def _parse_file_for_tar_path(descriptor_file: str, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with tarfile.open(descriptor_file) as tar_file: for desc in parse_file(tar_file, *args, **kwargs): desc._set_path(os.path.abspath(descriptor_file)) yield desc
-def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: +def _parse_file_for_tarfile(descriptor_file: tarfile.TarFile, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: for tar_entry in descriptor_file: if tar_entry.isfile(): entry = descriptor_file.extractfile(tar_entry) @@ -479,10 +482,14 @@ def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any entry.close()
-def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: +def _parse_metrics_file(descriptor_type: str, major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: # Parses descriptor files from metrics, yielding individual descriptors. This # throws a TypeError if the descriptor_type or version isn't recognized.
+ desc = None # type: Optional[Any] + desc_type = None # type: Optional[Type[stem.descriptor.Descriptor]] + document_type = None # type: Optional[Type] + if descriptor_type == stem.descriptor.server_descriptor.RelayDescriptor.TYPE_ANNOTATION_NAME and major_version == 1: for desc in stem.descriptor.server_descriptor._parse_file(descriptor_file, is_bridge = False, validate = validate, **kwargs): yield desc @@ -507,7 +514,7 @@ def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], maj for desc in stem.descriptor.networkstatus._parse_file(descriptor_file, document_type, validate = validate, document_handler = document_handler, **kwargs): yield desc elif descriptor_type == stem.descriptor.networkstatus.KeyCertificate.TYPE_ANNOTATION_NAME and major_version == 1: - for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate, **kwargs): + for desc in stem.descriptor.networkstatus._parse_file_key_certs(descriptor_file, validate = validate): yield desc elif descriptor_type in ('network-status-consensus-3', 'network-status-vote-3') and major_version == 1: document_type = stem.descriptor.networkstatus.NetworkStatusDocumentV3 @@ -549,7 +556,7 @@ def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], maj raise TypeError("Unrecognized metrics descriptor format. type: '%s', version: '%i.%i'" % (descriptor_type, major_version, minor_version))
-def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[str] = (), footer_template: Sequence[str] = ()) -> bytes: +def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[Tuple[str, Optional[str]]] = (), footer_template: Sequence[Tuple[str, Optional[str]]] = ()) -> bytes: """ Constructs a minimal descriptor with the given attributes. The content we provide back is of the form... @@ -586,8 +593,9 @@ def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = :returns: bytes with the requested descriptor content """
- header_content, footer_content = [], [] - attr = {} if attr is None else collections.OrderedDict(attr) # shallow copy since we're destructive + header_content = [] # type: List[str] + footer_content = [] # type: List[str] + attr = {} if attr is None else collections.OrderedDict(attr) # type: Dict[str, str] # shallow copy since we're destructive
for content, template in ((header_content, header_template), (footer_content, footer_template)): @@ -621,28 +629,28 @@ def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = return stem.util.str_tools._to_bytes('\n'.join(header_content + remainder + footer_content))
-def _value(line: str, entries: Dict[str, Sequence[str]]) -> str: +def _value(line: str, entries: ENTRY_TYPE) -> str: return entries[line][0][0]
-def _values(line: str, entries: Dict[str, Sequence[str]]) -> Sequence[str]: +def _values(line: str, entries: ENTRY_TYPE) -> Sequence[str]: return [entry[0] for entry in entries[line]]
-def _parse_simple_line(keyword: str, attribute: str, func: Callable[[str], str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_simple_line(keyword: str, attribute: str, func: Optional[Callable[[str], Any]] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) setattr(descriptor, attribute, func(value) if func else value)
return _parse
-def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: +def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: return lambda descriptor, entries: setattr(descriptor, attribute, keyword in entries)
-def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: line_match = re.search(stem.util.str_tools._to_bytes('^(opt )?%s(?:[%s]+(.*))?$' % (keyword, WHITESPACE)), descriptor.get_bytes(), re.MULTILINE) result = None
@@ -655,8 +663,8 @@ def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descripto return _parse
-def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
try: @@ -672,10 +680,10 @@ def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) - return _parse
-def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: +def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: # "<keyword>" YYYY-MM-DD HH:MM:SS
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
try: @@ -686,10 +694,10 @@ def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descr return _parse
-def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: +def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: # format of fingerprints, sha1 digests, etc
- def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
if not stem.util.tor_tools.is_hex_digits(value, 40): @@ -700,15 +708,15 @@ def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem. return _parse
-def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # parses 'protocol' entries like: Cons=1-2 Desc=1-2 DirCache=1 HSDir=1
value = _value(keyword, entries) protocols = collections.OrderedDict()
for k, v in _mappings_for(keyword, value): - versions = [] + versions = [] # type: List[int]
if not v: continue @@ -731,8 +739,8 @@ def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descri return _parse
-def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, block_type, block_contents = entries[keyword][0]
if not block_contents or block_type != expected_block_type: @@ -788,7 +796,7 @@ def _copy(default: Any) -> Any: return copy.copy(default)
-def _encode_digest(hash_value: bytes, encoding: 'stem.descriptor.DigestEncoding') -> str: +def _encode_digest(hash_value: 'hashlib._HASH', encoding: 'stem.descriptor.DigestEncoding') -> Union[str, 'hashlib._HASH']: # type: ignore """ Encodes a hash value with the given HashEncoding. """ @@ -810,21 +818,21 @@ class Descriptor(object): Common parent for all types of descriptors. """
- ATTRIBUTES = {} # mapping of 'attribute' => (default_value, parsing_function) - PARSER_FOR_LINE = {} # line keyword to its associated parsing function - TYPE_ANNOTATION_NAME = None + ATTRIBUTES = {} # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]] # mapping of 'attribute' => (default_value, parsing_function) + PARSER_FOR_LINE = {} # type: Dict[str, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]] # line keyword to its associated parsing function + TYPE_ANNOTATION_NAME = None # type: Optional[str]
- def __init__(self, contents, lazy_load = False): - self._path = None - self._archive_path = None + def __init__(self, contents: bytes, lazy_load: bool = False) -> None: + self._path = None # type: Optional[str] + self._archive_path = None # type: Optional[str] self._raw_contents = contents self._lazy_loading = lazy_load - self._entries = {} - self._hash = None - self._unrecognized_lines = [] + self._entries = {} # type: ENTRY_TYPE + self._hash = None # type: Optional[int] + self._unrecognized_lines = [] # type: List[str]
@classmethod - def from_str(cls, content, **kwargs): + def from_str(cls, content: str, **kwargs: Any) -> Union['stem.descriptor.Descriptor', List['stem.descriptor.Descriptor']]: """ Provides a :class:`~stem.descriptor.__init__.Descriptor` for the given content.
@@ -873,7 +881,7 @@ class Descriptor(object): raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: """ Creates descriptor content with the given attributes. Mandatory fields are filled with dummy information unless data is supplied. This doesn't yet @@ -885,7 +893,7 @@ class Descriptor(object): :param list exclude: mandatory keywords to exclude from the descriptor, this results in an invalid descriptor
- :returns: **str** with the content of a descriptor + :returns: **bytes** with the content of a descriptor
:raises: * **ImportError** if cryptography is unavailable and sign is True @@ -895,7 +903,7 @@ class Descriptor(object): raise NotImplementedError("The create and content methods haven't been implemented for %s" % cls.__name__)
@classmethod - def create(cls, attr = None, exclude = (), validate = True): + def create(cls, attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.Descriptor': """ Creates a descriptor with the given attributes. Mandatory fields are filled with dummy information unless data is supplied. This doesn't yet create a @@ -917,9 +925,9 @@ class Descriptor(object): * **NotImplementedError** if not implemented for this descriptor type """
- return cls(cls.content(attr, exclude), validate = validate) + return cls(cls.content(attr, exclude), validate = validate) # type: ignore
- def type_annotation(self): + def type_annotation(self) -> 'stem.descriptor.TypeAnnotation': """ Provides the `Tor metrics annotation https://metrics.torproject.org/collector.html#relay-descriptors`_ of this @@ -941,7 +949,7 @@ class Descriptor(object): else: raise NotImplementedError('%s does not have a @type annotation' % type(self).__name__)
- def get_path(self): + def get_path(self) -> str: """ Provides the absolute path that we loaded this descriptor from.
@@ -950,7 +958,7 @@ class Descriptor(object):
return self._path
- def get_archive_path(self): + def get_archive_path(self) -> str: """ If this descriptor came from an archive then provides its path within the archive. This is only set if the descriptor was read by @@ -962,7 +970,7 @@ class Descriptor(object):
return self._archive_path
- def get_bytes(self): + def get_bytes(self) -> bytes: """ Provides the ASCII **bytes** of the descriptor. This only differs from **str()** if you're running python 3.x, in which case **str()** provides a @@ -973,7 +981,7 @@ class Descriptor(object):
return stem.util.str_tools._to_bytes(self._raw_contents)
- def get_unrecognized_lines(self): + def get_unrecognized_lines(self) -> List[str]: """ Provides a list of lines that were either ignored or had data that we did not know how to process. This is most common due to new descriptor fields @@ -989,7 +997,7 @@ class Descriptor(object):
return list(self._unrecognized_lines)
- def _parse(self, entries, validate, parser_for_line = None): + def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None: """ Parses a series of 'keyword => (value, pgp block)' mappings and applies them as attributes. @@ -1020,16 +1028,16 @@ class Descriptor(object): if validate: raise
- def _set_path(self, path): + def _set_path(self, path: str) -> None: self._path = path
- def _set_archive_path(self, path): + def _set_archive_path(self, path: str) -> None: self._archive_path = path
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return str(type(self))
- def _digest_for_signature(self, signing_key, signature): + def _digest_for_signature(self, signing_key: str, signature: str) -> str: """ Provides the signed digest we should have given this key and signature.
@@ -1091,13 +1099,15 @@ class Descriptor(object): digest_hex = codecs.encode(decrypted_bytes[seperator_index + 1:], 'hex_codec') return stem.util.str_tools._to_unicode(digest_hex.upper())
- def _content_range(self, start = None, end = None): + def _content_range(self, start: Optional[Union[str, bytes]] = None, end: Optional[Union[str, bytes]] = None) -> bytes: """ Provides the descriptor content inclusively between two substrings.
:param bytes start: start of the content range to get :param bytes end: end of the content range to get
+ :returns: **bytes** within the given range + :raises: ValueError if either the start or end substring are not within our content """
@@ -1108,24 +1118,24 @@ class Descriptor(object): start_index = content.find(stem.util.str_tools._to_bytes(start))
if start_index == -1: - raise ValueError("'%s' is not present within our descriptor content" % start) + raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(start))
if end is not None: end_index = content.find(stem.util.str_tools._to_bytes(end), start_index)
if end_index == -1: - raise ValueError("'%s' is not present within our descriptor content" % end) + raise ValueError("'%s' is not present within our descriptor content" % stem.util.str_tools._to_unicode(end))
end_index += len(end) # make the ending index inclusive
return content[start_index:end_index]
- def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # We can't use standard hasattr() since it calls this function, recursing. # Doing so works since it stops recursing after several dozen iterations # (not sure why), but horrible in terms of performance.
- def has_attr(attr): + def has_attr(attr: str) -> bool: try: super(Descriptor, self).__getattribute__(attr) return True @@ -1156,31 +1166,31 @@ class Descriptor(object):
return super(Descriptor, self).__getattribute__(name)
- def __str__(self): + def __str__(self) -> str: return stem.util.str_tools._to_unicode(self._raw_contents)
- def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: if type(self) != type(other): return False
return method(str(self).strip(), str(other).strip())
- def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(str(self).strip())
return self._hash
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
- def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s < o)
- def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s <= o)
@@ -1189,27 +1199,31 @@ class NewlineNormalizer(object): File wrapper that normalizes CRLF line endings. """
- def __init__(self, wrapped_file): + def __init__(self, wrapped_file: BinaryIO) -> None: self._wrapped_file = wrapped_file self.name = getattr(wrapped_file, 'name', None)
- def read(self, *args): + def read(self, *args: Any) -> bytes: return self._wrapped_file.read(*args).replace(b'\r\n', b'\n')
- def readline(self, *args): + def readline(self, *args: Any) -> bytes: return self._wrapped_file.readline(*args).replace(b'\r\n', b'\n')
- def readlines(self, *args): + def readlines(self, *args: Any) -> List[bytes]: return [line.rstrip(b'\r') for line in self._wrapped_file.readlines(*args)]
- def seek(self, *args): + def seek(self, *args: Any) -> int: return self._wrapped_file.seek(*args)
- def tell(self, *args): + def tell(self, *args: Any) -> int: return self._wrapped_file.tell(*args)
-def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_first = False, skip = False, end_position = None, include_ending_keyword = False): +def _read_until_keywords(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None) -> List[bytes]: + return _read_until_keywords_with_ending_keyword(keywords, descriptor_file, inclusive, ignore_first, skip, end_position, include_ending_keyword = False) # type: ignore + + +def _read_until_keywords_with_ending_keyword(keywords: Union[str, Sequence[str]], descriptor_file: BinaryIO, inclusive: bool = False, ignore_first: bool = False, skip: bool = False, end_position: Optional[int] = None, include_ending_keyword: bool = False) -> Tuple[List[bytes], str]: """ Reads from the descriptor file until we get to one of the given keywords or reach the end of the file. @@ -1228,7 +1242,7 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi **True** """
- content = None if skip else [] + content = None if skip else [] # type: Optional[List[bytes]] ending_keyword = None
if isinstance(keywords, (bytes, str)): @@ -1270,10 +1284,10 @@ def _read_until_keywords(keywords, descriptor_file, inclusive = False, ignore_fi if include_ending_keyword: return (content, ending_keyword) else: - return content + return content # type: ignore
-def _bytes_for_block(content): +def _bytes_for_block(content: str) -> bytes: """ Provides the base64 decoded content of a pgp-style block.
@@ -1291,7 +1305,7 @@ def _bytes_for_block(content): return base64.b64decode(stem.util.str_tools._to_bytes(content))
-def _get_pseudo_pgp_block(remaining_contents): +def _get_pseudo_pgp_block(remaining_contents: List[str]) -> Tuple[str, str]: """ Checks if given contents begins with a pseudo-Open-PGP-style block and, if so, pops it off and provides it back to the caller. @@ -1311,7 +1325,7 @@ def _get_pseudo_pgp_block(remaining_contents):
if block_match: block_type = block_match.groups()[0] - block_lines = [] + block_lines = [] # type: List[str] end_line = PGP_BLOCK_END % block_type
while True: @@ -1327,7 +1341,7 @@ def _get_pseudo_pgp_block(remaining_contents): return None
-def create_signing_key(private_key = None): +def create_signing_key(private_key: Optional['cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey'] = None) -> 'stem.descriptor.SigningKey': # type: ignore """ Serializes a signing key if we have one. Otherwise this creates a new signing key we can use to create descriptors. @@ -1363,11 +1377,11 @@ def create_signing_key(private_key = None): # # https://github.com/pyca/cryptography/issues/3713
- def no_op(*args, **kwargs): + def no_op(*args: Any, **kwargs: Any) -> int: return 1
- private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op - private_key._backend.openssl_assert = no_op + private_key._backend._lib.EVP_PKEY_CTX_set_signature_md = no_op # type: ignore + private_key._backend.openssl_assert = no_op # type: ignore
public_key = private_key.public_key() public_digest = b'\n' + public_key.public_bytes( @@ -1378,7 +1392,7 @@ def create_signing_key(private_key = None): return SigningKey(private_key, public_key, public_digest)
-def _append_router_signature(content, private_key): +def _append_router_signature(content: bytes, private_key: 'cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey') -> bytes: # type: ignore """ Appends a router signature to a server or extrainfo descriptor.
@@ -1399,23 +1413,23 @@ def _append_router_signature(content, private_key): return content + b'\n'.join([b'-----BEGIN SIGNATURE-----'] + stem.util.str_tools._split_by_length(signature, 64) + [b'-----END SIGNATURE-----\n'])
-def _random_nickname(): +def _random_nickname() -> str: return ('Unnamed%i' % random.randint(0, 100000000000000))[:19]
-def _random_fingerprint(): +def _random_fingerprint() -> str: return ('%040x' % random.randrange(16 ** 40)).upper()
-def _random_ipv4_address(): +def _random_ipv4_address() -> str: return '%i.%i.%i.%i' % (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
-def _random_date(): +def _random_date() -> str: return '%i-%02i-%02i %02i:%02i:%02i' % (random.randint(2000, 2015), random.randint(1, 12), random.randint(1, 20), random.randint(0, 23), random.randint(0, 59), random.randint(0, 59))
-def _random_crypto_blob(block_type = None): +def _random_crypto_blob(block_type: Optional[str] = None) -> str: """ Provides a random string that can be used for crypto blocks. """ @@ -1429,7 +1443,11 @@ def _random_crypto_blob(block_type = None): return crypto_blob
-def _descriptor_components(raw_contents, validate, extra_keywords = (), non_ascii_fields = ()): +def _descriptor_components(raw_contents: bytes, validate: bool, non_ascii_fields: Sequence[str] = ()) -> ENTRY_TYPE: + return _descriptor_components_with_extra(raw_contents, validate, (), non_ascii_fields) # type: ignore + + +def _descriptor_components_with_extra(raw_contents: bytes, validate: bool, extra_keywords: Sequence[str] = (), non_ascii_fields: Sequence[str] = ()) -> Tuple[ENTRY_TYPE, List[str]]: """ Initial breakup of the server descriptor contents to make parsing easier.
@@ -1443,7 +1461,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci entries because this influences the resulting exit policy, but for everything else in server descriptors the order does not matter.
- :param str raw_contents: descriptor content provided by the relay + :param bytes raw_contents: descriptor content provided by the relay :param bool validate: checks the validity of the descriptor's content if True, skips these checks otherwise :param list extra_keywords: entity keywords to put into a separate listing @@ -1456,12 +1474,9 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci value tuple, the second being a list of those entries. """
- if isinstance(raw_contents, bytes): - raw_contents = stem.util.str_tools._to_unicode(raw_contents) - - entries = collections.OrderedDict() + entries = collections.OrderedDict() # type: ENTRY_TYPE extra_entries = [] # entries with a keyword in extra_keywords - remaining_lines = raw_contents.split('\n') + remaining_lines = stem.util.str_tools._to_unicode(raw_contents).split('\n')
while remaining_lines: line = remaining_lines.pop(0) @@ -1525,7 +1540,7 @@ def _descriptor_components(raw_contents, validate, extra_keywords = (), non_asci if extra_keywords: return entries, extra_entries else: - return entries + return entries # type: ignore
# importing at the end to avoid circular dependencies on our Descriptor class diff --git a/stem/descriptor/bandwidth_file.py b/stem/descriptor/bandwidth_file.py index 49df3173..f1f0b1e2 100644 --- a/stem/descriptor/bandwidth_file.py +++ b/stem/descriptor/bandwidth_file.py @@ -21,9 +21,10 @@ import time
import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type
from stem.descriptor import ( + ENTRY_TYPE, _mappings_for, Descriptor, ) @@ -168,11 +169,14 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any * **IOError** if the file can't be read """
- yield BandwidthFile(descriptor_file.read(), validate, **kwargs) + if kwargs: + raise ValueError('BUG: keyword arguments unused by bandwidth files')
+ yield BandwidthFile(descriptor_file.read(), validate)
-def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: - header = collections.OrderedDict() + +def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: + header = collections.OrderedDict() # type: collections.OrderedDict[str, str] content = io.BytesIO(descriptor.get_bytes())
content.readline() # skip the first line, which should be the timestamp @@ -197,7 +201,7 @@ def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S if key == 'version': version_index = index else: - raise ValueError("Header expected to be key=value pairs, but had '%s'" % line) + raise ValueError("Header expected to be key=value pairs, but had '%s'" % stem.util.str_tools._to_unicode(line))
index += 1
@@ -216,16 +220,16 @@ def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S raise ValueError("The 'version' header must be in the second position")
-def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: first_line = io.BytesIO(descriptor.get_bytes()).readline().strip()
if first_line.isdigit(): descriptor.timestamp = datetime.datetime.utcfromtimestamp(int(first_line)) else: - raise ValueError("First line should be a unix timestamp, but was '%s'" % first_line) + raise ValueError("First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(first_line))
-def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # In version 1.0.0 the body is everything after the first line. Otherwise # it's everything after the header's divider.
@@ -239,13 +243,13 @@ def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Seq
measurements = {}
- for line in content.readlines(): - line = stem.util.str_tools._to_unicode(line.strip()) + for line_bytes in content.readlines(): + line = stem.util.str_tools._to_unicode(line_bytes.strip()) attr = dict(_mappings_for('measurement', line)) fingerprint = attr.get('node_id', '').lstrip('$') # bwauths prefix fingerprints with '$'
if not fingerprint: - raise ValueError("Every meaurement must include 'node_id': %s" % line) + raise ValueError("Every meaurement must include 'node_id': %s" % stem.util.str_tools._to_unicode(line)) elif fingerprint in measurements: raise ValueError('Relay %s is listed multiple times. It should only be present once.' % fingerprint)
@@ -298,12 +302,12 @@ class BandwidthFile(Descriptor): 'timestamp': (None, _parse_timestamp), 'header': ({}, _parse_header), 'measurements': ({}, _parse_body), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
ATTRIBUTES.update(dict([(k, (None, _parse_header)) for k in HEADER_ATTR.keys()]))
@classmethod - def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: """ Creates descriptor content with the given attributes. This descriptor type differs somewhat from others and treats our attr/exclude attributes as @@ -328,7 +332,7 @@ class BandwidthFile(Descriptor):
header = collections.OrderedDict(attr) if attr is not None else collections.OrderedDict() timestamp = header.pop('timestamp', str(int(time.time()))) - content = header.pop('content', []) + content = header.pop('content', []) # type: List[str] # type: ignore version = header.get('version', HEADER_DEFAULT.get('version'))
lines = [] @@ -354,7 +358,7 @@ class BandwidthFile(Descriptor):
return b'\n'.join(lines)
- def __init__(self, raw_content: str, validate: bool = False) -> None: + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(BandwidthFile, self).__init__(raw_content, lazy_load = not validate)
if validate: diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py index 6956a60f..bc09be2d 100644 --- a/stem/descriptor/certificate.py +++ b/stem/descriptor/certificate.py @@ -64,7 +64,8 @@ import stem.util.enum import stem.util.str_tools
from stem.client.datatype import CertType, Field, Size, split -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from stem.descriptor import ENTRY_TYPE +from typing import Callable, List, Optional, Sequence, Tuple, Union
ED25519_KEY_LENGTH = 32 ED25519_HEADER_LENGTH = 40 @@ -218,7 +219,7 @@ class Ed25519Certificate(object): return stem.util.str_tools._to_unicode(encoded)
@staticmethod - def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: def _parse(descriptor, entries): value, block_type, block_contents = entries[keyword][0]
@@ -253,7 +254,7 @@ class Ed25519CertificateV1(Ed25519Certificate): is unavailable """
- def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None: + def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None: # type: ignore super(Ed25519CertificateV1, self).__init__(1)
if cert_type is None: @@ -261,12 +262,15 @@ class Ed25519CertificateV1(Ed25519Certificate): elif key is None: raise ValueError('Certificate key is required')
+ self.type = None # type: Optional[stem.client.datatype.CertType] + self.type_int = None # type: Optional[int] + self.type, self.type_int = CertType.get(cert_type) - self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS) - self.key_type = key_type if key_type else 1 - self.key = stem.util._pubkey_bytes(key) - self.extensions = extensions if extensions else [] - self.signature = signature + self.expiration = expiration if expiration else datetime.datetime.utcnow() + datetime.timedelta(hours = DEFAULT_EXPIRATION_HOURS) # type: datetime.datetime + self.key_type = key_type if key_type else 1 # type: int + self.key = stem.util._pubkey_bytes(key) # type: bytes + self.extensions = list(extensions) if extensions else [] # type: List[stem.descriptor.certificate.Ed25519Extension] + self.signature = signature # type: Optional[bytes]
if signing_key: calculated_sig = signing_key.sign(self.pack()) diff --git a/stem/descriptor/collector.py b/stem/descriptor/collector.py index 1f1b1e95..9749dadb 100644 --- a/stem/descriptor/collector.py +++ b/stem/descriptor/collector.py @@ -63,7 +63,7 @@ import stem.util.connection import stem.util.str_tools
from stem.descriptor import Compression, DocumentHandler -from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
COLLECTOR_URL = 'https://collector.torproject.org/' REFRESH_INDEX_RATE = 3600 # get new index if cached copy is an hour old @@ -93,7 +93,7 @@ def get_instance() -> 'stem.descriptor.collector.CollecTor': return SINGLETON_COLLECTOR
-def get_server_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']: +def get_server_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_server_descriptors` @@ -104,7 +104,7 @@ def get_server_descriptors(start: datetime.datetime = None, end: datetime.dateti yield desc
-def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']: +def get_extrainfo_descriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_extrainfo_descriptors` @@ -115,7 +115,7 @@ def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.dat yield desc
-def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: +def get_microdescriptors(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_microdescriptors` @@ -126,7 +126,7 @@ def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime yield desc
-def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: +def get_consensus(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_consensus` @@ -137,7 +137,7 @@ def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None yield desc
-def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: +def get_key_certificates(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_key_certificates` @@ -148,7 +148,7 @@ def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime yield desc
-def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: +def get_bandwidth_files(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_bandwidth_files` @@ -159,7 +159,7 @@ def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime yield desc
-def get_exit_lists(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: +def get_exit_lists(start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_exit_lists` @@ -188,14 +188,14 @@ class File(object): :var datetime last_modified: when the file was last modified """
- def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: datetime.datetime, last_published: datetime.datetime, last_modified: datetime.datetime) -> None: + def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: str, last_published: str, last_modified: str) -> None: self.path = path self.types = tuple(types) if types else () self.compression = File._guess_compression(path) self.size = size self.sha256 = sha256 self.last_modified = datetime.datetime.strptime(last_modified, '%Y-%m-%d %H:%M') - self._downloaded_to = None # location we last downloaded to + self._downloaded_to = None # type: Optional[str] # location we last downloaded to
# Most descriptor types have publication time fields, but microdescriptors # don't because these files lack timestamps to parse. @@ -206,7 +206,7 @@ class File(object): else: self.start, self.end = File._guess_time_range(path)
- def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.Descriptor']: + def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.Descriptor]: """ Provides descriptors from this archive. Descriptors are downloaded or read from disk as follows... @@ -325,8 +325,8 @@ class File(object): # check if this file already exists with the correct checksum
if os.path.exists(path): - with open(path) as prior_file: - expected_hash = binascii.hexlify(base64.b64decode(self.sha256)) + with open(path, 'b') as prior_file: + expected_hash = binascii.hexlify(base64.b64decode(self.sha256)).decode('utf-8') actual_hash = hashlib.sha256(prior_file.read()).hexdigest()
if expected_hash == actual_hash: @@ -346,7 +346,7 @@ class File(object): return path
@staticmethod - def _guess_compression(path) -> 'stem.descriptor.Compression': + def _guess_compression(path: str) -> stem.descriptor._Compression: """ Determine file comprssion from CollecTor's filename. """ @@ -358,7 +358,7 @@ class File(object): return Compression.PLAINTEXT
@staticmethod - def _guess_time_range(path) -> Tuple[datetime.datetime, datetime.datetime]: + def _guess_time_range(path: str) -> Tuple[datetime.datetime, datetime.datetime]: """ Attemt to determine the (start, end) time range from CollecTor's filename. This provides (None, None) if this cannot be determined. @@ -404,10 +404,10 @@ class CollecTor(object): self.timeout = timeout
self._cached_index = None - self._cached_files = None - self._cached_index_at = 0 + self._cached_files = None # type: Optional[List[File]] + self._cached_index_at = 0.0
- def get_server_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']: + def get_server_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]: """ Provides server descriptors published during the given time range, sorted oldest to newest. @@ -432,9 +432,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_extrainfo_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']: + def get_extrainfo_descriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor]: """ Provides extrainfo descriptors published during the given time range, sorted oldest to newest. @@ -459,9 +459,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_microdescriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: + def get_microdescriptors(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ Provides microdescriptors estimated to be published during the given time range, sorted oldest to newest. Unlike server/extrainfo descriptors, @@ -493,9 +493,9 @@ class CollecTor(object):
for f in self.files('microdescriptor', start, end): for desc in f.read(cache_to, 'microdescriptor', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_consensus(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: + def get_consensus(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, document_handler: stem.descriptor.DocumentHandler = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntry]: """ Provides consensus router status entries published during the given time range, sorted oldest to newest. @@ -537,9 +537,9 @@ class CollecTor(object):
for f in self.files(desc_type, start, end): for desc in f.read(cache_to, desc_type, start, end, document_handler, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_key_certificates(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: + def get_key_certificates(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.networkstatus.KeyCertificate]: """ Directory authority key certificates for the given time range, sorted oldest to newest. @@ -561,9 +561,9 @@ class CollecTor(object):
for f in self.files('dir-key-certificate-3', start, end): for desc in f.read(cache_to, 'dir-key-certificate-3', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_bandwidth_files(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: + def get_bandwidth_files(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.bandwidth_file.BandwidthFile]: """ Bandwidth authority heuristics for the given time range, sorted oldest to newest. @@ -585,9 +585,9 @@ class CollecTor(object):
for f in self.files('bandwidth-file', start, end): for desc in f.read(cache_to, 'bandwidth-file', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def get_exit_lists(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: + def get_exit_lists(self, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator[stem.descriptor.tordnsel.TorDNSEL]: """ `TorDNSEL exit lists https://www.torproject.org/projects/tordnsel.html.en`_ for the given time range, sorted oldest to newest. @@ -609,9 +609,9 @@ class CollecTor(object):
for f in self.files('tordnsel', start, end): for desc in f.read(cache_to, 'tordnsel', start, end, timeout = timeout, retries = retries): - yield desc + yield desc # type: ignore
- def index(self, compression: Union[str, 'descriptor.Compression'] = 'best') -> Dict[str, Any]: + def index(self, compression: Union[str, stem.descriptor._Compression] = 'best') -> Dict[str, Any]: """ Provides the archives available in CollecTor.
@@ -632,21 +632,25 @@ class CollecTor(object): if compression == 'best': for option in (Compression.LZMA, Compression.BZ2, Compression.GZIP, Compression.PLAINTEXT): if option.available: - compression = option + compression_enum = option break elif compression is None: - compression = Compression.PLAINTEXT + compression_enum = Compression.PLAINTEXT + elif isinstance(compression, stem.descriptor._Compression): + compression_enum = compression + else: + raise ValueError('compression must be a descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
- extension = compression.extension if compression != Compression.PLAINTEXT else '' + extension = compression_enum.extension if compression_enum != Compression.PLAINTEXT else '' url = COLLECTOR_URL + 'index/index.json' + extension - response = compression.decompress(stem.util.connection.download(url, self.timeout, self.retries)) + response = compression_enum.decompress(stem.util.connection.download(url, self.timeout, self.retries))
self._cached_index = json.loads(stem.util.str_tools._to_unicode(response)) self._cached_index_at = time.time()
return self._cached_index
- def files(self, descriptor_type: str = None, start: datetime.datetime = None, end: datetime.datetime = None) -> Sequence['stem.descriptor.collector.File']: + def files(self, descriptor_type: Optional[str] = None, start: Optional[datetime.datetime] = None, end: Optional[datetime.datetime] = None) -> List['stem.descriptor.collector.File']: """ Provides files CollecTor presently has, sorted oldest to newest.
@@ -681,7 +685,7 @@ class CollecTor(object): return matches
@staticmethod - def _files(val: str, path: Sequence[str]) -> Sequence['stem.descriptor.collector.File']: + def _files(val: Dict[str, Any], path: List[str]) -> List['stem.descriptor.collector.File']: """ Recursively provies files within the index.
@@ -698,7 +702,7 @@ class CollecTor(object):
for k, v in val.items(): if k == 'files': - for attr in v: + for attr in v: # Dict[str, str] file_path = '/'.join(path + [attr.get('path')]) files.append(File(file_path, attr.get('types'), attr.get('size'), attr.get('sha256'), attr.get('first_published'), attr.get('last_published'), attr.get('last_modified'))) elif k == 'directories': diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py index 6aca3c29..cd9467d1 100644 --- a/stem/descriptor/extrainfo_descriptor.py +++ b/stem/descriptor/extrainfo_descriptor.py @@ -76,9 +76,10 @@ import stem.util.connection import stem.util.enum import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, @@ -184,6 +185,9 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False, * **IOError** if the file can't be read """
+ if kwargs: + raise ValueError('BUG: keyword arguments unused by extrainfo descriptors') + while True: if not is_bridge: extrainfo_content = _read_until_keywords('router-signature', descriptor_file) @@ -200,9 +204,9 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False, extrainfo_content = extrainfo_content[1:]
if is_bridge: - yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs) + yield BridgeExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate) else: - yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate, **kwargs) + yield RelayExtraInfoDescriptor(bytes.join(b'', extrainfo_content), validate) else: break # done parsing file
@@ -241,7 +245,7 @@ def _parse_timestamp_and_interval(keyword: str, content: str) -> Tuple[datetime. raise ValueError("%s line's timestamp wasn't parsable: %s" % (keyword, line))
-def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "extra-info" Nickname Fingerprint
value = _value('extra-info', entries) @@ -258,7 +262,7 @@ def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Di descriptor.fingerprint = extra_info_comp[1]
-def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "transport" transportname address:port [arglist] # Everything after the transportname is scrubbed in published bridge # descriptors, so we'll never see it in practice. @@ -304,7 +308,7 @@ def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic descriptor.transport = transports
-def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "padding-counts" YYYY-MM-DD HH:MM:SS (NSEC s) key=val key=val...
value = _value('padding-counts', entries) @@ -319,7 +323,7 @@ def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries setattr(descriptor, 'padding_counts', counts)
-def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
recognized_counts = {} @@ -343,7 +347,7 @@ def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_c setattr(descriptor, unrecognized_counts_attr, unrecognized_counts)
-def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
if not value.endswith('%'): @@ -356,7 +360,7 @@ def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.des setattr(descriptor, attribute, float(value[:-1]) / 100)
-def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" num,...,num
value = _value(keyword, entries) @@ -378,7 +382,7 @@ def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor. raise exc
-def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s)
timestamp, interval, _ = _parse_timestamp_and_interval(keyword, _value(keyword, entries)) @@ -386,7 +390,7 @@ def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interva setattr(descriptor, interval_attribute, interval)
-def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "conn-bi-direct" YYYY-MM-DD HH:MM:SS (NSEC s) BELOW,READ,WRITE,BOTH
value = _value('conn-bi-direct', entries) @@ -404,7 +408,7 @@ def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries descriptor.conn_bi_direct_both = int(stats[3])
-def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s) NUM,NUM,NUM,NUM,NUM...
value = _value(keyword, entries) @@ -422,7 +426,7 @@ def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: st setattr(descriptor, values_attribute, history_values)
-def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" port=N,port=N,...
value, port_mappings = _value(keyword, entries), {} @@ -431,13 +435,13 @@ def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descr if (port != 'other' and not stem.util.connection.is_valid_port(port)) or not stat.isdigit(): raise ValueError('Entries in %s line should only be PORT=N entries: %s %s' % (keyword, keyword, value))
- port = int(port) if port.isdigit() else port + port = int(port) if port.isdigit() else port # type: ignore # this can be an int or 'other' port_mappings[port] = int(stat)
setattr(descriptor, attribute, port_mappings)
-def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" CC=N,CC=N,... # # The maxmind geoip (https://www.maxmind.com/app/iso3166) has numeric @@ -457,7 +461,7 @@ def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.d setattr(descriptor, attribute, locale_usage)
-def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, ip_versions = _value('bridge-ip-versions', entries), {}
for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','): @@ -469,7 +473,7 @@ def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', ent descriptor.ip_versions = ip_versions
-def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, ip_transports = _value('bridge-ip-transports', entries), {}
for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','): @@ -481,7 +485,7 @@ def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', e descriptor.ip_transports = ip_transports
-def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "<keyword>" num key=val key=val...
value, stat, extra = _value(keyword, entries), None, {} @@ -768,7 +772,7 @@ class ExtraInfoDescriptor(Descriptor):
'ip_versions': (None, _parse_bridge_ip_versions_line), 'ip_transports': (None, _parse_bridge_ip_transports_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = { 'extra-info': _parse_extra_info_line, @@ -817,7 +821,7 @@ class ExtraInfoDescriptor(Descriptor): 'bridge-ip-transports': _parse_bridge_ip_transports_line, }
- def __init__(self, raw_contents: str, validate: bool = False) -> None: + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: """ Extra-info descriptor constructor. By default this validates the descriptor's content as it's parsed. This validation can be disabled to @@ -854,7 +858,7 @@ class ExtraInfoDescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by...
@@ -879,7 +883,7 @@ class ExtraInfoDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ExtraInfoDescriptor subclass')
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS
def _first_keyword(self) -> str: @@ -920,7 +924,7 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor): })
@classmethod - def content(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> str: + def content(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> bytes: base_header = ( ('extra-info', '%s %s' % (_random_nickname(), _random_fingerprint())), ('published', _random_date()), @@ -941,11 +945,11 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor): ))
@classmethod - def create(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo.RelayExtraInfoDescriptor': + def create(cls: Type['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor': return cls(cls.content(attr, exclude, sign, signing_key), validate = validate)
@functools.lru_cache() - def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1: # our digest is calculated from everything except our signature
@@ -989,7 +993,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): })
@classmethod - def content(cls: Type['stem.descriptor.extrainfo.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.extrainfo_descriptor.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('extra-info', 'ec2bridgereaac65a3 %s' % _random_fingerprint()), ('published', _random_date()), @@ -997,7 +1001,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): ('router-digest', _random_fingerprint()), ))
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest elif hash_type == DigestHash.SHA256 and encoding == DigestEncoding.BASE64: @@ -1005,7 +1009,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): else: raise NotImplementedError('Bridge extrainfo digests are only available as sha1/hex and sha256/base64, not %s/%s' % (hash_type, encoding))
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: excluded_fields = [ 'router-signature', ] diff --git a/stem/descriptor/hidden_service.py b/stem/descriptor/hidden_service.py index 8d23838e..2eb7d02f 100644 --- a/stem/descriptor/hidden_service.py +++ b/stem/descriptor/hidden_service.py @@ -51,9 +51,10 @@ import stem.util.tor_tools
from stem.client.datatype import CertType from stem.descriptor.certificate import ExtensionType, Ed25519Extension, Ed25519Certificate, Ed25519CertificateV1 -from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, _descriptor_content, @@ -104,7 +105,7 @@ INTRODUCTION_POINTS_ATTR = { 'onion_key': None, 'service_key': None, 'intro_authentication': [], -} +} # type: Dict[str, Any]
# introduction-point fields that can only appear once
@@ -133,7 +134,7 @@ class DecryptionFailure(Exception): """
-class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())): +class IntroductionPointV2(collections.namedtuple('IntroductionPointV2', INTRODUCTION_POINTS_ATTR.keys())): # type: ignore """ Introduction point for a v2 hidden service.
@@ -163,7 +164,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s """
@staticmethod - def parse(content: str) -> 'stem.descriptor.hidden_service.IntroductionPointV3': + def parse(content: bytes) -> 'stem.descriptor.hidden_service.IntroductionPointV3': """ Parses an introduction point from its descriptor content.
@@ -175,7 +176,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s """
entry = _descriptor_components(content, False) - link_specifiers = IntroductionPointV3._parse_link_specifiers(_value('introduction-point', entry)) + link_specifiers = IntroductionPointV3._parse_link_specifiers(stem.util.str_tools._to_bytes(_value('introduction-point', entry)))
onion_key_line = _value('onion-key', entry) onion_key = onion_key_line[5:] if onion_key_line.startswith('ntor ') else None @@ -201,7 +202,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, legacy_key, legacy_key_cert)
@staticmethod - def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': + def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore """ Simplified constructor for a single address/port link specifier.
@@ -223,6 +224,8 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s if not stem.util.connection.is_valid_port(port): raise ValueError("'%s' is an invalid port" % port)
+ link_specifiers = None # type: Optional[List[stem.client.datatype.LinkSpecifier]] + if stem.util.connection.is_valid_ipv4_address(address): link_specifiers = [stem.client.datatype.LinkByIPv4(address, port)] elif stem.util.connection.is_valid_ipv6_address(address): @@ -233,7 +236,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3.create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None)
@staticmethod - def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': + def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': # type: ignore """ Simplified constructor. For more sophisticated use cases you can use this as a template for how introduction points are properly created. @@ -300,7 +303,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return '\n'.join(lines)
- def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': + def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our ntor introduction point public key.
@@ -313,7 +316,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.onion_key_raw, x25519 = True)
- def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey': + def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey': # type: ignore """ Provides our authentication certificate's public key.
@@ -326,7 +329,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.auth_key_cert.key, ed25519 = True)
- def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': + def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our encryption key.
@@ -339,7 +342,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.enc_key_raw, x25519 = True)
- def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': + def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': # type: ignore """ Provides our legacy introduction point public key.
@@ -353,7 +356,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3._key_as(self.legacy_key_raw, x25519 = True)
@staticmethod - def _key_as(value: str, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']: + def _key_as(value: bytes, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']: # type: ignore if value is None or (not x25519 and not ed25519): return value
@@ -376,11 +379,11 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return Ed25519PublicKey.from_public_bytes(value)
@staticmethod - def _parse_link_specifiers(content: str) -> 'stem.client.datatype.LinkSpecifier': + def _parse_link_specifiers(content: bytes) -> List['stem.client.datatype.LinkSpecifier']: try: content = base64.b64decode(content) except Exception as exc: - raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, content)) + raise ValueError('Unable to base64 decode introduction point (%s): %s' % (exc, stem.util.str_tools._to_unicode(content)))
link_specifiers = [] count, content = stem.client.datatype.Size.CHAR.pop(content) @@ -390,7 +393,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s link_specifiers.append(link_specifier)
if content: - raise ValueError('Introduction point had excessive data (%s)' % content) + raise ValueError('Introduction point had excessive data (%s)' % stem.util.str_tools._to_unicode(content))
return link_specifiers
@@ -418,7 +421,7 @@ class AuthorizedClient(object): :var str cookie: base64 encoded authentication cookie """
- def __init__(self, id: str = None, iv: str = None, cookie: str = None) -> None: + def __init__(self, id: Optional[str] = None, iv: Optional[str] = None, cookie: Optional[str] = None) -> None: self.id = stem.util.str_tools._to_unicode(id if id else base64.b64encode(os.urandom(8)).rstrip(b'=')) self.iv = stem.util.str_tools._to_unicode(iv if iv else base64.b64encode(os.urandom(16)).rstrip(b'=')) self.cookie = stem.util.str_tools._to_unicode(cookie if cookie else base64.b64encode(os.urandom(16)).rstrip(b'=')) @@ -433,7 +436,7 @@ class AuthorizedClient(object): return not self == other
-def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']: +def _parse_file(descriptor_file: BinaryIO, desc_type: Optional[Type['stem.descriptor.hidden_service.HiddenServiceDescriptor']] = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']: """ Iterates over the hidden service descriptors in a file.
@@ -468,12 +471,12 @@ def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool if descriptor_content[0].startswith(b'@type'): descriptor_content = descriptor_content[1:]
- yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs) + yield desc_type(bytes.join(b'', descriptor_content), validate, **kwargs) # type: ignore else: break # done parsing file
-def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str: +def _decrypt_layer(encrypted_block: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str: if encrypted_block.startswith('-----BEGIN MESSAGE-----\n') and encrypted_block.endswith('\n-----END MESSAGE-----'): encrypted_block = encrypted_block[24:-22]
@@ -492,7 +495,7 @@ def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: in cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
if expected_mac != mac_for(ciphertext): - raise ValueError('Malformed mac (expected %s, but was %s)' % (expected_mac, mac_for(ciphertext))) + raise ValueError('Malformed mac (expected %s, but was %s)' % (stem.util.str_tools._to_unicode(expected_mac), stem.util.str_tools._to_unicode(mac_for(ciphertext))))
decryptor = cipher.decryptor() plaintext = decryptor.update(ciphertext) + decryptor.finalize() @@ -500,7 +503,7 @@ def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: in return stem.util.str_tools._to_unicode(plaintext)
-def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: +def _encrypt_layer(plaintext: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: salt = os.urandom(16) cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
@@ -511,7 +514,7 @@ def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcr return b'-----BEGIN MESSAGE-----\n%s\n-----END MESSAGE-----' % b'\n'.join(stem.util.str_tools._split_by_length(encoded, 64))
-def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]: +def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]: # type: ignore try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -531,7 +534,7 @@ def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, return cipher, lambda ciphertext: hashlib.sha3_256(mac_prefix + ciphertext).digest()
-def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('protocol-versions', entries)
try: @@ -546,7 +549,7 @@ def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entr descriptor.protocol_versions = versions
-def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: _, block_type, block_contents = entries['introduction-points'][0]
if not block_contents or block_type != 'MESSAGE': @@ -560,7 +563,7 @@ def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', en raise ValueError("'introduction-points' isn't base64 encoded content:\n%s" % block_contents)
-def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "auth-client" client-id iv encrypted-cookie
clients = {} @@ -576,7 +579,7 @@ def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: D descriptor.clients = clients
-def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, formats = _value('create2-formats', entries), []
for entry in value.split(' '): @@ -588,7 +591,7 @@ def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: D descriptor.formats = formats
-def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: if hasattr(descriptor, '_unparsed_introduction_points'): introduction_points = [] remaining = descriptor._unparsed_introduction_points @@ -674,7 +677,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): 'introduction_points_encoded': (None, _parse_introduction_points_line), 'introduction_points_content': (None, _parse_introduction_points_line), 'signature': (None, _parse_v2_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = { 'rendezvous-service-descriptor': _parse_rendezvous_service_descriptor_line, @@ -688,7 +691,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): }
@classmethod - def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('rendezvous-service-descriptor', 'y3olqqblqw2gbh6phimfuiroechjjafa'), ('version', '2'), @@ -705,7 +708,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV2': return cls(cls.content(attr, exclude), validate = validate, skip_crypto_validation = True)
- def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None: + def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(HiddenServiceDescriptorV2, self).__init__(raw_contents, lazy_load = not validate) entries = _descriptor_components(raw_contents, validate, non_ascii_fields = ('introduction-points'))
@@ -737,11 +740,11 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): self._entries = entries
@functools.lru_cache() - def introduction_points(self, authentication_cookie: Optional[str] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: + def introduction_points(self, authentication_cookie: Optional[bytes] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: """ Provided this service's introduction points.
- :param str authentication_cookie: base64 encoded authentication cookie + :param bytes authentication_cookie: base64 encoded authentication cookie
:returns: **list** of :class:`~stem.descriptor.hidden_service.IntroductionPointV2`
@@ -777,7 +780,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): return HiddenServiceDescriptorV2._parse_introduction_points(content)
@staticmethod - def _decrypt_basic_auth(content: bytes, authentication_cookie: str) -> bytes: + def _decrypt_basic_auth(content: bytes, authentication_cookie: bytes) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -787,7 +790,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): try: client_blocks = int(binascii.hexlify(content[1:2]), 16) except ValueError: - raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2])) + raise DecryptionFailure("When using basic auth the content should start with a number of blocks but wasn't a hex digit: %s" % binascii.hexlify(content[1:2]).decode('utf-8'))
# parse the client id and encrypted session keys
@@ -824,7 +827,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): return content # nope, unable to decrypt the content
@staticmethod - def _decrypt_stealth_auth(content: bytes, authentication_cookie: str) -> bytes: + def _decrypt_stealth_auth(content: bytes, authentication_cookie: bytes) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -888,7 +891,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): auth_type, auth_data = auth_value.split(' ')[:2] auth_entries.append((auth_type, auth_data))
- introduction_points.append(IntroductionPointV2(**attr)) + introduction_points.append(IntroductionPointV2(**attr)) # type: ignore
return introduction_points
@@ -931,7 +934,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): }
@classmethod - def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> str: + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> bytes: # type: ignore """ Hidden service v3 descriptors consist of three parts:
@@ -992,7 +995,12 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor):
blinded_key = _blinded_pubkey(identity_key, blinding_nonce) if blinding_nonce else b'a' * 32 subcredential = HiddenServiceDescriptorV3._subcredential(identity_key, blinded_key) - custom_sig = attr.pop('signature') if (attr and 'signature' in attr) else None + + if attr and 'signature' in attr: + custom_sig = attr['signature'] + attr = dict(filter(lambda entry: entry[0] != 'signature', attr.items())) + else: + custom_sig = None
if not outer_layer: outer_layer = OuterLayer.create( @@ -1014,7 +1022,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): ('descriptor-lifetime', '180'), ('descriptor-signing-key-cert', '\n' + signing_cert.to_base64(pem = True)), ('revision-counter', str(revision_counter)), - ('superencrypted', b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key)), + ('superencrypted', stem.util.str_tools._to_unicode(b'\n' + outer_layer._encrypt(revision_counter, subcredential, blinded_key))), ), ()) + b'\n'
if custom_sig: @@ -1026,13 +1034,13 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return desc_content
@classmethod - def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3': + def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.certificate.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3': # type: ignore return cls(cls.content(attr, exclude, sign, inner_layer, outer_layer, identity_key, signing_key, signing_cert, revision_counter, blinding_nonce), validate = validate)
def __init__(self, raw_contents: bytes, validate: bool = False) -> None: super(HiddenServiceDescriptorV3, self).__init__(raw_contents, lazy_load = not validate)
- self._inner_layer = None + self._inner_layer = None # type: Optional[stem.descriptor.hidden_service.InnerLayer] entries = _descriptor_components(raw_contents, validate)
if validate: @@ -1089,7 +1097,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return self._inner_layer
@staticmethod - def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str: + def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str: # type: ignore """ Converts a hidden service identity key into its address. This accepts all key formats (private, public, or public bytes). @@ -1112,7 +1120,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return stem.util.str_tools._to_unicode(onion_address + b'.onion' if suffix else onion_address).lower()
@staticmethod - def identity_key_from_address(onion_address: str) -> bool: + def identity_key_from_address(onion_address: str) -> bytes: """ Converts a hidden service address into its public identity key.
@@ -1149,7 +1157,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return pubkey
@staticmethod - def _subcredential(identity_key: bytes, blinded_key: bytes) -> bytes: + def _subcredential(identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes) -> bytes: # type: ignore # credential = H('credential' | public-identity-key) # subcredential = H('subcredential' | credential | blinded-public-key)
@@ -1179,7 +1187,7 @@ class OuterLayer(Descriptor): 'ephemeral_key': (None, _parse_v3_outer_ephemeral_key), 'clients': ({}, _parse_v3_outer_clients), 'encrypted': (None, _parse_v3_outer_encrypted), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = { 'desc-auth-type': _parse_v3_outer_auth_type, @@ -1189,9 +1197,9 @@ class OuterLayer(Descriptor): }
@staticmethod - def _decrypt(encrypted: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer': + def _decrypt(encrypted: str, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer': plaintext = _decrypt_layer(encrypted, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key) - return OuterLayer(plaintext) + return OuterLayer(stem.util.str_tools._to_bytes(plaintext))
def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: # Spec mandated padding: "Before encryption the plaintext is padded with @@ -1204,7 +1212,7 @@ class OuterLayer(Descriptor): return _encrypt_layer(content, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
@classmethod - def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> str: + def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> bytes: try: from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey @@ -1230,11 +1238,11 @@ class OuterLayer(Descriptor):
return _descriptor_content(attr, exclude, [ ('desc-auth-type', 'x25519'), - ('desc-auth-ephemeral-key', base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate()))), + ('desc-auth-ephemeral-key', stem.util.str_tools._to_unicode(base64.b64encode(stem.util._pubkey_bytes(X25519PrivateKey.generate())))), ] + [ ('auth-client', '%s %s %s' % (c.id, c.iv, c.cookie)) for c in authorized_clients ], ( - ('encrypted', b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key)), + ('encrypted', stem.util.str_tools._to_unicode(b'\n' + inner_layer._encrypt(revision_counter, subcredential, blinded_key))), ))
@classmethod @@ -1285,17 +1293,17 @@ class InnerLayer(Descriptor): }
@staticmethod - def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: + def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.InnerLayer': plaintext = _decrypt_layer(outer_layer.encrypted, b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key) - return InnerLayer(plaintext, validate = True, outer_layer = outer_layer) + return InnerLayer(stem.util.str_tools._to_bytes(plaintext), validate = True, outer_layer = outer_layer)
- def _encrypt(self, revision_counter, subcredential, blinded_key): + def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: # encrypt back into an outer layer's 'encrypted' field
return _encrypt_layer(self.get_bytes(), b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
@classmethod - def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> str: + def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> bytes: if introduction_points: suffix = '\n' + '\n'.join(map(IntroductionPointV3.encode, introduction_points)) else: @@ -1342,7 +1350,7 @@ def _blinded_pubkey(identity_key: bytes, blinding_nonce: bytes) -> bytes: return ed25519.encodepoint(ed25519.scalarmult(P, mult))
-def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes: +def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes: # type: ignore try: from cryptography.hazmat.primitives import serialization except ImportError: diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py index c2c104ff..7bd241e4 100644 --- a/stem/descriptor/microdescriptor.py +++ b/stem/descriptor/microdescriptor.py @@ -72,6 +72,7 @@ import stem.exit_policy from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, Descriptor, DigestHash, DigestEncoding, @@ -120,6 +121,9 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any * **IOError** if the file can't be read """
+ if kwargs: + raise ValueError('BUG: keyword arguments unused by microdescriptors') + while True: annotations = _read_until_keywords('onion-key', descriptor_file)
@@ -156,12 +160,12 @@ def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any
descriptor_text = bytes.join(b'', descriptor_lines)
- yield Microdescriptor(descriptor_text, validate, annotations, **kwargs) + yield Microdescriptor(descriptor_text, validate, annotations) else: break # done parsing descriptors
-def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: identities = {}
for entry in _values('id', entries): @@ -246,12 +250,12 @@ class Microdescriptor(Descriptor): }
@classmethod - def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('onion-key', _random_crypto_blob('RSA PUBLIC KEY')), ))
- def __init__(self, raw_contents, validate = False, annotations = None): + def __init__(self, raw_contents: bytes, validate: bool = False, annotations: Optional[Sequence[bytes]] = None) -> None: super(Microdescriptor, self).__init__(raw_contents, lazy_load = not validate) self._annotation_lines = annotations if annotations else [] entries = _descriptor_components(raw_contents, validate) @@ -262,7 +266,7 @@ class Microdescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this microdescriptor. These are referenced by...
@@ -287,7 +291,7 @@ class Microdescriptor(Descriptor): raise NotImplementedError('Microdescriptor digests are only available in sha1 and sha256, not %s' % hash_type)
@functools.lru_cache() - def get_annotations(self) -> Dict[str, str]: + def get_annotations(self) -> Dict[bytes, bytes]: """ Provides content that appeared prior to the descriptor. If this comes from the cached-microdescs then this commonly contains content like... @@ -310,7 +314,7 @@ class Microdescriptor(Descriptor):
return annotation_dict
- def get_annotation_lines(self) -> Sequence[str]: + def get_annotation_lines(self) -> Sequence[bytes]: """ Provides the lines of content that appeared prior to the descriptor. This is the same as the @@ -322,7 +326,7 @@ class Microdescriptor(Descriptor):
return self._annotation_lines
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: + def _check_constraints(self, entries: ENTRY_TYPE) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py index 48940987..6c0f5e8f 100644 --- a/stem/descriptor/networkstatus.py +++ b/stem/descriptor/networkstatus.py @@ -65,9 +65,10 @@ import stem.util.str_tools import stem.util.tor_tools import stem.version
-from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Type +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, @@ -295,7 +296,7 @@ class DocumentDigest(collections.namedtuple('DocumentDigest', ['flavor', 'algori """
-def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.descriptor.networkstatus.NetworkStatusDocument']] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> 'stem.descriptor.networkstatus.NetworkStatusDocument': +def _parse_file(document_file: BinaryIO, document_type: Optional[Type] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> Iterator[Union['stem.descriptor.networkstatus.NetworkStatusDocument', 'stem.descriptor.router_status_entry.RouterStatusEntry']]: """ Parses a network status and iterates over the RouterStatusEntry in it. The document that these instances reference have an empty 'routers' attribute to @@ -324,6 +325,8 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc if document_type is None: document_type = NetworkStatusDocumentV3
+ router_type = None # type: Optional[Type[stem.descriptor.router_status_entry.RouterStatusEntry]] + if document_type == NetworkStatusDocumentV2: document_type, router_type = NetworkStatusDocumentV2, RouterStatusEntryV2 elif document_type == NetworkStatusDocumentV3: @@ -334,10 +337,10 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc yield document_type(document_file.read(), validate, **kwargs) return else: - raise ValueError("Document type %i isn't recognized (only able to parse v2, v3, and bridge)" % document_type) + raise ValueError("Document type %s isn't recognized (only able to parse v2, v3, and bridge)" % document_type)
if document_handler == DocumentHandler.DOCUMENT: - yield document_type(document_file.read(), validate, **kwargs) + yield document_type(document_file.read(), validate, **kwargs) # type: ignore return
# getting the document without the routers section @@ -355,7 +358,7 @@ def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.desc document_content = bytes.join(b'', header + footer)
if document_handler == DocumentHandler.BARE_DOCUMENT: - yield document_type(document_content, validate, **kwargs) + yield document_type(document_content, validate, **kwargs) # type: ignore elif document_handler == DocumentHandler.ENTRIES: desc_iterator = stem.descriptor.router_status_entry._parse_file( document_file, @@ -433,7 +436,7 @@ class NetworkStatusDocument(Descriptor): Common parent for network status documents. """
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> None: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by...
@@ -460,8 +463,8 @@ class NetworkStatusDocument(Descriptor): raise NotImplementedError('Network status document digests are only available in sha1 and sha256, not %s' % hash_type)
-def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries)
if not value.isdigit(): @@ -475,7 +478,7 @@ def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> return _parse
-def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('dir-source', entries) dir_source_comp = value.split()
@@ -495,7 +498,7 @@ def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Di descriptor.dir_port = None if dir_source_comp[2] == '0' else int(dir_source_comp[2])
-def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: digests = []
for val in _values('additional-digest', entries): @@ -509,7 +512,7 @@ def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: descriptor.additional_digests = digests
-def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: signatures = []
for val, block_type, block_contents in entries['additional-signature']: @@ -584,7 +587,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument):
'signing_authority': (None, _parse_directory_signature_line), 'signatures': (None, _parse_directory_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = { 'network-status-version': _parse_network_status_version_line, @@ -600,7 +603,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): }
@classmethod - def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('network-status-version', '2'), ('dir-source', '%s %s 80' % (_random_ipv4_address(), _random_ipv4_address())), @@ -648,7 +651,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): else: self._entries = entries
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: + def _check_constraints(self, entries: ENTRY_TYPE) -> None: required_fields = [field for (field, is_mandatory) in NETWORK_STATUS_V2_FIELDS if is_mandatory] for keyword in required_fields: if keyword not in entries: @@ -664,7 +667,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): raise ValueError("Network status document (v2) are expected to start with a 'network-status-version' line:\n%s" % str(self))
-def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "network-status-version" version
value = _value('network-status-version', entries) @@ -685,7 +688,7 @@ def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descr raise ValueError("Expected a version 3 network status document, got version '%s' instead" % descriptor.version)
-def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "vote-status" type # # The consensus-method and consensus-methods fields are optional since @@ -702,7 +705,7 @@ def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', ent raise ValueError("A network status document's vote-status line can only be 'consensus' or 'vote', got '%s' instead" % value)
-def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "consensus-methods" IntegerList
if descriptor._lazy_loading and descriptor.is_vote: @@ -719,7 +722,7 @@ def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor descriptor.consensus_methods = consensus_methods
-def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "consensus-method" Integer
if descriptor._lazy_loading and descriptor.is_consensus: @@ -733,7 +736,7 @@ def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor' descriptor.consensus_method = int(value)
-def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "voting-delay" VoteSeconds DistSeconds
value = _value('voting-delay', entries) @@ -746,8 +749,8 @@ def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', en raise ValueError("A network status document's 'voting-delay' line must be a pair of integer values, but was '%s'" % value)
-def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: - def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value, entries = _value(keyword, entries), []
for entry in value.split(','): @@ -761,7 +764,7 @@ def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descri return _parse
-def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "flag-thresholds" SP THRESHOLDS
value, thresholds = _value('flag-thresholds', entries).strip(), {} @@ -784,7 +787,7 @@ def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', descriptor.flag_thresholds = thresholds
-def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "params" [Parameters] # Parameter ::= Keyword '=' Int32 # Int32 ::= A decimal integer between -2147483648 and 2147483647. @@ -800,7 +803,7 @@ def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entr descriptor._check_params_constraints()
-def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # nothing to parse, simply checking that we don't have a value
value = _value('directory-footer', entries) @@ -809,7 +812,7 @@ def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entri raise ValueError("A network status document's 'directory-footer' line shouldn't have any content, got 'directory-footer %s'" % value)
-def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: signatures = []
for sig_value, block_type, block_contents in entries['directory-signature']: @@ -830,7 +833,7 @@ def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descript descriptor.signatures = signatures
-def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: package_versions = []
for value, _, _ in entries['package']: @@ -851,7 +854,7 @@ def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[ descriptor.packages = package_versions
-def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-commit" Version AlgName Identity Commit [Reveal]
commitments = [] @@ -873,7 +876,7 @@ def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries descriptor.shared_randomness_commitments = commitments
-def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-previous-value" NumReveals Value
value = _value('shared-rand-previous-value', entries) @@ -886,7 +889,7 @@ def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', raise ValueError("A network status document's 'shared-rand-previous-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "shared-rand-current-value" NumReveals Value
value = _value('shared-rand-current-value', entries) @@ -899,7 +902,7 @@ def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', e raise ValueError("A network status document's 'shared-rand-current-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth-file-headers" KeyValues # KeyValues ::= "" | KeyValue | KeyValues SP KeyValue # KeyValue ::= Keyword '=' Value @@ -914,7 +917,7 @@ def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entr descriptor.bandwidth_file_headers = results
-def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth-file-digest" 1*(SP algorithm "=" digest)
value = _value('bandwidth-file-digest', entries) @@ -1098,7 +1101,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): }
@classmethod - def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> str: + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> bytes: attr = {} if attr is None else dict(attr) is_vote = attr.get('vote-status') == 'vote'
@@ -1170,10 +1173,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): return desc_content
@classmethod - def create(cls: Type['stem.descriptor.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.directory.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.NetworkStatusDocumentV3': + def create(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3': return cls(cls.content(attr, exclude, authorities, routers), validate = validate)
- def __init__(self, raw_content: str, validate: bool = False, default_params: bool = True) -> None: + def __init__(self, raw_content: bytes, validate: bool = False, default_params: bool = True) -> None: """ Parse a v3 network status document.
@@ -1188,13 +1191,15 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): super(NetworkStatusDocumentV3, self).__init__(raw_content, lazy_load = not validate) document_file = io.BytesIO(raw_content)
+ self._header_entries = None # type: Optional[ENTRY_TYPE] + self._default_params = default_params self._header(document_file, validate)
self.directory_authorities = tuple(stem.descriptor.router_status_entry._parse_file( document_file, validate, - entry_class = DirectoryAuthority, + entry_class = DirectoryAuthority, # type: ignore # TODO: move to another parse_file() entry_keyword = AUTH_START, section_end_keywords = (ROUTERS_START, FOOTER_START, V2_FOOTER_START), extra_args = (self.is_vote,), @@ -1255,13 +1260,13 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.fresh_until
- def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificates']) -> None: + def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificate']) -> None: """ Validates we're properly signed by the signing certificates.
.. versionadded:: 1.6.0
- :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificates` + :param list key_certs: :class:`~stem.descriptor.networkstatus.KeyCertificate` to validate the consensus against
:raises: **ValueError** if an insufficient number of valid signatures are present. @@ -1289,7 +1294,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): if valid_digests < required_digests: raise ValueError('Network Status Document has %i valid signatures out of %i total, needed %i' % (valid_digests, total_digests, required_digests))
- def get_unrecognized_lines(self) -> Sequence[str]: + def get_unrecognized_lines(self) -> List[str]: if self._lazy_loading: self._parse(self._header_entries, False, parser_for_line = self._HEADER_PARSER_FOR_LINE) self._parse(self._footer_entries, False, parser_for_line = self._FOOTER_PARSER_FOR_LINE) @@ -1308,10 +1313,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): :returns: **True** if we meet the given consensus-method, and **False** otherwise """
- if self.consensus_method is not None: - return self.consensus_method >= method - elif self.consensus_methods is not None: - return bool([x for x in self.consensus_methods if x >= method]) + if self.consensus_method is not None: # type: ignore + return self.consensus_method >= method # type: ignore + elif self.consensus_methods is not None: # type: ignore + return bool([x for x in self.consensus_methods if x >= method]) # type: ignore else: return False # malformed document
@@ -1341,9 +1346,9 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
# default consensus_method and consensus_methods based on if we're a consensus or vote
- if self.is_consensus and not self.consensus_method: + if self.is_consensus and not self.consensus_method: # type: ignore self.consensus_method = 1 - elif self.is_vote and not self.consensus_methods: + elif self.is_vote and not self.consensus_methods: # type: ignore self.consensus_methods = [1] else: self._header_entries = entries @@ -1400,7 +1405,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): raise ValueError("'%s' value on the params line must be in the range of %i - %i, was %i" % (key, minimum, maximum, value))
-def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: Mapping[str, str], fields: Sequence[str]) -> None: +def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: ENTRY_TYPE, fields: Sequence[Tuple[str, bool, bool, bool]]) -> None: """ Checks that we have mandatory fields for our type, and that we don't have any fields exclusive to the other (ie, no vote-only fields appear in a @@ -1438,7 +1443,8 @@ def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, i # - values are integers # - keys are sorted in lexical order
- results, seen_keys = {}, [] + results = {} # type: Dict[str, int] + seen_keys = [] # type: List[str] error_template = "Unable to parse network status document's '%s' line (%%s): %s'" % (keyword, value)
for key, val in _mappings_for(keyword, value): @@ -1463,7 +1469,7 @@ def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, i return results
-def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "dir-source" nickname identity address IP dirport orport
value = _value('dir-source', entries) @@ -1582,7 +1588,7 @@ class DirectoryAuthority(Descriptor): }
@classmethod - def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> str: + def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> bytes: attr = {} if attr is None else dict(attr)
# include mandatory 'vote-digest' if a consensus @@ -1604,7 +1610,7 @@ class DirectoryAuthority(Descriptor): def create(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, is_vote: bool = False) -> 'stem.descriptor.networkstatus.DirectoryAuthority': return cls(cls.content(attr, exclude, is_vote), validate = validate, is_vote = is_vote)
- def __init__(self, raw_content: str, validate: bool = False, is_vote: bool = False) -> None: + def __init__(self, raw_content: bytes, validate: bool = False, is_vote: bool = False) -> None: """ Parse a directory authority entry in a v3 network status document.
@@ -1623,12 +1629,12 @@ class DirectoryAuthority(Descriptor): key_div = content.find('\ndir-key-certificate-version')
if key_div != -1: - self.key_certificate = KeyCertificate(content[key_div + 1:], validate) + self.key_certificate = KeyCertificate(content[key_div + 1:].encode('utf-8'), validate) content = content[:key_div + 1] else: self.key_certificate = None
- entries = _descriptor_components(content, validate) + entries = _descriptor_components(content.encode('utf-8'), validate)
if validate and 'dir-source' != list(entries.keys())[0]: raise ValueError("Authority entries are expected to start with a 'dir-source' line:\n%s" % (content)) @@ -1679,7 +1685,7 @@ class DirectoryAuthority(Descriptor): self._entries = entries
-def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "dir-address" IPPort
value = _value('dir-address', entries) @@ -1754,7 +1760,7 @@ class KeyCertificate(Descriptor): }
@classmethod - def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('dir-key-certificate-version', '3'), ('fingerprint', _random_fingerprint()), @@ -1766,26 +1772,26 @@ class KeyCertificate(Descriptor): ('dir-key-certification', _random_crypto_blob('SIGNATURE')), ))
- def __init__(self, raw_content: str, validate: str = False) -> None: + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(KeyCertificate, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate)
if validate: if 'dir-key-certificate-version' != list(entries.keys())[0]: - raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % (raw_content)) + raise ValueError("Key certificates must start with a 'dir-key-certificate-version' line:\n%s" % stem.util.str_tools._to_unicode(raw_content)) elif 'dir-key-certification' != list(entries.keys())[-1]: - raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % (raw_content)) + raise ValueError("Key certificates must end with a 'dir-key-certification' line:\n%s" % stem.util.str_tools._to_unicode(raw_content))
# check that we have mandatory fields and that our known fields only # appear once
for keyword, is_mandatory in KEY_CERTIFICATE_PARAMS: if is_mandatory and keyword not in entries: - raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, raw_content)) + raise ValueError("Key certificates must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content)))
entry_count = len(entries.get(keyword, [])) if entry_count > 1: - raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content)) + raise ValueError("Key certificates can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content)))
self._parse(entries, validate) else: @@ -1887,7 +1893,7 @@ class DetachedSignature(Descriptor): 'additional_digests': ([], _parse_additional_digests), 'additional_signatures': ([], _parse_additional_signatures), 'signatures': ([], _parse_footer_directory_signature_line), - } + } # type: Dict[str, Tuple[Any, Callable[['stem.descriptor.Descriptor', ENTRY_TYPE], None]]]
PARSER_FOR_LINE = { 'consensus-digest': _parse_consensus_digest_line, @@ -1900,7 +1906,7 @@ class DetachedSignature(Descriptor): }
@classmethod - def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('consensus-digest', '6D3CC0EFA408F228410A4A8145E1B0BB0670E442'), ('valid-after', _random_date()), @@ -1908,23 +1914,23 @@ class DetachedSignature(Descriptor): ('valid-until', _random_date()), ))
- def __init__(self, raw_content: str, validate: bool = False) -> None: + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(DetachedSignature, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate)
if validate: if 'consensus-digest' != list(entries.keys())[0]: - raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % (raw_content)) + raise ValueError("Detached signatures must start with a 'consensus-digest' line:\n%s" % stem.util.str_tools._to_unicode(raw_content))
# check that we have mandatory fields and certain fields only appear once
for keyword, is_mandatory, is_multiple in DETACHED_SIGNATURE_PARAMS: if is_mandatory and keyword not in entries: - raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, raw_content)) + raise ValueError("Detached signatures must have a '%s' line:\n%s" % (keyword, stem.util.str_tools._to_unicode(raw_content)))
entry_count = len(entries.get(keyword, [])) if not is_multiple and entry_count > 1: - raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, raw_content)) + raise ValueError("Detached signatures can only have a single '%s' line, got %i:\n%s" % (keyword, entry_count, stem.util.str_tools._to_unicode(raw_content)))
self._parse(entries, validate) else: @@ -1943,7 +1949,7 @@ class BridgeNetworkStatusDocument(NetworkStatusDocument):
TYPE_ANNOTATION_NAME = 'bridge-network-status'
- def __init__(self, raw_content: str, validate: bool = False) -> None: + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(BridgeNetworkStatusDocument, self).__init__(raw_content)
self.published = None diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index f3c6d6bd..2e2bb53b 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -101,7 +101,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression from stem.util import log, str_tools -from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their # fingerprint or hashes due to a limit on the url length by squid proxies. @@ -371,7 +371,7 @@ class Query(object): the same as running **query.run(True)** (default is **False**) """
- def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence['stem.Endpoint']] = None, compression: Sequence['stem.descriptor.Compression'] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: + def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence[stem.Endpoint]] = None, compression: Union[stem.descriptor._Compression, Sequence[stem.descriptor._Compression]] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: stem.descriptor.DocumentHandler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: if not resource.startswith('/'): raise ValueError("Resources should start with a '/': %s" % resource)
@@ -380,8 +380,10 @@ class Query(object): resource = resource[:-2] elif isinstance(compression, tuple): compression = list(compression) - elif not isinstance(compression, list): + elif isinstance(compression, stem.descriptor._Compression): compression = [compression] # caller provided only a single option + else: + raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
if Compression.ZSTD in compression and not Compression.ZSTD.available: compression.remove(Compression.ZSTD) @@ -411,21 +413,21 @@ class Query(object): self.retries = retries self.fall_back_to_authority = fall_back_to_authority
- self.content = None - self.error = None + self.content = None # type: Optional[bytes] + self.error = None # type: Optional[BaseException] self.is_done = False - self.download_url = None + self.download_url = None # type: Optional[str]
- self.start_time = None + self.start_time = None # type: Optional[float] self.timeout = timeout - self.runtime = None + self.runtime = None # type: Optional[float]
self.validate = validate self.document_handler = document_handler - self.reply_headers = None + self.reply_headers = None # type: Optional[Dict[str, str]] self.kwargs = kwargs
- self._downloader_thread = None + self._downloader_thread = None # type: Optional[threading.Thread] self._downloader_thread_lock = threading.RLock()
if start: @@ -450,7 +452,7 @@ class Query(object): self._downloader_thread.setDaemon(True) self._downloader_thread.start()
- def run(self, suppress: bool = False) -> Sequence['stem.descriptor.Descriptor']: + def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -470,7 +472,7 @@ class Query(object):
return list(self._run(suppress))
- def _run(self, suppress: bool) -> Iterator['stem.descriptor.Descriptor']: + def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]: with self._downloader_thread_lock: self.start() self._downloader_thread.join() @@ -506,11 +508,11 @@ class Query(object):
raise self.error
- def __iter__(self) -> Iterator['stem.descriptor.Descriptor']: + def __iter__(self) -> Iterator[stem.descriptor.Descriptor]: for desc in self._run(True): yield desc
- def _pick_endpoint(self, use_authority: bool = False) -> 'stem.Endpoint': + def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint: """ Provides an endpoint to query. If we have multiple endpoints then one is picked at random. @@ -576,7 +578,7 @@ class DescriptorDownloader(object): def __init__(self, use_mirrors: bool = False, **default_args: Any) -> None: self._default_args = default_args
- self._endpoints = None + self._endpoints = None # type: Optional[List[stem.DirPort]]
if use_mirrors: try: @@ -586,7 +588,7 @@ class DescriptorDownloader(object): except Exception as exc: log.debug('Unable to retrieve directory mirrors: %s' % exc)
- def use_directory_mirrors(self) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3': + def use_directory_mirrors(self) -> stem.descriptor.networkstatus.NetworkStatusDocumentV3: """ Downloads the present consensus and configures ourselves to use directory mirrors, in addition to authorities. @@ -610,7 +612,7 @@ class DescriptorDownloader(object):
self._endpoints = list(new_endpoints)
- return consensus + return consensus # type: ignore
def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ @@ -776,7 +778,7 @@ class DescriptorDownloader(object):
return consensus_query
- def get_vote(self, authority: 'stem.directory.Authority', **query_args: Any) -> 'stem.descriptor.remote.Query': + def get_vote(self, authority: stem.directory.Authority, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the present vote for a given directory authority.
@@ -924,7 +926,7 @@ class DescriptorDownloader(object): return Query(resource, **args)
-def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.Compression'], resource: str) -> Tuple[bytes, Dict[str, str]]: +def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given orport. Payload is just like an http response (headers and all)... @@ -974,7 +976,7 @@ def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.C
for line in str_tools._to_unicode(header_data).splitlines(): if ': ' not in line: - raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % line) + raise stem.ProtocolError("'%s' is not a HTTP header:\n\n%s" % (line, header_data.decode('utf-8')))
key, value = line.split(': ', 1) headers[key] = value @@ -982,7 +984,7 @@ def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.C return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Compression'], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: +def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given url.
@@ -1011,8 +1013,8 @@ def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Comp except socket.timeout as exc: raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) except: - exc, stacktrace = sys.exc_info()[1:3] - raise stem.DownloadFailed(url, exc, stacktrace) + exception, stacktrace = sys.exc_info()[1:3] + raise stem.DownloadFailed(url, exception, stacktrace)
return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
diff --git a/stem/descriptor/router_status_entry.py b/stem/descriptor/router_status_entry.py index 20822c82..2c4937f3 100644 --- a/stem/descriptor/router_status_entry.py +++ b/stem/descriptor/router_status_entry.py @@ -27,9 +27,10 @@ import io import stem.exit_policy import stem.util.str_tools
-from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type +from typing import Any, BinaryIO, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, KEYWORD_LINE, Descriptor, _descriptor_content, @@ -37,7 +38,7 @@ from stem.descriptor import ( _values, _descriptor_components, _parse_protocol_line, - _read_until_keywords, + _read_until_keywords_with_ending_keyword, _random_nickname, _random_ipv4_address, _random_date, @@ -46,7 +47,7 @@ from stem.descriptor import ( _parse_pr_line = _parse_protocol_line('pr', 'protocols')
-def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: int = None, end_position: int = None, section_end_keywords: Sequence[str] = (), extra_args: Sequence[str] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: +def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: Optional[int] = None, end_position: Optional[int] = None, section_end_keywords: Tuple[str, ...] = (), extra_args: Sequence[Any] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: """ Reads a range of the document_file containing some number of entry_class instances. We deliminate the entry_class entries by the keyword on their @@ -93,7 +94,7 @@ def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem return
while end_position is None or document_file.tell() < end_position: - desc_lines, ending_keyword = _read_until_keywords( + desc_lines, ending_keyword = _read_until_keywords_with_ending_keyword( (entry_keyword,) + section_end_keywords, document_file, ignore_first = True, @@ -113,7 +114,7 @@ def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem break
-def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # Parses a RouterStatusEntry's 'r' line. They're very nearly identical for # all current entry types (v2, v3, and microdescriptor v3) with one little # wrinkle: only the microdescriptor flavor excludes a 'digest' field. @@ -165,7 +166,7 @@ def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S raise ValueError("Publication time time wasn't parsable: r %s" % value)
-def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "a" SP address ":" portlist # example: a [2001:888:2133:0:82:94:251:204]:9001
@@ -188,7 +189,7 @@ def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S descriptor.or_addresses = or_addresses
-def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "s" Flags # example: s Named Running Stable Valid
@@ -203,7 +204,7 @@ def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S raise ValueError("%s had extra whitespace on its 's' line: s %s" % (descriptor._name(), value))
-def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "v" version # example: v Tor 0.2.2.35 # @@ -221,7 +222,7 @@ def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S raise ValueError('%s has a malformed tor version (%s): v %s' % (descriptor._name(), exc, value))
-def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "w" "Bandwidth=" INT ["Measured=" INT] ["Unmeasured=1"] # example: w Bandwidth=7980
@@ -268,7 +269,7 @@ def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S descriptor.unrecognized_bandwidth_entries = unrecognized_bandwidth_entries
-def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "p" ("accept" / "reject") PortList # # examples: @@ -284,7 +285,7 @@ def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S raise ValueError('%s exit policy is malformed (%s): p %s' % (descriptor._name(), exc, value))
-def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "id" "ed25519" ed25519-identity # # examples: @@ -307,7 +308,7 @@ def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, raise ValueError("'id' lines should contain both the key type and digest: id %s" % value)
-def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "m" methods 1*(algorithm "=" digest) # example: m 8,9,10,11,12 sha256=g1vx9si329muxV3tquWIXXySNOIwRGMeAESKs/v4DWs
@@ -341,7 +342,7 @@ def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, S descriptor.microdescriptor_hashes = all_hashes
-def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries): +def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "m" digest # example: m aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70
@@ -422,7 +423,7 @@ class RouterStatusEntry(Descriptor): }
@classmethod - def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> 'stem.descriptor.router_status_entry.RouterStatusEntry': + def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> Union['stem.descriptor.router_status_entry.RouterStatusEntry', List['stem.descriptor.router_status_entry.RouterStatusEntry']]: # type: ignore # Router status entries don't have their own @type annotation, so to make # our subclass from_str() work we need to do the type inferencing ourself.
@@ -442,7 +443,7 @@ class RouterStatusEntry(Descriptor): else: raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
- def __init__(self, content: str, validate: bool = False, document: Optional['stem.descriptor.NetworkStatusDocument'] = None) -> None: + def __init__(self, content: bytes, validate: bool = False, document: Optional['stem.descriptor.networkstatus.NetworkStatusDocument'] = None) -> None: """ Parse a router descriptor in a network status document.
@@ -481,14 +482,14 @@ class RouterStatusEntry(Descriptor):
return 'Router status entries' if is_plural else 'Router status entry'
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: """ Provides lines that must appear in the descriptor. """
return ()
- def _single_fields(self) -> Tuple[str]: + def _single_fields(self) -> Tuple[str, ...]: """ Provides lines that can only appear in the descriptor once. """ @@ -514,7 +515,7 @@ class RouterStatusEntryV2(RouterStatusEntry): })
@classmethod - def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), )) @@ -522,10 +523,10 @@ class RouterStatusEntryV2(RouterStatusEntry): def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v2)' if is_plural else 'Router status entry (v2)'
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: return ('r',)
- def _single_fields(self) -> Tuple[str]: + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v')
@@ -579,7 +580,7 @@ class RouterStatusEntryV3(RouterStatusEntry):
TYPE_ANNOTATION_NAME = 'network-status-consensus-3'
- ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ + ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore 'digest': (None, _parse_r_line), 'or_addresses': ([], _parse_a_line), 'identifier_type': (None, _parse_id_line), @@ -595,7 +596,7 @@ class RouterStatusEntryV3(RouterStatusEntry): 'microdescriptor_hashes': ([], _parse_m_line), })
- PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ + PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore 'a': _parse_a_line, 'w': _parse_w_line, 'p': _parse_p_line, @@ -605,7 +606,7 @@ class RouterStatusEntryV3(RouterStatusEntry): })
@classmethod - def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('s', 'Fast Named Running Stable Valid'), @@ -614,10 +615,10 @@ class RouterStatusEntryV3(RouterStatusEntry): def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v3)' if is_plural else 'Router status entry (v3)'
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: return ('r', 's')
- def _single_fields(self) -> Tuple[str]: + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v', 'w', 'p', 'pr')
@@ -652,7 +653,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry):
TYPE_ANNOTATION_NAME = 'network-status-microdesc-consensus-3'
- ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ + ATTRIBUTES = dict(RouterStatusEntry.ATTRIBUTES, **{ # type: ignore 'or_addresses': ([], _parse_a_line), 'bandwidth': (None, _parse_w_line), 'measured': (None, _parse_w_line), @@ -662,7 +663,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): 'microdescriptor_digest': (None, _parse_microdescriptor_m_line), })
- PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ + PARSER_FOR_LINE = dict(RouterStatusEntry.PARSER_FOR_LINE, **{ # type: ignore 'a': _parse_a_line, 'w': _parse_w_line, 'm': _parse_microdescriptor_m_line, @@ -670,7 +671,7 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): })
@classmethod - def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('r', '%s ARIJF2zbqirB9IwsW0mQznccWww %s %s 9001 9030' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('m', 'aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70'), @@ -680,8 +681,8 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): def _name(self, is_plural: bool = False) -> str: return 'Router status entries (micro v3)' if is_plural else 'Router status entry (micro v3)'
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: return ('r', 's', 'm')
- def _single_fields(self) -> Tuple[str]: + def _single_fields(self) -> Tuple[str, ...]: return ('r', 's', 'v', 'w', 'm', 'pr') diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py index 11b44972..fbb5c633 100644 --- a/stem/descriptor/server_descriptor.py +++ b/stem/descriptor/server_descriptor.py @@ -61,16 +61,17 @@ import stem.version
from stem.descriptor.certificate import Ed25519Certificate from stem.descriptor.router_status_entry import RouterStatusEntryV3 -from typing import Any, BinaryIO, Dict, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union +from typing import Any, BinaryIO, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union
from stem.descriptor import ( + ENTRY_TYPE, PGP_BLOCK_END, Descriptor, DigestHash, DigestEncoding, create_signing_key, _descriptor_content, - _descriptor_components, + _descriptor_components_with_extra, _read_until_keywords, _bytes_for_block, _value, @@ -214,14 +215,17 @@ def _parse_file(descriptor_file: BinaryIO, is_bridge: bool = False, validate: bo descriptor_text = bytes.join(b'', descriptor_content)
if is_bridge: - yield BridgeDescriptor(descriptor_text, validate, **kwargs) + if kwargs: + raise ValueError('BUG: keyword arguments unused by bridge descriptors') + + yield BridgeDescriptor(descriptor_text, validate) else: yield RelayDescriptor(descriptor_text, validate, **kwargs) else: break # done parsing descriptors
-def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "router" nickname address ORPort SocksPort DirPort
value = _value('router', entries) @@ -247,7 +251,7 @@ def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[s descriptor.dir_port = None if router_comp[4] == '0' else int(router_comp[4])
-def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "bandwidth" bandwidth-avg bandwidth-burst bandwidth-observed
value = _value('bandwidth', entries) @@ -267,7 +271,7 @@ def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic descriptor.observed_bandwidth = int(bandwidth_comp[2])
-def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "platform" string
_parse_bytes_line('platform', 'platform')(descriptor, entries) @@ -293,7 +297,7 @@ def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict pass
-def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # This is forty hex digits split into space separated groups of four. # Checking that we match this pattern.
@@ -310,7 +314,7 @@ def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: D descriptor.fingerprint = fingerprint
-def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('extra-info-digest', entries) digest_comp = value.split(' ')
@@ -321,7 +325,7 @@ def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entri descriptor.extra_info_sha256_digest = digest_comp[1] if len(digest_comp) >= 2 else None
-def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: # "hibernating" 0|1 (in practice only set if one)
value = _value('hibernating', entries) @@ -332,7 +336,7 @@ def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: D descriptor.hibernating = value == '1'
-def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value('protocols', entries) protocols_match = re.match('^Link (.*) Circuit (.*)$', value)
@@ -344,7 +348,7 @@ def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dic descriptor.circuit_protocols = circuit_versions.split(' ')
-def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: all_values = _values('or-address', entries) or_addresses = []
@@ -367,7 +371,7 @@ def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Di descriptor.or_addresses = or_addresses
-def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: value = _value(keyword, entries) timestamp, interval, remainder = stem.descriptor.extrainfo_descriptor._parse_timestamp_and_interval(keyword, value)
@@ -384,7 +388,7 @@ def _parse_history_line(keyword: str, history_end_attribute: str, history_interv setattr(descriptor, history_values_attribute, history_values)
-def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: +def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: ENTRY_TYPE) -> None: if hasattr(descriptor, '_unparsed_exit_policy'): if descriptor._unparsed_exit_policy and stem.util.str_tools._to_unicode(descriptor._unparsed_exit_policy[0]) == 'reject *:*': descriptor.exit_policy = REJECT_ALL_POLICY @@ -577,7 +581,7 @@ class ServerDescriptor(Descriptor): 'eventdns': _parse_eventdns_line, }
- def __init__(self, raw_contents: str, validate: bool = False) -> None: + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: """ Server descriptor constructor, created from an individual relay's descriptor content (as provided by 'GETINFO desc/*', cached descriptors, @@ -604,7 +608,7 @@ class ServerDescriptor(Descriptor): # influences the resulting exit policy, but for everything else the order # does not matter so breaking it into key / value pairs.
- entries, self._unparsed_exit_policy = _descriptor_components(stem.util.str_tools._to_unicode(raw_contents), validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform')) + entries, self._unparsed_exit_policy = _descriptor_components_with_extra(raw_contents, validate, extra_keywords = ('accept', 'reject'), non_ascii_fields = ('contact', 'platform'))
if validate: self._parse(entries, validate) @@ -622,7 +626,7 @@ class ServerDescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Digest of this descriptor's content. These are referenced by...
@@ -642,7 +646,7 @@ class ServerDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ServerDescriptor subclass')
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: + def _check_constraints(self, entries: ENTRY_TYPE) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. @@ -680,16 +684,16 @@ class ServerDescriptor(Descriptor): # Constraints that the descriptor must meet to be valid. These can be None if # not applicable.
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS
- def _single_fields(self) -> Tuple[str]: + def _single_fields(self) -> Tuple[str, ...]: return REQUIRED_FIELDS + SINGLE_FIELDS
def _first_keyword(self) -> str: return 'router'
- def _last_keyword(self) -> str: + def _last_keyword(self) -> Optional[str]: return 'router-signature'
@@ -754,7 +758,7 @@ class RelayDescriptor(ServerDescriptor): 'router-signature': _parse_router_signature_line, })
- def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None: + def __init__(self, raw_contents: bytes, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(RelayDescriptor, self).__init__(raw_contents, validate)
if validate: @@ -786,9 +790,8 @@ class RelayDescriptor(ServerDescriptor): pass # cryptography module unavailable
@classmethod - def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> str: - if attr is None: - attr = {} + def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> bytes: + attr = dict(attr) if attr else {}
if exit_policy is None: exit_policy = REJECT_ALL_POLICY @@ -798,7 +801,7 @@ class RelayDescriptor(ServerDescriptor): ('published', _random_date()), ('bandwidth', '153600 256000 104590'), ] + [ - tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines() + tuple(line.split(' ', 1)) for line in str(exit_policy).splitlines() # type: ignore ] + [ ('onion-key', _random_crypto_blob('RSA PUBLIC KEY')), ('signing-key', _random_crypto_blob('RSA PUBLIC KEY')), @@ -832,7 +835,7 @@ class RelayDescriptor(ServerDescriptor): return cls(cls.content(attr, exclude, sign, signing_key, exit_policy), validate = validate, skip_crypto_validation = not sign)
@functools.lru_cache() - def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore """ Provides the digest of our descriptor's content.
@@ -889,7 +892,7 @@ class RelayDescriptor(ServerDescriptor): if self.certificate: attr['id'] = 'ed25519 %s' % _truncated_b64encode(self.certificate.key)
- return RouterStatusEntryV3.create(attr) + return RouterStatusEntryV3.create(attr) # type: ignore
@functools.lru_cache() def _onion_key_crosscert_digest(self) -> str: @@ -906,7 +909,7 @@ class RelayDescriptor(ServerDescriptor): data = signing_key_digest + base64.b64decode(stem.util.str_tools._to_bytes(self.ed25519_master_key) + b'=') return stem.util.str_tools._to_unicode(binascii.hexlify(data).upper())
- def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: + def _check_constraints(self, entries: ENTRY_TYPE) -> None: super(RelayDescriptor, self)._check_constraints(entries)
if self.certificate: @@ -945,7 +948,7 @@ class BridgeDescriptor(ServerDescriptor): })
@classmethod - def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: + def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> bytes: return _descriptor_content(attr, exclude, ( ('router', '%s %s 9001 0 0' % (_random_nickname(), _random_ipv4_address())), ('router-digest', '006FD96BA35E7785A6A3B8B75FE2E2435A13BDB4'), @@ -954,7 +957,7 @@ class BridgeDescriptor(ServerDescriptor): ('reject', '*:*'), ))
- def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib._HASH']: # type: ignore if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest else: @@ -1007,7 +1010,7 @@ class BridgeDescriptor(ServerDescriptor):
return issues
- def _required_fields(self) -> Tuple[str]: + def _required_fields(self) -> Tuple[str, ...]: # bridge required fields are the same as a relay descriptor, minus items # excluded according to the format page
@@ -1023,8 +1026,8 @@ class BridgeDescriptor(ServerDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _single_fields(self) -> str: + def _single_fields(self) -> Tuple[str, ...]: return self._required_fields() + SINGLE_FIELDS
- def _last_keyword(self) -> str: + def _last_keyword(self) -> Optional[str]: return None diff --git a/stem/descriptor/tordnsel.py b/stem/descriptor/tordnsel.py index c36e343d..6b9d4ceb 100644 --- a/stem/descriptor/tordnsel.py +++ b/stem/descriptor/tordnsel.py @@ -10,13 +10,16 @@ exit list files. TorDNSEL - Exit list provided by TorDNSEL """
+import datetime + import stem.util.connection import stem.util.str_tools import stem.util.tor_tools
-from typing import Any, BinaryIO, Dict, Iterator, Sequence +from typing import Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Tuple
from stem.descriptor import ( + ENTRY_TYPE, Descriptor, _read_until_keywords, _descriptor_components, @@ -35,6 +38,9 @@ def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any) * **IOError** if the file can't be read """
+ if kwargs: + raise ValueError("TorDNSEL doesn't support additional arguments: %s" % kwargs) + # skip content prior to the first ExitNode _read_until_keywords('ExitNode', tordnsel_file, skip = True)
@@ -43,7 +49,7 @@ def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any) contents += _read_until_keywords('ExitNode', tordnsel_file)
if contents: - yield TorDNSEL(bytes.join(b'', contents), validate, **kwargs) + yield TorDNSEL(bytes.join(b'', contents), validate) else: break # done parsing file
@@ -64,19 +70,21 @@ class TorDNSEL(Descriptor):
TYPE_ANNOTATION_NAME = 'tordnsel'
- def __init__(self, raw_contents: str, validate: bool) -> None: + def __init__(self, raw_contents: bytes, validate: bool) -> None: super(TorDNSEL, self).__init__(raw_contents) - raw_contents = stem.util.str_tools._to_unicode(raw_contents) entries = _descriptor_components(raw_contents, validate)
- self.fingerprint = None - self.published = None - self.last_status = None - self.exit_addresses = [] + self.fingerprint = None # type: Optional[str] + self.published = None # type: Optional[datetime.datetime] + self.last_status = None # type: Optional[datetime.datetime] + self.exit_addresses = [] # type: List[Tuple[str, datetime.datetime]]
self._parse(entries, validate)
- def _parse(self, entries: Dict[str, Sequence[str]], validate: bool) -> None: + def _parse(self, entries: ENTRY_TYPE, validate: bool, parser_for_line: Optional[Dict[str, Callable]] = None) -> None: + if parser_for_line: + raise ValueError('parser_for_line is unused by TorDNSEL') + for keyword, values in list(entries.items()): value, block_type, block_content = values[0]
@@ -102,7 +110,7 @@ class TorDNSEL(Descriptor): raise ValueError("LastStatus time wasn't parsable: %s" % value) elif keyword == 'ExitAddress': for value, block_type, block_content in values: - address, date = value.split(' ', 1) + address, date_str = value.split(' ', 1)
if validate: if not stem.util.connection.is_valid_ipv4_address(address): @@ -111,7 +119,7 @@ class TorDNSEL(Descriptor): raise ValueError('Unexpected block content: %s' % block_content)
try: - date = stem.util.str_tools._parse_timestamp(date) + date = stem.util.str_tools._parse_timestamp(date_str) self.exit_addresses.append((address, date)) except ValueError: if validate: diff --git a/stem/directory.py b/stem/directory.py index f96adfbb..3ecb0b71 100644 --- a/stem/directory.py +++ b/stem/directory.py @@ -49,7 +49,7 @@ import stem.util import stem.util.conf
from stem.util import connection, str_tools, tor_tools -from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Pattern, Sequence, Tuple +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Pattern, Sequence, Tuple, Union
GITWEB_AUTHORITY_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/auth_dirs.inc' GITWEB_FALLBACK_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/fallback_dirs.inc' @@ -69,7 +69,7 @@ FALLBACK_EXTRAINFO = re.compile('/\* extrainfo=([0-1]) \*/') FALLBACK_IPV6 = re.compile('" ipv6=\[([\da-f:]+)\]:(\d+)"')
-def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[bool] = None) -> Dict[Pattern, Tuple[str]]: +def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Dict[Pattern, Union[str, List[str]]]: """ Scans the given content against a series of regex matchers, providing back a mapping of regexes to their capture groups. This maping is with the value if @@ -102,7 +102,7 @@ def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Opti return matches
-def _directory_entries(lines: Sequence[str], pop_section_func: Callable[[Sequence[str]], Sequence[str]], regexes: Sequence[Pattern], required: bool = None) -> Iterator[Dict[Pattern, Tuple[str]]]: +def _directory_entries(lines: List[str], pop_section_func: Callable[[List[str]], List[str]], regexes: Sequence[Pattern], required: Optional[Sequence[Pattern]] = None) -> Iterator[Dict[Pattern, Union[str, List[str]]]]: next_section = pop_section_func(lines)
while next_section: @@ -130,11 +130,11 @@ class Directory(object): :var int dir_port: port on which directory information is available :var str fingerprint: relay fingerprint :var str nickname: relay nickname - :var str orport_v6: **(address, port)** tuple for the directory's IPv6 + :var tuple orport_v6: **(address, port)** tuple for the directory's IPv6 ORPort, or **None** if it doesn't have one """
- def __init__(self, address: str, or_port: int, dir_port: int, fingerprint: str, nickname: str, orport_v6: str) -> None: + def __init__(self, address: str, or_port: Union[int, str], dir_port: Union[int, str], fingerprint: str, nickname: str, orport_v6: Tuple[str, int]) -> None: identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
if not connection.is_valid_ipv4_address(address): @@ -164,7 +164,7 @@ class Directory(object): self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None
@staticmethod - def from_cache() -> Dict[str, 'stem.directory.Directory']: + def from_cache() -> Dict[str, Any]: """ Provides cached Tor directory information. This information is hardcoded into Tor and occasionally changes, so the information provided by this @@ -182,7 +182,7 @@ class Directory(object): raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
@staticmethod - def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Directory']: + def from_remote(timeout: int = 60) -> Dict[str, Any]: """ Reads and parses tor's directory data `from gitweb.torproject.org https://gitweb.torproject.org/`_. Note that while convenient, this reliance on GitWeb means you should alway @@ -232,7 +232,7 @@ class Authority(Directory): :var str v3ident: identity key fingerprint used to sign votes and consensus """
- def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[int] = None, v3ident: Optional[str] = None) -> None: + def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[Tuple[str, int]] = None, v3ident: Optional[str] = None) -> None: super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
if v3ident and not tor_tools.is_valid_fingerprint(v3ident): @@ -276,8 +276,8 @@ class Authority(Directory): dir_port = dir_port, fingerprint = fingerprint.replace(' ', ''), nickname = nickname, - orport_v6 = matches.get(AUTHORITY_IPV6), - v3ident = matches.get(AUTHORITY_V3IDENT), + orport_v6 = matches.get(AUTHORITY_IPV6), # type: ignore + v3ident = matches.get(AUTHORITY_V3IDENT), # type: ignore ) except ValueError as exc: raise IOError(str(exc)) @@ -285,7 +285,7 @@ class Authority(Directory): return results
@staticmethod - def _pop_section(lines: Sequence[str]) -> Sequence[str]: + def _pop_section(lines: List[str]) -> List[str]: """ Provides the next authority entry. """ @@ -349,7 +349,7 @@ class Fallback(Directory): :var collections.OrderedDict header: metadata about the fallback directory file this originated from """
- def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[int] = None, header: Optional[Mapping[str, str]] = None) -> None: + def __init__(self, address: Optional[str] = None, or_port: Optional[Union[int, str]] = None, dir_port: Optional[Union[int, str]] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[Tuple[str, int]] = None, header: Optional[Mapping[str, str]] = None) -> None: super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6) self.has_extrainfo = has_extrainfo self.header = collections.OrderedDict(header) if header else collections.OrderedDict() @@ -440,9 +440,9 @@ class Fallback(Directory): or_port = int(or_port), dir_port = int(dir_port), fingerprint = fingerprint, - nickname = matches.get(FALLBACK_NICKNAME), + nickname = matches.get(FALLBACK_NICKNAME), # type: ignore has_extrainfo = matches.get(FALLBACK_EXTRAINFO) == '1', - orport_v6 = matches.get(FALLBACK_IPV6), + orport_v6 = matches.get(FALLBACK_IPV6), # type: ignore header = header, ) except ValueError as exc: @@ -451,7 +451,7 @@ class Fallback(Directory): return results
@staticmethod - def _pop_section(lines: Sequence[str]) -> Sequence[str]: + def _pop_section(lines: List[str]) -> List[str]: """ Provides lines up through the next divider. This excludes lines with just a comma since they're an artifact of these being C strings. @@ -514,7 +514,7 @@ class Fallback(Directory): return not self == other
-def _fallback_directory_differences(previous_directories: Sequence['stem.directory.Dirctory'], new_directories: Sequence['stem.directory.Directory']) -> str: +def _fallback_directory_differences(previous_directories: Mapping[str, 'stem.directory.Fallback'], new_directories: Mapping[str, 'stem.directory.Fallback']) -> str: """ Provides a description of how fallback directories differ. """ diff --git a/stem/exit_policy.py b/stem/exit_policy.py index 076611d2..19178c9a 100644 --- a/stem/exit_policy.py +++ b/stem/exit_policy.py @@ -71,7 +71,7 @@ import stem.util.connection import stem.util.enum import stem.util.str_tools
-from typing import Any, Iterator, Optional, Sequence, Union +from typing import Any, Iterator, List, Optional, Sequence, Set, Union
AddressType = stem.util.enum.Enum(('WILDCARD', 'Wildcard'), ('IPv4', 'IPv4'), ('IPv6', 'IPv6'))
@@ -167,6 +167,8 @@ class ExitPolicy(object): def __init__(self, *rules: Union[str, 'stem.exit_policy.ExitPolicyRule']) -> None: # sanity check the types
+ self._input_rules = None # type: Optional[Union[bytes, Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]]] + for rule in rules: if not isinstance(rule, (bytes, str)) and not isinstance(rule, ExitPolicyRule): raise TypeError('Exit policy rules can only contain strings or ExitPolicyRules, got a %s (%s)' % (type(rule), rules)) @@ -183,13 +185,14 @@ class ExitPolicy(object): is_all_str = False
if rules and is_all_str: - byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules] + byte_rules = [stem.util.str_tools._to_bytes(r) for r in rules] # type: ignore self._input_rules = zlib.compress(b','.join(byte_rules)) else: self._input_rules = rules
- self._rules = None - self._hash = None + self._policy_str = None # type: Optional[str] + self._rules = None # type: List[stem.exit_policy.ExitPolicyRule] + self._hash = None # type: Optional[int]
# Result when no rules apply. According to the spec policies default to 'is # allowed', but our microdescriptor policy subclass might want to change @@ -228,7 +231,7 @@ class ExitPolicy(object): otherwise. """
- rejected_ports = set() + rejected_ports = set() # type: Set[int]
for rule in self._get_rules(): if rule.is_accept: @@ -298,7 +301,8 @@ class ExitPolicy(object):
# convert port list to a list of ranges (ie, ['1-3'] rather than [1, 2, 3]) if display_ports: - display_ranges, temp_range = [], [] + display_ranges = [] + temp_range = [] # type: List[int] display_ports.sort() display_ports.append(None) # ending item to include last range in loop
@@ -384,23 +388,28 @@ class ExitPolicy(object): input_rules = self._input_rules
if self._rules is None and input_rules is not None: - rules = [] + rules = [] # type: List[stem.exit_policy.ExitPolicyRule] is_all_accept, is_all_reject = True, True + decompressed_rules = None # type: Optional[Sequence[Union[str, bytes, stem.exit_policy.ExitPolicyRule]]]
if isinstance(input_rules, bytes): decompressed_rules = zlib.decompress(input_rules).split(b',') else: decompressed_rules = input_rules
- for rule in decompressed_rules: - if isinstance(rule, bytes): - rule = stem.util.str_tools._to_unicode(rule) + for rule_val in decompressed_rules: + if isinstance(rule_val, bytes): + rule_val = stem.util.str_tools._to_unicode(rule_val)
- if isinstance(rule, (bytes, str)): - if not rule.strip(): + if isinstance(rule_val, str): + if not rule_val.strip(): continue
- rule = ExitPolicyRule(rule.strip()) + rule = ExitPolicyRule(rule_val.strip()) + elif isinstance(rule_val, stem.exit_policy.ExitPolicyRule): + rule = rule_val + else: + raise TypeError('BUG: unexpected type within decompressed policy: %s (%s)' % (stem.util.str_tools._to_unicode(rule_val), type(rule_val).__name__))
if rule.is_accept: is_all_reject = False @@ -446,9 +455,11 @@ class ExitPolicy(object): for rule in self._get_rules(): yield rule
- @functools.lru_cache() def __str__(self) -> str: - return ', '.join([str(rule) for rule in self._get_rules()]) + if self._policy_str is None: + self._policy_str = ', '.join([str(rule) for rule in self._get_rules()]) + + return self._policy_str
def __hash__(self) -> int: if self._hash is None: @@ -505,7 +516,7 @@ class MicroExitPolicy(ExitPolicy): # PortList ::= PortList "," PortOrRange # PortOrRange ::= INT "-" INT / INT
- self._policy = policy + policy_str = policy
if policy.startswith('accept'): self.is_accept = True @@ -517,7 +528,7 @@ class MicroExitPolicy(ExitPolicy): policy = policy[6:]
if not policy.startswith(' '): - raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % self._policy) + raise ValueError('A microdescriptor exit policy should have a space separating accept/reject from its port list: %s' % policy_str)
policy = policy.lstrip()
@@ -538,9 +549,10 @@ class MicroExitPolicy(ExitPolicy):
super(MicroExitPolicy, self).__init__(*rules) self._is_allowed_default = not self.is_accept + self._policy_str = policy_str
def __str__(self) -> str: - return self._policy + return self._policy_str
def __hash__(self) -> int: return hash(str(self)) @@ -606,17 +618,17 @@ class ExitPolicyRule(object): if ':' not in exitpattern or ']' in exitpattern.rsplit(':', 1)[1]: raise ValueError("An exitpattern must be of the form 'addrspec:portspec': %s" % rule)
- self.address = None - self._address_type = None - self._masked_bits = None - self.min_port = self.max_port = None - self._hash = None + self.address = None # type: Optional[str] + self._address_type = None # type: Optional[stem.exit_policy.AddressType] + self._masked_bits = None # type: Optional[int] + self.min_port = self.max_port = None # type: Optional[int] + self._hash = None # type: Optional[int]
# Our mask in ip notation (ex. '255.255.255.0'). This is only set if we # either have a custom mask that can't be represented by a number of bits, # or the user has called mask(), lazily loading this.
- self._mask = None + self._mask = None # type: Optional[str]
# Malformed exit policies are rejected, but there's an exception where it's # just skipped: when an accept6/reject6 rule has an IPv4 address... diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index 2a3cff18..1d08abb6 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -54,13 +54,13 @@ def main() -> None: import stem.interpreter.commands
try: - args = stem.interpreter.arguments.parse(sys.argv[1:]) + args = stem.interpreter.arguments.Arguments.parse(sys.argv[1:]) except ValueError as exc: print(exc) sys.exit(1)
if args.print_help: - print(stem.interpreter.arguments.get_help()) + print(stem.interpreter.arguments.Arguments.get_help()) sys.exit()
if args.disable_color or not sys.stdout.isatty(): @@ -82,13 +82,11 @@ def main() -> None: if not args.run_cmd and not args.run_path: print(format(msg('msg.starting_tor'), *HEADER_OUTPUT))
- control_port = '9051' if args.control_port == 'default' else str(args.control_port) - try: stem.process.launch_tor_with_config( config = { 'SocksPort': '0', - 'ControlPort': control_port, + 'ControlPort': '9051' if args.control_port is None else str(args.control_port), 'CookieAuthentication': '1', 'ExitPolicy': 'reject *:*', }, @@ -115,7 +113,7 @@ def main() -> None: control_port = control_port, control_socket = control_socket, password_prompt = True, - ) + ) # type: stem.control.Controller
if controller is None: sys.exit(1) @@ -126,7 +124,7 @@ def main() -> None:
if args.run_cmd: if args.run_cmd.upper().startswith('SETEVENTS '): - controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) + controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) # type: ignore
if sys.stdout.isatty(): events = args.run_cmd.upper().split(' ', 1)[1] diff --git a/stem/interpreter/arguments.py b/stem/interpreter/arguments.py index 8ac1c2c1..dd0b19bb 100644 --- a/stem/interpreter/arguments.py +++ b/stem/interpreter/arguments.py @@ -5,103 +5,102 @@ Commandline argument parsing for our interpreter prompt. """
-import collections import getopt import os
import stem.interpreter import stem.util.connection
-from typing import NamedTuple, Sequence - -DEFAULT_ARGS = { - 'control_address': '127.0.0.1', - 'control_port': 'default', - 'user_provided_port': False, - 'control_socket': '/var/run/tor/control', - 'user_provided_socket': False, - 'tor_path': 'tor', - 'run_cmd': None, - 'run_path': None, - 'disable_color': False, - 'print_help': False, -} +from typing import Any, Dict, NamedTuple, Optional, Sequence
OPT = 'i:s:h' OPT_EXPANDED = ['interface=', 'socket=', 'tor=', 'run=', 'no-color', 'help']
-def parse(argv: Sequence[str]) -> NamedTuple: - """ - Parses our arguments, providing a named tuple with their values. - - :param list argv: input arguments to be parsed - - :returns: a **named tuple** with our parsed arguments - - :raises: **ValueError** if we got an invalid argument - """ - - args = dict(DEFAULT_ARGS) - - try: - recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) - - if unrecognized_args: - error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" - raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) - except Exception as exc: - raise ValueError('%s (for usage provide --help)' % exc) - - for opt, arg in recognized_args: - if opt in ('-i', '--interface'): - if ':' in arg: - address, port = arg.rsplit(':', 1) - else: - address, port = None, arg - - if address is not None: - if not stem.util.connection.is_valid_ipv4_address(address): - raise ValueError("'%s' isn't a valid IPv4 address" % address) - - args['control_address'] = address - - if not stem.util.connection.is_valid_port(port): - raise ValueError("'%s' isn't a valid port number" % port) - - args['control_port'] = int(port) - args['user_provided_port'] = True - elif opt in ('-s', '--socket'): - args['control_socket'] = arg - args['user_provided_socket'] = True - elif opt in ('--tor'): - args['tor_path'] = arg - elif opt in ('--run'): - if os.path.exists(arg): - args['run_path'] = arg - else: - args['run_cmd'] = arg - elif opt == '--no-color': - args['disable_color'] = True - elif opt in ('-h', '--help'): - args['print_help'] = True - - # translates our args dict into a named tuple - - Args = collections.namedtuple('Args', args.keys()) - return Args(**args) - - -def get_help() -> str: - """ - Provides our --help usage information. - - :returns: **str** with our usage information - """ - - return stem.interpreter.msg( - 'msg.help', - address = DEFAULT_ARGS['control_address'], - port = DEFAULT_ARGS['control_port'], - socket = DEFAULT_ARGS['control_socket'], - ) +class Arguments(NamedTuple): + control_address: str = '127.0.0.1' + control_port: Optional[int] = None + user_provided_port: bool = False + control_socket: str = '/var/run/tor/control' + user_provided_socket: bool = False + tor_path: str = 'tor' + run_cmd: Optional[str] = None + run_path: Optional[str] = None + disable_color: bool = False + print_help: bool = False + + @staticmethod + def parse(argv: Sequence[str]) -> 'stem.interpreter.arguments.Arguments': + """ + Parses our commandline arguments into this class. + + :param list argv: input arguments to be parsed + + :returns: :class:`stem.interpreter.arguments.Arguments` for this + commandline input + + :raises: **ValueError** if we got an invalid argument + """ + + args = {} # type: Dict[str, Any] + + try: + recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore + + if unrecognized_args: + error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" + raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) + except Exception as exc: + raise ValueError('%s (for usage provide --help)' % exc) + + for opt, arg in recognized_args: + if opt in ('-i', '--interface'): + if ':' in arg: + address, port = arg.rsplit(':', 1) + else: + address, port = None, arg + + if address is not None: + if not stem.util.connection.is_valid_ipv4_address(address): + raise ValueError("'%s' isn't a valid IPv4 address" % address) + + args['control_address'] = address + + if not stem.util.connection.is_valid_port(port): + raise ValueError("'%s' isn't a valid port number" % port) + + args['control_port'] = int(port) + args['user_provided_port'] = True + elif opt in ('-s', '--socket'): + args['control_socket'] = arg + args['user_provided_socket'] = True + elif opt in ('--tor'): + args['tor_path'] = arg + elif opt in ('--run'): + if os.path.exists(arg): + args['run_path'] = arg + else: + args['run_cmd'] = arg + elif opt == '--no-color': + args['disable_color'] = True + elif opt in ('-h', '--help'): + args['print_help'] = True + + return Arguments(**args) + + @staticmethod + def get_help() -> str: + """ + Provides our --help usage information. + + :returns: **str** with our usage information + """ + + defaults = Arguments() + + return stem.interpreter.msg( + 'msg.help', + address = defaults.control_address, + port = defaults.control_port if defaults.control_port else 'default', + socket = defaults.control_socket, + ) diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py index 671085a7..e310ed28 100644 --- a/stem/interpreter/autocomplete.py +++ b/stem/interpreter/autocomplete.py @@ -7,12 +7,15 @@ Tab completion for our interpreter prompt.
import functools
+import stem.control +import stem.util.conf + from stem.interpreter import uses_settings -from typing import Optional, Sequence +from typing import List, Optional
@uses_settings -def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf.Config') -> Sequence[str]: +def _get_commands(controller: stem.control.Controller, config: stem.util.conf.Config) -> List[str]: """ Provides commands recognized by tor. """ @@ -77,11 +80,11 @@ def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf
class Autocompleter(object): - def __init__(self, controller: 'stem.control.Controller') -> None: + def __init__(self, controller: stem.control.Controller) -> None: self._commands = _get_commands(controller)
@functools.lru_cache() - def matches(self, text: str) -> Sequence[str]: + def matches(self, text: str) -> List[str]: """ Provides autocompletion matches for the given text.
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index 1d610dac..254e46a1 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -21,12 +21,12 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg from stem.util.term import format -from typing import BinaryIO, Iterator, Sequence, Tuple +from typing import Iterator, List, TextIO
MAX_EVENTS = 100
-def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str: +def _get_fingerprint(arg: str, controller: stem.control.Controller) -> str: """ Resolves user input into a relay fingerprint. This accepts...
@@ -91,7 +91,7 @@ def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str:
@contextlib.contextmanager -def redirect(stdout: BinaryIO, stderr: BinaryIO) -> Iterator[None]: +def redirect(stdout: TextIO, stderr: TextIO) -> Iterator[None]: original = sys.stdout, sys.stderr sys.stdout, sys.stderr = stdout, stderr
@@ -107,8 +107,8 @@ class ControlInterpreter(code.InteractiveConsole): for special irc style subcommands. """
- def __init__(self, controller: 'stem.control.Controller') -> None: - self._received_events = [] + def __init__(self, controller: stem.control.Controller) -> None: + self._received_events = [] # type: List[stem.response.events.Event]
code.InteractiveConsole.__init__(self, { 'stem': stem, @@ -130,18 +130,19 @@ class ControlInterpreter(code.InteractiveConsole):
handle_event_real = self._controller._handle_event
- def handle_event_wrapper(event_message: 'stem.response.events.Event') -> None: + def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None: handle_event_real(event_message) - self._received_events.insert(0, event_message) + self._received_events.insert(0, event_message) # type: ignore
if len(self._received_events) > MAX_EVENTS: self._received_events.pop()
- self._controller._handle_event = handle_event_wrapper + # type check disabled due to https://github.com/python/mypy/issues/708
- def get_events(self, *event_types: 'stem.control.EventType') -> Sequence['stem.response.events.Event']: + self._controller._handle_event = handle_event_wrapper # type: ignore + + def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]: events = list(self._received_events) - event_types = list(map(str.upper, event_types)) # make filtering case insensitive
if event_types: events = [e for e in events if e.type in event_types] @@ -296,7 +297,7 @@ class ControlInterpreter(code.InteractiveConsole): return format(response, *STANDARD_OUTPUT)
@uses_settings - def run_command(self, command: str, config: 'stem.util.conf.Config', print_response: bool = False) -> Sequence[Tuple[str, int]]: + def run_command(self, command: str, config: stem.util.conf.Config, print_response: bool = False) -> str: """ Runs the given command. Requests starting with a '/' are special commands to the interpreter, and anything else is sent to the control port. @@ -304,8 +305,7 @@ class ControlInterpreter(code.InteractiveConsole): :param str command: command to be processed :param bool print_response: prints the response to stdout if true
- :returns: **list** out output lines, each line being a list of - (msg, format) tuples + :returns: **str** output of the command
:raises: **stem.SocketClosed** if the control connection has been severed """ @@ -363,7 +363,7 @@ class ControlInterpreter(code.InteractiveConsole): output = console_output.getvalue().strip() else: try: - output = format(self._controller.msg(command).raw_content().strip(), *STANDARD_OUTPUT) + output = format(str(self._controller.msg(command).raw_content()).strip(), *STANDARD_OUTPUT) except stem.ControllerError as exc: if isinstance(exc, stem.SocketClosed): raise diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py index 81c76d34..3a206c35 100644 --- a/stem/interpreter/help.py +++ b/stem/interpreter/help.py @@ -7,6 +7,11 @@ Provides our /help responses.
import functools
+import stem.control +import stem.util.conf + +from stem.util.term import format + from stem.interpreter import ( STANDARD_OUTPUT, BOLD_OUTPUT, @@ -15,10 +20,8 @@ from stem.interpreter import ( uses_settings, )
-from stem.util.term import format -
-def response(controller: 'stem.control.Controller', arg: str) -> str: +def response(controller: stem.control.Controller, arg: str) -> str: """ Provides our /help response.
@@ -33,7 +36,7 @@ def response(controller: 'stem.control.Controller', arg: str) -> str: return _response(controller, _normalize(arg))
-def _normalize(arg) -> str: +def _normalize(arg: str) -> str: arg = arg.upper()
# If there's multiple arguments then just take the first. This is @@ -52,7 +55,7 @@ def _normalize(arg) -> str:
@functools.lru_cache() @uses_settings -def _response(controller: 'stem.control.Controller', arg: str, config: 'stem.util.conf.Config') -> str: +def _response(controller: stem.control.Controller, arg: str, config: stem.util.conf.Config) -> str: if not arg: return _general_help()
diff --git a/stem/manual.py b/stem/manual.py index e28e0e6f..9bc10b85 100644 --- a/stem/manual.py +++ b/stem/manual.py @@ -61,9 +61,10 @@ import stem.util import stem.util.conf import stem.util.enum import stem.util.log +import stem.util.str_tools import stem.util.system
-from typing import Any, Dict, Mapping, Optional, Sequence, TextIO, Tuple, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union
Category = stem.util.enum.Enum('GENERAL', 'CLIENT', 'RELAY', 'DIRECTORY', 'AUTHORITY', 'HIDDEN_SERVICE', 'DENIAL_OF_SERVICE', 'TESTING', 'UNKNOWN') GITWEB_MANUAL_URL = 'https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt' @@ -111,7 +112,7 @@ class SchemaMismatch(IOError): self.supported_schemas = supported_schemas
-def query(query: str, *param: str) -> 'sqlite3.Cursor': +def query(query: str, *param: str) -> 'sqlite3.Cursor': # type: ignore """ Performs the given query on our sqlite manual cache. This database should be treated as being read-only. File permissions generally enforce this, and @@ -182,7 +183,7 @@ class ConfigOption(object):
@functools.lru_cache() -def _config(lowercase: bool = True) -> Dict[str, Union[Sequence[str], str]]: +def _config(lowercase: bool = True) -> Dict[str, Union[List[str], str]]: """ Provides a dictionary for our settings.cfg. This has a couple categories...
@@ -264,7 +265,7 @@ def is_important(option: str) -> bool: return option.lower() in _config()['manual.important']
-def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None: +def download_man_page(path: Optional[str] = None, file_handle: Optional[IO[bytes]] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None: """ Downloads tor's latest man page from `gitweb.torproject.org https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt`_. This method is @@ -303,7 +304,7 @@ def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO] if not os.path.exists(manual_path): raise OSError('no man page was generated') except stem.util.system.CallError as exc: - raise IOError("Unable to run '%s': %s" % (exc.command, exc.stderr)) + raise IOError("Unable to run '%s': %s" % (exc.command, stem.util.str_tools._to_unicode(exc.stderr)))
if path: try: @@ -349,7 +350,7 @@ class Manual(object): :var str stem_commit: stem commit to cache this manual information """
- def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, str]) -> None: + def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, 'stem.manual.ConfigOption']) -> None: self.name = name self.synopsis = synopsis self.description = description @@ -449,7 +450,8 @@ class Manual(object): except OSError as exc: raise IOError("Unable to run '%s': %s" % (man_cmd, exc))
- categories, config_options = _get_categories(man_output), collections.OrderedDict() + categories = _get_categories(man_output) + config_options = collections.OrderedDict() # type: collections.OrderedDict[str, stem.manual.ConfigOption]
for category_header, category_enum in CATEGORY_SECTIONS.items(): _add_config_options(config_options, category_enum, categories.get(category_header, [])) @@ -561,7 +563,7 @@ class Manual(object): return not self == other
-def _get_categories(content: str) -> Dict[str, str]: +def _get_categories(content: Sequence[str]) -> Dict[str, List[str]]: """ The man page is headers followed by an indented section. First pass gets the mapping of category titles to their lines. @@ -576,7 +578,8 @@ def _get_categories(content: str) -> Dict[str, str]: content = content[:-1]
categories = collections.OrderedDict() - category, lines = None, [] + category = None + lines = [] # type: List[str]
for line in content: # replace non-ascii characters @@ -607,7 +610,7 @@ def _get_categories(content: str) -> Dict[str, str]: return categories
-def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]: +def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, str]: """ Parses the commandline argument and signal sections. These are options followed by an indented description. For example... @@ -624,7 +627,8 @@ def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]] ignoring those. """
- options, last_arg = collections.OrderedDict(), None + options = collections.OrderedDict() # type: collections.OrderedDict[str, List[str]] + last_arg = None
for line in lines: if line == ' Note': @@ -637,7 +641,7 @@ def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]] return dict([(arg, ' '.join(desc_lines)) for arg, desc_lines in options.items() if desc_lines])
-def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None: +def _add_config_options(config_options: Dict[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None: """ Parses a section of tor configuration options. These have usage information, followed by an indented description. For instance... @@ -655,7 +659,7 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'] since that platform lacks getrlimit(). (Default: 1000) """
- def add_option(title: str, description: str) -> None: + def add_option(title: str, description: List[str]) -> None: if 'PER INSTANCE OPTIONS' in title: return # skip, unfortunately amid the options
@@ -669,7 +673,7 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'] add_option(subtitle, description) else: name, usage = title.split(' ', 1) if ' ' in title else (title, '') - summary = _config().get('manual.summary.%s' % name.lower(), '') + summary = str(_config().get('manual.summary.%s' % name.lower(), '')) config_options[name] = ConfigOption(name, category, usage, summary, _join_lines(description).strip())
# Remove the section's description by finding the sentence the section @@ -681,7 +685,8 @@ def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'] lines = lines[max(end_indices):] # trim to the description paragrah lines = lines[lines.index(''):] # drop the paragraph
- last_title, description = None, [] + last_title = None + description = [] # type: List[str]
for line in lines: if line and not line.startswith(' '): @@ -704,7 +709,7 @@ def _join_lines(lines: Sequence[str]) -> str: Simple join, except we want empty lines to still provide a newline. """
- result = [] + result = [] # type: List[str]
for line in lines: if not line: diff --git a/stem/process.py b/stem/process.py index bfab4967..3c7688a5 100644 --- a/stem/process.py +++ b/stem/process.py @@ -29,7 +29,7 @@ import stem.util.str_tools import stem.util.system import stem.version
-from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union
NO_TORRC = '<no torrc>' DEFAULT_INIT_TIMEOUT = 90 @@ -199,7 +199,7 @@ def launch_tor(tor_cmd: str = 'tor', args: Optional[Sequence[str]] = None, torrc pass
-def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen: +def launch_tor_with_config(config: Dict[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen: """ Initializes a tor process, like :func:`~stem.process.launch_tor`, but with a customized configuration. This writes a temporary torrc to disk, launches @@ -260,7 +260,7 @@ def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_ break
if not has_stdout: - config['Log'].append('NOTICE stdout') + config['Log'] = list(config['Log']) + ['NOTICE stdout']
config_str = ''
diff --git a/stem/response/__init__.py b/stem/response/__init__.py index 4b1f9533..2f851389 100644 --- a/stem/response/__init__.py +++ b/stem/response/__init__.py @@ -38,7 +38,7 @@ import stem.socket import stem.util import stem.util.str_tools
-from typing import Any, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
__all__ = [ 'add_onion', @@ -123,7 +123,40 @@ def convert(response_type: str, message: 'stem.response.ControlMessage', **kwarg raise TypeError('Unsupported response type: %s' % response_type)
message.__class__ = response_class - message._parse_message(**kwargs) + message._parse_message(**kwargs) # type: ignore + + +# TODO: These aliases are for type hint compatability. We should refactor how +# message conversion is performed to avoid this headache. + +def _convert_to_single_line(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.SingleLineResponse': + stem.response.convert('SINGLELINE', message) + return message # type: ignore + + +def _convert_to_event(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.events.Event': + stem.response.convert('EVENT', message) + return message # type: ignore + + +def _convert_to_getinfo(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getinfo.GetInfoResponse': + stem.response.convert('GETINFO', message) + return message # type: ignore + + +def _convert_to_getconf(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.getconf.GetConfResponse': + stem.response.convert('GETCONF', message) + return message # type: ignore + + +def _convert_to_add_onion(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.add_onion.AddOnionResponse': + stem.response.convert('ADD_ONION', message) + return message # type: ignore + + +def _convert_to_mapaddress(message: 'stem.response.ControlMessage', **kwargs: Any) -> 'stem.response.mapaddress.MapAddressResponse': + stem.response.convert('MAPADDRESS', message) + return message # type: ignore
class ControlMessage(object): @@ -142,7 +175,7 @@ class ControlMessage(object): """
@staticmethod - def from_str(content: str, msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage': + def from_str(content: Union[str, bytes], msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage': """ Provides a ControlMessage for the given content.
@@ -160,28 +193,35 @@ class ControlMessage(object): :returns: stem.response.ControlMessage instance """
+ if isinstance(content, str): + content = stem.util.str_tools._to_bytes(content) + if normalize: - if not content.endswith('\n'): - content += '\n' + if not content.endswith(b'\n'): + content += b'\n'
- content = re.sub('([\r]?)\n', '\r\n', content) + content = re.sub(b'([\r]?)\n', b'\r\n', content)
- msg = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(content)), arrived_at = kwargs.pop('arrived_at', None)) + msg = stem.socket.recv_message(io.BytesIO(content), arrived_at = kwargs.pop('arrived_at', None))
if msg_type is not None: convert(msg_type, msg, **kwargs)
return msg
- def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[int] = None) -> None: + def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[float] = None) -> None: if not parsed_content: raise ValueError("ControlMessages can't be empty")
- self.arrived_at = arrived_at if arrived_at else int(time.time()) + # TODO: Change arrived_at to a float (can't yet because it causes Event + # equality checks to fail - events include arrived_at within their hash + # whereas ControlMessages don't). + + self.arrived_at = int(arrived_at if arrived_at else time.time())
self._parsed_content = parsed_content self._raw_content = raw_content - self._str = None + self._str = None # type: Optional[str] self._hash = stem.util._hash_attr(self, '_raw_content')
def is_ok(self) -> bool: @@ -197,7 +237,12 @@ class ControlMessage(object):
return False
- def content(self, get_bytes: bool = False) -> Sequence[Tuple[str, str, bytes]]: + # TODO: drop this alias when we provide better type support + + def _content_bytes(self) -> List[Tuple[str, str, bytes]]: + return self.content(get_bytes = True) # type: ignore + + def content(self, get_bytes: bool = False) -> List[Tuple[str, str, str]]: """ Provides the parsed message content. These are entries of the form...
@@ -234,9 +279,9 @@ class ControlMessage(object): if not get_bytes: return [(code, div, stem.util.str_tools._to_unicode(content)) for (code, div, content) in self._parsed_content] else: - return list(self._parsed_content) + return list(self._parsed_content) # type: ignore
- def raw_content(self, get_bytes: bytes = False) -> Union[str, bytes]: + def raw_content(self, get_bytes: bool = False) -> Union[str, bytes]: """ Provides the unparsed content read from the control socket.
@@ -253,6 +298,9 @@ class ControlMessage(object): else: return self._raw_content
+ def _parse_message(self) -> None: + raise NotImplementedError('Implemented by subclasses') + def __str__(self) -> str: """ Content of the message, stripped of status code and divider protocol @@ -288,9 +336,7 @@ class ControlMessage(object): """
for _, _, content in self._parsed_content: - content = stem.util.str_tools._to_unicode(content) - - yield ControlLine(content) + yield ControlLine(stem.util.str_tools._to_unicode(content))
def __len__(self) -> int: """ @@ -330,7 +376,7 @@ class ControlLine(str): """
def __new__(self, value: str) -> 'stem.response.ControlLine': - return str.__new__(self, value) + return str.__new__(self, value) # type: ignore
def __init__(self, value: str) -> None: self._remainder = value @@ -443,7 +489,12 @@ class ControlLine(str): with self._remainder_lock: next_entry, remainder = _parse_entry(self._remainder, quoted, escaped, False) self._remainder = remainder - return next_entry + return next_entry # type: ignore + + # TODO: drop this alias when we provide better type support + + def _pop_mapping_bytes(self, quoted: bool = False, escaped: bool = False) -> Tuple[str, bytes]: + return self.pop_mapping(quoted, escaped, get_bytes = True) # type: ignore
def pop_mapping(self, quoted: bool = False, escaped: bool = False, get_bytes: bool = False) -> Tuple[str, str]: """ @@ -479,7 +530,7 @@ class ControlLine(str):
next_entry, remainder = _parse_entry(remainder, quoted, escaped, get_bytes) self._remainder = remainder - return (key, next_entry) + return (key, next_entry) # type: ignore
def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tuple[Union[str, bytes], str]: @@ -532,15 +583,15 @@ def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tup # # https://stackoverflow.com/questions/14820429/how-do-i-decodestring-escape-in...
- next_entry = codecs.escape_decode(next_entry)[0] + next_entry = codecs.escape_decode(next_entry)[0] # type: ignore
if not get_bytes: next_entry = stem.util.str_tools._to_unicode(next_entry) # normalize back to str
if get_bytes: - next_entry = stem.util.str_tools._to_bytes(next_entry) - - return (next_entry, remainder.lstrip()) + return (stem.util.str_tools._to_bytes(next_entry), remainder.lstrip()) + else: + return (next_entry, remainder.lstrip())
def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]: @@ -566,7 +617,7 @@ def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]:
indices.append(quote_index)
- return tuple(indices) + return tuple(indices) # type: ignore
class SingleLineResponse(ControlMessage): @@ -604,4 +655,7 @@ class SingleLineResponse(ControlMessage): elif len(content) == 0: raise stem.ProtocolError('Received empty response') else: - self.code, _, self.message = content[0] + code, _, msg = content[0] + + self.code = stem.util.str_tools._to_unicode(code) + self.message = stem.util.str_tools._to_unicode(msg) diff --git a/stem/response/events.py b/stem/response/events.py index 0e112373..65419fe6 100644 --- a/stem/response/events.py +++ b/stem/response/events.py @@ -1,6 +1,9 @@ # Copyright 2012-2020, Damian Johnson and The Tor Project # See LICENSE for licensing information +# +# mypy: ignore-errors
+import datetime import io import re
@@ -12,7 +15,7 @@ import stem.util import stem.version
from stem.util import connection, log, str_tools, tor_tools -from typing import Any, Dict, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
# Matches keyword=value arguments. This can't be a simple "(.*)=(.*)" pattern # because some positional arguments, like circuit paths, can have an equal @@ -34,10 +37,13 @@ class Event(stem.response.ControlMessage): :var dict keyword_args: key/value arguments of the event """
- _POSITIONAL_ARGS = () # attribute names for recognized positional arguments - _KEYWORD_ARGS = {} # map of 'keyword => attribute' for recognized attributes - _QUOTED = () # positional arguments that are quoted - _OPTIONALLY_QUOTED = () # positional arguments that may or may not be quoted + # TODO: Replace metaprogramming with concrete implementations (to simplify type information) + # TODO: _QUOTED looks to be unused + + _POSITIONAL_ARGS = () # type: Tuple[str, ...] # attribute names for recognized positional arguments + _KEYWORD_ARGS = {} # type: Dict[str, str] # map of 'keyword => attribute' for recognized attributes + _QUOTED = () # type: Tuple[str, ...] # positional arguments that are quoted + _OPTIONALLY_QUOTED = () # type: Tuple[str, ...] # positional arguments that may or may not be quoted _SKIP_PARSING = False # skip parsing contents into our positional_args and keyword_args _VERSION_ADDED = stem.version.Version('0.1.1.1-alpha') # minimum version with control-spec V1 event support
@@ -46,13 +52,14 @@ class Event(stem.response.ControlMessage): raise stem.ProtocolError('Received a blank tor event. Events must at the very least have a type.')
self.type = str(self).split()[0] - self.positional_args = [] - self.keyword_args = {} + self.positional_args = [] # type: List[str] + self.keyword_args = {} # type: Dict[str, str]
# if we're a recognized event type then translate ourselves into that subclass
if self.type in EVENT_TYPE_TO_CLASS: self.__class__ = EVENT_TYPE_TO_CLASS[self.type] + self.__init__() # type: ignore
if not self._SKIP_PARSING: self._parse_standard_attr() @@ -123,7 +130,7 @@ class Event(stem.response.ControlMessage): for controller_attr_name, attr_name in self._KEYWORD_ARGS.items(): setattr(self, attr_name, self.keyword_args.get(controller_attr_name))
- def _iso_timestamp(self, timestamp: str) -> 'datetime.datetime': + def _iso_timestamp(self, timestamp: str) -> datetime.datetime: """ Parses an iso timestamp (ISOTime2Frac in the control-spec).
@@ -146,7 +153,7 @@ class Event(stem.response.ControlMessage): def _parse(self) -> None: pass
- def _log_if_unrecognized(self, attr: str, attr_enum: 'stem.util.enum.Enum') -> None: + def _log_if_unrecognized(self, attr: str, attr_enum: Union[stem.util.enum.Enum, Sequence[stem.util.enum.Enum]]) -> None: """ Checks if an attribute exists in a given enumeration, logging a message if it isn't. Attributes can either be for a string or collection of strings @@ -195,7 +202,15 @@ class AddrMapEvent(Event): 'EXPIRES': 'utc_expiry', 'CACHED': 'cached', } - _OPTIONALLY_QUOTED = ('expiry') + _OPTIONALLY_QUOTED = ('expiry',) + + def __init__(self): + self.hostname = None # type: Optional[str] + self.destination = None # type: Optional[str] + self.expiry = None # type: Optional[datetime.datetime] + self.error = None # type: Optional[str] + self.utc_expiry = None # type: Optional[datetime.datetime] + self.cached = None # type: Optional[bool]
def _parse(self) -> None: if self.destination == '<error>': @@ -235,6 +250,10 @@ class BandwidthEvent(Event):
_POSITIONAL_ARGS = ('read', 'written')
+ def __init__(self): + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + def _parse(self) -> None: if not self.read: raise stem.ProtocolError('BW event is missing its read value') @@ -278,6 +297,17 @@ class BuildTimeoutSetEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.2.7-alpha')
+ def __init__(self): + self.set_type = None # type: Optional[stem.TimeoutSetType] + self.total_times = None # type: Optional[int] + self.timeout = None # type: Optional[int] + self.xm = None # type: Optional[int] + self.alpha = None # type: Optional[float] + self.quantile = None # type: Optional[float] + self.timeout_rate = None # type: Optional[float] + self.close_timeout = None # type: Optional[int] + self.close_rate = None # type: Optional[float] + def _parse(self) -> None: # convert our integer and float parameters
@@ -347,6 +377,20 @@ class CircuitEvent(Event): 'SOCKS_PASSWORD': 'socks_password', }
+ def __init__(self): + self.id = None # type: Optional[str] + self.status = None # type: Optional[stem.CircStatus] + self.path = None # type: Optional[Tuple[Tuple[str, str], ...]] + self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]] + self.purpose = None # type: Optional[stem.CircPurpose] + self.hs_state = None # type: Optional[stem.HiddenServiceState] + self.rend_query = None # type: Optional[str] + self.created = None # type: Optional[datetime.datetime] + self.reason = None # type: Optional[stem.CircClosureReason] + self.remote_reason = None # type: Optional[stem.CircClosureReason] + self.socks_username = None # type: Optional[str] + self.socks_password = None # type: Optional[str] + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created) @@ -415,6 +459,18 @@ class CircMinorEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.3.11-alpha')
+ def __init__(self): + self.id = None # type: Optional[str] + self.event = None # type: Optional[stem.CircEvent] + self.path = None # type: Optional[Tuple[Tuple[str, str], ...]] + self.build_flags = None # type: Optional[Tuple[stem.CircBuildFlag, ...]] + self.purpose = None # type: Optional[stem.CircPurpose] + self.hs_state = None # type: Optional[stem.HiddenServiceState] + self.rend_query = None # type: Optional[str] + self.created = None # type: Optional[datetime.datetime] + self.old_purpose = None # type: Optional[stem.CircPurpose] + self.old_hs_state = None # type: Optional[stem.HiddenServiceState] + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created) @@ -451,6 +507,11 @@ class ClientsSeenEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.1.10-alpha')
+ def __init__(self): + self.start_time = None # type: Optional[datetime.datetime] + self.locales = None # type: Optional[Dict[str, int]] + self.ip_versions = None # type: Optional[Dict[str, int]] + def _parse(self) -> None: if self.start_time is not None: self.start_time = stem.util.str_tools._parse_timestamp(self.start_time) @@ -510,6 +571,10 @@ class ConfChangedEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.3.3-alpha')
+ def __init__(self): + self.changed = {} # type: Dict[str, List[str]] + self.unset = [] # type: List[str] + def _parse(self) -> None: self.changed = {} self.unset = [] @@ -541,6 +606,9 @@ class DescChangedEvent(Event):
_VERSION_ADDED = stem.version.Version('0.1.2.2-alpha')
+ def __init__(self): + pass +
class GuardEvent(Event): """ @@ -564,10 +632,14 @@ class GuardEvent(Event): _VERSION_ADDED = stem.version.Version('0.1.2.5-alpha') _POSITIONAL_ARGS = ('guard_type', 'endpoint', 'status')
- def _parse(self) -> None: - self.endpoint_fingerprint = None - self.endpoint_nickname = None + def __init__(self): + self.guard_type = None # type: Optional[stem.GuardType] + self.endpoint = None # type: Optional[str] + self.endpoint_fingerprint = None # type: Optional[str] + self.endpoint_nickname = None # type: Optional[str] + self.status = None # type: Optional[stem.GuardStatus]
+ def _parse(self) -> None: try: self.endpoint_fingerprint, self.endpoint_nickname = \ stem.control._parse_circ_entry(self.endpoint) @@ -611,10 +683,19 @@ class HSDescEvent(Event): _POSITIONAL_ARGS = ('action', 'address', 'authentication', 'directory', 'descriptor_id') _KEYWORD_ARGS = {'REASON': 'reason', 'REPLICA': 'replica', 'HSDIR_INDEX': 'index'}
- def _parse(self) -> None: - self.directory_fingerprint = None - self.directory_nickname = None + def __init__(self): + self.action = None # type: Optional[stem.HSDescAction] + self.address = None # type: Optional[str] + self.authentication = None # type: Optional[stem.HSAuth] + self.directory = None # type: Optional[str] + self.directory_fingerprint = None # type: Optional[str] + self.directory_nickname = None # type: Optional[str] + self.descriptor_id = None # type: Optional[str] + self.reason = None # type: Optional[stem.HSDescReason] + self.replica = None # type: Optional[int] + self.index = None # type: Optional[str]
+ def _parse(self) -> None: if self.directory != 'UNKNOWN': try: self.directory_fingerprint, self.directory_nickname = \ @@ -651,13 +732,18 @@ class HSDescContentEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.7.1-alpha') _POSITIONAL_ARGS = ('address', 'descriptor_id', 'directory')
+ def __init__(self): + self.address = None # type: Optional[str] + self.descriptor_id = None # type: Optional[str] + self.directory = None # type: Optional[str] + self.directory_fingerprint = None # type: Optional[str] + self.directory_nickname = None # type: Optional[str] + self.descriptor = None # type: Optional[stem.descriptor.hidden_service.HiddenServiceDescriptorV2] + def _parse(self) -> None: if self.address == 'UNKNOWN': self.address = None
- self.directory_fingerprint = None - self.directory_nickname = None - try: self.directory_fingerprint, self.directory_nickname = \ stem.control._parse_circ_entry(self.directory) @@ -687,6 +773,10 @@ class LogEvent(Event):
_SKIP_PARSING = True
+ def __init__(self): + self.runlevel = None # type: Optional[stem.Runlevel] + self.message = None # type: Optional[str] + def _parse(self) -> None: self.runlevel = self.type self._log_if_unrecognized('runlevel', stem.Runlevel) @@ -710,6 +800,9 @@ class NetworkStatusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
+ def __init__(self): + self.descriptors = None # type: Optional[List[stem.descriptor.router_status_entry.RouterStatusEntryV3]] + def _parse(self) -> None: content = str(self).lstrip('NS\n').rstrip('\nOK')
@@ -735,6 +828,9 @@ class NetworkLivenessEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.7.2-alpha') _POSITIONAL_ARGS = ('status',)
+ def __init__(self): + self.status = None # type: Optional[str] +
class NewConsensusEvent(Event): """ @@ -754,11 +850,14 @@ class NewConsensusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.1.13-alpha')
+ def __init__(self): + self.consensus_content = None # type: Optional[str] + self._parsed = None # type: List[stem.descriptor.router_status_entry.RouterStatusEntryV3] + def _parse(self) -> None: self.consensus_content = str(self).lstrip('NEWCONSENSUS\n').rstrip('\nOK') - self._parsed = None
- def entries(self) -> Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']: + def entries(self) -> List[stem.descriptor.router_status_entry.RouterStatusEntryV3]: """ Relay router status entries residing within this consensus.
@@ -774,7 +873,7 @@ class NewConsensusEvent(Event): entry_class = stem.descriptor.router_status_entry.RouterStatusEntryV3, ))
- return self._parsed + return list(self._parsed)
class NewDescEvent(Event): @@ -792,6 +891,9 @@ class NewDescEvent(Event): new descriptors """
+ def __init__(self): + self.relays = () # type: Tuple[Tuple[str, str], ...] + def _parse(self) -> None: self.relays = tuple([stem.control._parse_circ_entry(entry) for entry in str(self).split()[1:]])
@@ -833,12 +935,18 @@ class ORConnEvent(Event): 'ID': 'id', }
- def _parse(self) -> None: - self.endpoint_fingerprint = None - self.endpoint_nickname = None - self.endpoint_address = None - self.endpoint_port = None + def __init__(self): + self.id = None # type: Optional[str] + self.endpoint = None # type: Optional[str] + self.endpoint_fingerprint = None # type: Optional[str] + self.endpoint_nickname = None # type: Optional[str] + self.endpoint_address = None # type: Optional[str] + self.endpoint_port = None # type: Optional[int] + self.status = None # type: Optional[stem.ORStatus] + self.reason = None # type: Optional[stem.ORClosureReason] + self.circ_count = None # type: Optional[int]
+ def _parse(self) -> None: try: self.endpoint_fingerprint, self.endpoint_nickname = \ stem.control._parse_circ_entry(self.endpoint) @@ -887,6 +995,9 @@ class SignalEvent(Event): _POSITIONAL_ARGS = ('signal',) _VERSION_ADDED = stem.version.Version('0.2.3.1-alpha')
+ def __init__(self): + self.signal = None # type: Optional[stem.Signal] + def _parse(self) -> None: # log if we recieved an unrecognized signal expected_signals = ( @@ -919,6 +1030,12 @@ class StatusEvent(Event): _POSITIONAL_ARGS = ('runlevel', 'action') _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
+ def __init__(self): + self.status_type = None # type: Optional[stem.StatusType] + self.runlevel = None # type: Optional[stem.Runlevel] + self.action = None # type: Optional[str] + self.arguments = None # type: Optional[Dict[str, str]] + def _parse(self) -> None: if self.type == 'STATUS_GENERAL': self.status_type = stem.StatusType.GENERAL @@ -971,6 +1088,21 @@ class StreamEvent(Event): 'PURPOSE': 'purpose', }
+ def __init__(self): + self.id = None # type: Optional[str] + self.status = None # type: Optional[stem.StreamStatus] + self.circ_id = None # type: Optional[str] + self.target = None # type: Optional[str] + self.target_address = None # type: Optional[str] + self.target_port = None # type: Optional[int] + self.reason = None # type: Optional[stem.StreamClosureReason] + self.remote_reason = None # type: Optional[stem.StreamClosureReason] + self.source = None # type: Optional[stem.StreamSource] + self.source_addr = None # type: Optional[str] + self.source_address = None # type: Optional[str] + self.source_port = None # type: Optional[str] + self.purpose = None # type: Optional[stem.StreamPurpose] + def _parse(self) -> None: if self.target is None: raise stem.ProtocolError("STREAM event didn't have a target: %s" % self) @@ -1030,6 +1162,12 @@ class StreamBwEvent(Event): _POSITIONAL_ARGS = ('id', 'written', 'read', 'time') _VERSION_ADDED = stem.version.Version('0.1.2.8-beta')
+ def __init__(self): + self.id = None # type: Optional[str] + self.written = None # type: Optional[int] + self.read = None # type: Optional[int] + self.time = None # type: Optional[datetime.datetime] + def _parse(self) -> None: if not tor_tools.is_valid_stream_id(self.id): raise stem.ProtocolError("Stream IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) @@ -1063,6 +1201,12 @@ class TransportLaunchedEvent(Event): _POSITIONAL_ARGS = ('type', 'name', 'address', 'port') _VERSION_ADDED = stem.version.Version('0.2.5.0-alpha')
+ def __init__(self): + self.type = None # type: Optional[str] + self.name = None # type: Optional[str] + self.address = None # type: Optional[str] + self.port = None # type: Optional[int] + def _parse(self) -> None: if self.type not in ('server', 'client'): raise stem.ProtocolError("Transport type should either be 'server' or 'client': %s" % self) @@ -1105,6 +1249,12 @@ class ConnectionBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self): + self.id = None # type: Optional[str] + self.conn_type = None # type: Optional[stem.ConnectionType] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CONN_BW event is missing its id') @@ -1164,6 +1314,16 @@ class CircuitBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self): + self.id = None # type: Optional[str] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + self.delivered_read = None # type: Optional[int] + self.delivered_written = None # type: Optional[int] + self.overhead_read = None # type: Optional[int] + self.overhead_written = None # type: Optional[int] + self.time = None # type: Optional[datetime.datetime] + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CIRC_BW event is missing its id') @@ -1234,6 +1394,19 @@ class CellStatsEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self): + self.id = None # type: Optional[str] + self.inbound_queue = None # type: Optional[str] + self.inbound_connection = None # type: Optional[str] + self.inbound_added = None # type: Optional[Dict[str, int]] + self.inbound_removed = None # type: Optional[Dict[str, int]] + self.inbound_time = None # type: Optional[Dict[str, int]] + self.outbound_queue = None # type: Optional[str] + self.outbound_connection = None # type: Optional[str] + self.outbound_added = None # type: Optional[Dict[str, int]] + self.outbound_removed = None # type: Optional[Dict[str, int]] + self.outbound_time = None # type: Optional[Dict[str, int]] + def _parse(self) -> None: if self.id and not tor_tools.is_valid_circuit_id(self.id): raise stem.ProtocolError("Circuit IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) @@ -1280,6 +1453,13 @@ class TokenBucketEmptyEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
+ def __init__(self): + self.bucket = None # type: Optional[stem.TokenBucket] + self.id = None # type: Optional[str] + self.read = None # type: Optional[int] + self.written = None # type: Optional[int] + self.last_refill = None # type: Optional[int] + def _parse(self) -> None: if self.id and not tor_tools.is_valid_connection_id(self.id): raise stem.ProtocolError("Connection IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) diff --git a/stem/response/getconf.py b/stem/response/getconf.py index 7ba972ae..6c65c4ec 100644 --- a/stem/response/getconf.py +++ b/stem/response/getconf.py @@ -4,6 +4,8 @@ import stem.response import stem.socket
+from typing import Dict, List +
class GetConfResponse(stem.response.ControlMessage): """ @@ -23,7 +25,7 @@ class GetConfResponse(stem.response.ControlMessage): # 250-DataDirectory=/home/neena/.tor # 250 DirPort
- self.entries = {} + self.entries = {} # type: Dict[str, List[str]] remaining_lines = list(self)
if self.content() == [('250', ' ', 'OK')]: diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py index 7aebd70a..9d9da21b 100644 --- a/stem/response/getinfo.py +++ b/stem/response/getinfo.py @@ -4,7 +4,7 @@ import stem.response import stem.socket
-from typing import Sequence +from typing import Dict, Set
class GetInfoResponse(stem.response.ControlMessage): @@ -27,8 +27,8 @@ class GetInfoResponse(stem.response.ControlMessage): # . # 250 OK
- self.entries = {} - remaining_lines = [content for (code, div, content) in self.content(get_bytes = True)] + self.entries = {} # type: Dict[str, bytes] + remaining_lines = [content for (code, div, content) in self._content_bytes()]
if not self.is_ok() or not remaining_lines.pop() == b'OK': unrecognized_keywords = [] @@ -51,11 +51,11 @@ class GetInfoResponse(stem.response.ControlMessage):
while remaining_lines: try: - key, value = remaining_lines.pop(0).split(b'=', 1) + key_bytes, value = remaining_lines.pop(0).split(b'=', 1) except ValueError: raise stem.ProtocolError('GETINFO replies should only contain parameter=value mappings:\n%s' % self)
- key = stem.util.str_tools._to_unicode(key) + key = stem.util.str_tools._to_unicode(key_bytes)
# if the value is a multiline value then it *must* be of the form # '<key>=\n<value>' @@ -68,7 +68,7 @@ class GetInfoResponse(stem.response.ControlMessage):
self.entries[key] = value
- def _assert_matches(self, params: Sequence[str]) -> None: + def _assert_matches(self, params: Set[str]) -> None: """ Checks if we match a given set of parameters, and raise a ProtocolError if not.
diff --git a/stem/response/protocolinfo.py b/stem/response/protocolinfo.py index 330b165e..c1387fab 100644 --- a/stem/response/protocolinfo.py +++ b/stem/response/protocolinfo.py @@ -9,6 +9,7 @@ import stem.version import stem.util.str_tools
from stem.util import log +from typing import Tuple
class ProtocolInfoResponse(stem.response.ControlMessage): @@ -36,8 +37,8 @@ class ProtocolInfoResponse(stem.response.ControlMessage):
self.protocol_version = None self.tor_version = None - self.auth_methods = () - self.unknown_auth_methods = () + self.auth_methods = () # type: Tuple[stem.connection.AuthMethod, ...] + self.unknown_auth_methods = () # type: Tuple[str, ...] self.cookie_path = None
auth_methods, unknown_auth_methods = [], [] @@ -107,7 +108,7 @@ class ProtocolInfoResponse(stem.response.ControlMessage): # parse optional COOKIEFILE mapping (quoted and can have escapes)
if line.is_next_mapping('COOKIEFILE', True, True): - self.cookie_path = line.pop_mapping(True, True, get_bytes = True)[1].decode(sys.getfilesystemencoding()) + self.cookie_path = line._pop_mapping_bytes(True, True)[1].decode(sys.getfilesystemencoding()) self.cookie_path = stem.util.str_tools._to_unicode(self.cookie_path) # normalize back to str elif line_type == 'VERSION': # Line format: diff --git a/stem/socket.py b/stem/socket.py index 179ae16e..b5da4b78 100644 --- a/stem/socket.py +++ b/stem/socket.py @@ -80,7 +80,7 @@ import stem.util.str_tools
from stem.util import log from types import TracebackType -from typing import BinaryIO, Callable, Optional, Type +from typing import BinaryIO, Callable, List, Optional, Tuple, Type, Union, overload
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]') ERROR_MSG = 'Error while receiving a control message (%s): %s' @@ -96,7 +96,8 @@ class BaseSocket(object): """
def __init__(self) -> None: - self._socket, self._socket_file = None, None + self._socket = None # type: Optional[Union[socket.socket, ssl.SSLSocket]] + self._socket_file = None # type: Optional[BinaryIO] self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected
@@ -218,7 +219,7 @@ class BaseSocket(object): if is_change: self._close()
- def _send(self, message: str, handler: Callable[[socket.socket, BinaryIO, str], None]) -> None: + def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None: """ Send message in a thread safe manner. Handler is expected to be of the form...
@@ -242,7 +243,15 @@ class BaseSocket(object):
raise
- def _recv(self, handler: Callable[[socket.socket, BinaryIO], None]) -> bytes: + @overload + def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes: + ... + + @overload + def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage: + ... + + def _recv(self, handler): """ Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -317,7 +326,7 @@ class BaseSocket(object):
pass
- def _make_socket(self) -> socket.socket: + def _make_socket(self) -> Union[socket.socket, ssl.SSLSocket]: """ Constructs and connects new socket. This is implemented by subclasses.
@@ -362,7 +371,7 @@ class RelaySocket(BaseSocket): if connect: self.connect()
- def send(self, message: str) -> None: + def send(self, message: Union[str, bytes]) -> None: """ Sends a message to the relay's ORPort.
@@ -389,26 +398,26 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- def wrapped_recv(s: socket.socket, sf: BinaryIO) -> bytes: + def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes: if timeout is None: - return s.recv() + return s.recv(1024) else: - s.setblocking(0) + s.setblocking(False) s.settimeout(timeout)
try: - return s.recv() + return s.recv(1024) except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError): return None finally: - s.setblocking(1) + s.setblocking(True)
return self._recv(wrapped_recv)
def is_localhost(self) -> bool: return self.address == '127.0.0.1'
- def _make_socket(self) -> socket.socket: + def _make_socket(self) -> ssl.SSLSocket: try: relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) relay_socket.connect((self.address, self.port)) @@ -430,7 +439,7 @@ class ControlSocket(BaseSocket): def __init__(self) -> None: super(ControlSocket, self).__init__()
- def send(self, message: str) -> None: + def send(self, message: Union[bytes, str]) -> None: """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -536,7 +545,7 @@ class ControlSocketFile(ControlSocket): raise stem.SocketError(exc)
-def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> None: +def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = False) -> None: """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -568,6 +577,8 @@ def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> Non * :class:`stem.SocketClosed` if the socket is known to be shut down """
+ message = stem.util.str_tools._to_unicode(message) + if not raw: message = send_formatting(message)
@@ -579,7 +590,7 @@ def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> Non log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file: BinaryIO, message: str) -> None: +def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None: try: socket_file.write(stem.util.str_tools._to_bytes(message)) socket_file.flush() @@ -618,7 +629,9 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> a complete message """
- parsed_content, raw_content, first_line = None, None, True + parsed_content = [] # type: List[Tuple[str, str, bytes]] + raw_content = bytearray() + first_line = True
while True: try: @@ -649,10 +662,10 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> log.info(ERROR_MSG % ('SocketClosed', 'empty socket content')) raise stem.SocketClosed('Received empty socket content.') elif not MESSAGE_PREFIX.match(line): - log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('Badly formatted reply line: beginning is malformed') elif not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line))) + log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line.decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF')
status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content @@ -691,11 +704,11 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> line = control_file.readline() raw_content += line except socket.error as exc: - log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content))))) + log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8'))))) raise stem.SocketClosed(exc)
if not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content)))) + log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content).decode('utf-8')))) raise stem.ProtocolError('All lines should end with CRLF') elif line == b'.\r\n': break # data block termination @@ -722,7 +735,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def send_formatting(message: str) -> None: +def send_formatting(message: str) -> str: """ Performs the formatting expected from sent control messages. For more information see the :func:`~stem.socket.send_message` function. diff --git a/stem/util/__init__.py b/stem/util/__init__.py index 050f6c91..498234cd 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -80,13 +80,15 @@ def datetime_to_unix(timestamp: 'datetime.datetime') -> float: return (timestamp - datetime.datetime(1970, 1, 1)).total_seconds()
-def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes: +def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes: # type: ignore """ Normalizes X25509 and ED25519 keys into their public key bytes. """
- if isinstance(key, (bytes, str)): + if isinstance(key, bytes): return key + elif isinstance(key, str): + return key.encode('utf-8')
try: from cryptography.hazmat.primitives import serialization diff --git a/stem/util/conf.py b/stem/util/conf.py index 37d1c5f4..1fd31fd0 100644 --- a/stem/util/conf.py +++ b/stem/util/conf.py @@ -162,8 +162,10 @@ import inspect import os import threading
+import stem.util.enum + from stem.util import log -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union
CONFS = {} # mapping of identifier to singleton instances of configs
@@ -186,10 +188,10 @@ class _SyncListener(object): if interceptor_value: new_value = interceptor_value
- self.config_dict[key] = new_value + self.config_dict[key] = new_value # type: ignore
-def config_dict(handle: str, conf_mappings: Mapping[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Mapping[str, Any]: +def config_dict(handle: str, conf_mappings: Dict[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Dict[str, Any]: """ Makes a dictionary that stays synchronized with a configuration.
@@ -308,7 +310,7 @@ def parse_enum(key: str, value: str, enumeration: 'stem.util.enum.Enum') -> Any: return parse_enum_csv(key, value, enumeration, 1)[0]
-def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> Sequence[Any]: +def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> List[Any]: """ Parses a given value as being a comma separated listing of enumeration keys, returning the corresponding enumeration values. This is intended to be a @@ -449,15 +451,15 @@ class Config(object): """
def __init__(self) -> None: - self._path = None # location we last loaded from or saved to - self._contents = collections.OrderedDict() # configuration key/value pairs - self._listeners = [] # functors to be notified of config changes + self._path = None # type: Optional[str] # location we last loaded from or saved to + self._contents = collections.OrderedDict() # type: Dict[str, Any] # configuration key/value pairs + self._listeners = [] # type: List[Callable[['stem.util.conf.Config', str], Any]] # functors to be notified of config changes
# used for accessing _contents self._contents_lock = threading.RLock()
# keys that have been requested (used to provide unused config contents) - self._requested_keys = set() + self._requested_keys = set() # type: Set[str]
# flag to support lazy loading in uses_settings() self._settings_loaded = False @@ -577,7 +579,7 @@ class Config(object): self._contents.clear() self._requested_keys = set()
- def add_listener(self, listener: Callable[[str, Any], Any], backfill: bool = True) -> None: + def add_listener(self, listener: Callable[['stem.util.conf.Config', str], Any], backfill: bool = True) -> None: """ Registers the function to be notified of configuration updates. Listeners are expected to be functors which accept (config, key). @@ -600,7 +602,7 @@ class Config(object):
self._listeners = []
- def keys(self) -> Sequence[str]: + def keys(self) -> List[str]: """ Provides all keys in the currently loaded configuration.
@@ -609,7 +611,7 @@ class Config(object):
return list(self._contents.keys())
- def unused_keys(self) -> Sequence[str]: + def unused_keys(self) -> Set[str]: """ Provides the configuration keys that have never been provided to a caller via :func:`~stem.util.conf.config_dict` or the @@ -740,7 +742,7 @@ class Config(object):
return val
- def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, Sequence[str]]: + def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, List[str]]: """ This provides the current value associated with a given key.
diff --git a/stem/util/connection.py b/stem/util/connection.py index 2f815a46..21745c43 100644 --- a/stem/util/connection.py +++ b/stem/util/connection.py @@ -65,7 +65,7 @@ import stem.util.proc import stem.util.system
from stem.util import conf, enum, log, str_tools -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union
# Connection resolution is risky to log about since it's highly likely to # contain sensitive information. That said, it's also difficult to get right in @@ -158,15 +158,15 @@ class Connection(collections.namedtuple('Connection', ['local_address', 'local_p """
-def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = None) -> bytes: +def download(url: str, timeout: Optional[float] = None, retries: Optional[int] = None) -> bytes: """ Download from the given url.
.. versionadded:: 1.8.0
:param str url: uncompressed url to download from - :param int timeout: timeout when connection becomes idle, no timeout applied - if **None** + :param float timeout: timeout when connection becomes idle, no timeout + applied if **None** :param int retires: maximum attempts to impose
:returns: **bytes** content of the given url @@ -186,17 +186,17 @@ def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = N except socket.timeout as exc: raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout) except: - exc, stacktrace = sys.exc_info()[1:3] + exception, stacktrace = sys.exc_info()[1:3]
if timeout is not None: timeout -= time.time() - start_time
if retries > 0 and (timeout is None or timeout > 0): - log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exc)) + log.debug('Failed to download from %s (%i retries remaining): %s' % (url, retries, exception)) return download(url, timeout, retries - 1) else: - log.debug('Failed to download from %s: %s' % (url, exc)) - raise stem.DownloadFailed(url, exc, stacktrace) + log.debug('Failed to download from %s: %s' % (url, exception)) + raise stem.DownloadFailed(url, exception, stacktrace)
def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, process_pid: Optional[int] = None, process_name: Optional[str] = None) -> Sequence['stem.util.connection.Connection']: @@ -254,7 +254,7 @@ def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, raise ValueError('Process pid was non-numeric: %s' % process_pid)
if process_pid is None: - all_pids = stem.util.system.pid_by_name(process_name, True) + all_pids = stem.util.system.pid_by_name(process_name, True) # type: List[int] # type: ignore
if len(all_pids) == 0: if resolver in (Resolver.NETSTAT_WINDOWS, Resolver.PROC, Resolver.BSD_PROCSTAT): @@ -289,7 +289,7 @@ def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, connections = [] resolver_regex = re.compile(resolver_regex_str)
- def _parse_address_str(addr_type: str, addr_str: str, line: str) -> str: + def _parse_address_str(addr_type: str, addr_str: str, line: str) -> Tuple[str, int]: addr, port = addr_str.rsplit(':', 1)
if not is_valid_ipv4_address(addr) and not is_valid_ipv6_address(addr, allow_brackets = True): @@ -524,8 +524,15 @@ def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_ze :returns: **True** if input is an integer and within the valid port range, **False** otherwise """
+ if isinstance(entry, (tuple, list)): + for port in entry: + if not is_valid_port(port, allow_zero): + return False + + return True + try: - value = int(entry) + value = int(entry) # type: ignore
if str(value) != str(entry): return False # invalid leading char, e.g. space or zero @@ -534,14 +541,7 @@ def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_ze else: return value > 0 and value < 65536 except TypeError: - if isinstance(entry, (tuple, list)): - for port in entry: - if not is_valid_port(port, allow_zero): - return False - - return True - else: - return False + return False except ValueError: return False
@@ -621,6 +621,9 @@ def expand_ipv6_address(address: str) -> str: :raises: **ValueError** if the address can't be expanded due to being malformed """
+ if isinstance(address, bytes): + address = str_tools._to_unicode(address) + if not is_valid_ipv6_address(address): raise ValueError("'%s' isn't a valid IPv6 address" % address)
diff --git a/stem/util/enum.py b/stem/util/enum.py index b70d29f4..719a4c06 100644 --- a/stem/util/enum.py +++ b/stem/util/enum.py @@ -40,10 +40,10 @@ constructed as simple type listings... +- __iter__ - iterator over our enum keys """
-from typing import Iterator, Sequence +from typing import Any, Iterator, List, Sequence, Tuple, Union
-def UppercaseEnum(*args: str) -> 'stem.util.enum.Enum': +def UppercaseEnum(*args: str) -> 'Enum': """ Provides an :class:`~stem.util.enum.Enum` instance where the values are identical to the keys. Since the keys are uppercase by convention this means @@ -69,14 +69,15 @@ class Enum(object): Basic enumeration. """
- def __init__(self, *args: str) -> None: + def __init__(self, *args: Union[str, Tuple[str, Any]]) -> None: from stem.util.str_tools import _to_camel_case
# ordered listings of our keys and values - keys, values = [], [] + keys = [] # type: List[str] + values = [] # type: List[Any]
for entry in args: - if isinstance(entry, (bytes, str)): + if isinstance(entry, str): key, val = entry, _to_camel_case(entry) elif isinstance(entry, tuple) and len(entry) == 2: key, val = entry @@ -99,11 +100,11 @@ class Enum(object):
return list(self._keys)
- def index_of(self, value: str) -> int: + def index_of(self, value: Any) -> int: """ Provides the index of the given value in the collection.
- :param str value: entry to be looked up + :param object value: entry to be looked up
:returns: **int** index of the given entry
@@ -112,11 +113,11 @@ class Enum(object):
return self._values.index(value)
- def next(self, value: str) -> str: + def next(self, value: Any) -> Any: """ Provides the next enumeration after the given value.
- :param str value: enumeration for which to get the next entry + :param object value: enumeration for which to get the next entry
:returns: enum value following the given entry
@@ -129,11 +130,11 @@ class Enum(object): next_index = (self._values.index(value) + 1) % len(self._values) return self._values[next_index]
- def previous(self, value: str) -> str: + def previous(self, value: Any) -> Any: """ Provides the previous enumeration before the given value.
- :param str value: enumeration for which to get the previous entry + :param object value: enumeration for which to get the previous entry
:returns: enum value proceeding the given entry
@@ -146,13 +147,13 @@ class Enum(object): prev_index = (self._values.index(value) - 1) % len(self._values) return self._values[prev_index]
- def __getitem__(self, item: str) -> str: + def __getitem__(self, item: str) -> Any: """ Provides the values for the given key.
- :param str item: key to be looked up + :param str item: key to look up
- :returns: **str** with the value for the given key + :returns: value for the given key
:raises: **ValueError** if the key doesn't exist """ @@ -163,7 +164,7 @@ class Enum(object): keys = ', '.join(self.keys()) raise ValueError("'%s' isn't among our enumeration keys, which includes: %s" % (item, keys))
- def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Any]: """ Provides an ordered listing of the enums in this set. """ diff --git a/stem/util/log.py b/stem/util/log.py index 940469a3..404249a7 100644 --- a/stem/util/log.py +++ b/stem/util/log.py @@ -172,7 +172,7 @@ def log(runlevel: 'stem.util.log.Runlevel', message: str) -> None: LOGGER.log(LOG_VALUES[runlevel], message)
-def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> None: +def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> bool: """ Logs a message at the given runlevel. If a message with this ID has already been logged then this is a no-op. @@ -189,6 +189,7 @@ def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) else: DEDUPLICATION_MESSAGE_IDS.add(message_id) log(runlevel, message) + return True
# shorter aliases for logging at a runlevel
diff --git a/stem/util/proc.py b/stem/util/proc.py index 10f2ae60..e180bb66 100644 --- a/stem/util/proc.py +++ b/stem/util/proc.py @@ -56,7 +56,7 @@ import stem.util.enum import stem.util.str_tools
from stem.util import log -from typing import Any, Mapping, Optional, Sequence, Set, Tuple, Type +from typing import Any, Mapping, Optional, Sequence, Set, Tuple
try: # unavailable on windows (#19823) @@ -233,7 +233,7 @@ def memory_usage(pid: int) -> Tuple[int, int]: raise exc
-def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]: +def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[str]: """ Provides process specific information. See the :data:`~stem.util.proc.Stat` enum for valid options. @@ -290,7 +290,7 @@ def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]: results.append(str(float(stat_comp[14]) / CLOCK_TICKS)) elif stat_type == Stat.START_TIME: if pid == 0: - return system_start_time() + results.append(str(system_start_time())) else: # According to documentation, starttime is in field 21 and the unit is # jiffies (clock ticks). We divide it for clock ticks, then add the @@ -452,7 +452,7 @@ def _inodes_for_sockets(pid: int) -> Set[bytes]: return inodes
-def _unpack_addr(addr: str) -> str: +def _unpack_addr(addr: bytes) -> str: """ Translates an address entry in the /proc/net/* contents to a human readable form (`reference http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html`_, @@ -554,7 +554,7 @@ def _get_lines(file_path: str, line_prefixes: Sequence[str], parameter: str) -> raise
-def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None: +def _log_runtime(parameter: str, proc_location: str, start_time: float) -> None: """ Logs a message indicating a successful proc query.
@@ -567,7 +567,7 @@ def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None: log.debug('proc call (%s): %s (runtime: %0.4f)' % (parameter, proc_location, runtime))
-def _log_failure(parameter: str, exc: Type[Exception]) -> None: +def _log_failure(parameter: str, exc: BaseException) -> None: """ Logs a message indicating that the proc query failed.
diff --git a/stem/util/str_tools.py b/stem/util/str_tools.py index c606906a..a0bef734 100644 --- a/stem/util/str_tools.py +++ b/stem/util/str_tools.py @@ -26,7 +26,7 @@ import sys import stem.util import stem.util.enum
-from typing import Sequence, Tuple, Union +from typing import List, Sequence, Tuple, Union, overload
# label conversion tuples of the form... # (bits / bytes / seconds, short label, long label) @@ -73,7 +73,7 @@ def _to_bytes(msg: Union[str, bytes]) -> bytes: """
if isinstance(msg, str): - return codecs.latin_1_encode(msg, 'replace')[0] + return codecs.latin_1_encode(msg, 'replace')[0] # type: ignore else: return msg
@@ -95,7 +95,7 @@ def _to_unicode(msg: Union[str, bytes]) -> str: return msg
-def _decode_b64(msg: Union[str, bytes]) -> str: +def _decode_b64(msg: bytes) -> bytes: """ Base64 decode, without padding concerns. """ @@ -103,7 +103,7 @@ def _decode_b64(msg: Union[str, bytes]) -> str: missing_padding = len(msg) % 4 padding_chr = b'=' if isinstance(msg, bytes) else '='
- return base64.b64decode(msg + padding_chr * missing_padding) + return base64.b64decode(msg + (padding_chr * missing_padding))
def _to_int(msg: Union[str, bytes]) -> int: @@ -150,7 +150,17 @@ def _to_camel_case(label: str, divider: str = '_', joiner: str = ' ') -> str: return joiner.join(words)
-def _split_by_length(msg: str, size: int) -> Sequence[str]: +@overload +def _split_by_length(msg: bytes, size: int) -> List[bytes]: + ... + + +@overload +def _split_by_length(msg: str, size: int) -> List[str]: + ... + + +def _split_by_length(msg, size): """ Splits a string into a list of strings up to the given size.
@@ -174,7 +184,7 @@ def _split_by_length(msg: str, size: int) -> Sequence[str]: Ending = stem.util.enum.Enum('ELLIPSE', 'HYPHEN')
-def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> str: +def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> Union[str, Tuple[str, str]]: """ Shortens a string to a given length.
@@ -381,7 +391,7 @@ def time_labels(seconds: int, is_long: bool = False) -> Sequence[str]: for count_per_unit, _, _ in TIME_UNITS: if abs(seconds) >= count_per_unit: time_labels.append(_get_label(TIME_UNITS, seconds, 0, is_long)) - seconds %= count_per_unit + seconds %= int(count_per_unit)
return time_labels
@@ -413,7 +423,7 @@ def short_time_label(seconds: int) -> str:
for amount, _, label in TIME_UNITS: count = int(seconds / amount) - seconds %= amount + seconds %= int(amount) time_comp[label.strip()] = count
label = '%02i:%02i' % (time_comp['minute'], time_comp['second']) @@ -471,7 +481,7 @@ def parse_short_time_label(label: str) -> int: raise ValueError('Non-numeric value in time entry: %s' % label)
-def _parse_timestamp(entry: str) -> 'datetime.datetime': +def _parse_timestamp(entry: str) -> datetime.datetime: """ Parses the date and time that in format like like...
@@ -535,7 +545,7 @@ def _parse_iso_timestamp(entry: str) -> 'datetime.datetime': return timestamp + datetime.timedelta(microseconds = int(microseconds))
-def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: bool, round: bool = False) -> str: +def _get_label(units: Sequence[Tuple[float, str, str]], count: int, decimal: int, is_long: bool, round: bool = False) -> str: """ Provides label corresponding to units of the highest significance in the provided set. This rounds down (ie, integer truncation after visible units). @@ -580,3 +590,5 @@ def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: b return count_label + long_label + ('s' if is_plural else '') else: return count_label + short_label + + raise ValueError('BUG: value should always be divisible by a unit (%s)' % str(units)) diff --git a/stem/util/system.py b/stem/util/system.py index 8a61b2b9..a5147976 100644 --- a/stem/util/system.py +++ b/stem/util/system.py @@ -82,7 +82,7 @@ import stem.util.str_tools
from stem import UNDEFINED from stem.util import log -from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, TextIO, Union +from typing import Any, BinaryIO, Callable, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
State = stem.util.enum.UppercaseEnum( 'PENDING', @@ -98,11 +98,11 @@ SIZE_RECURSES = { dict: lambda d: itertools.chain.from_iterable(d.items()), set: iter, frozenset: iter, -} +} # type: Dict[Type, Callable]
# Mapping of commands to if they're available or not.
-CMD_AVAILABLE_CACHE = {} +CMD_AVAILABLE_CACHE = {} # type: Dict[str, bool]
# An incomplete listing of commands provided by the shell. Expand this as # needed. Some noteworthy things about shell commands... @@ -186,11 +186,11 @@ class CallError(OSError): :var str command: command that was ran :var int exit_status: exit code of the process :var float runtime: time the command took to run - :var str stdout: stdout of the process - :var str stderr: stderr of the process + :var bytes stdout: stdout of the process + :var bytes stderr: stderr of the process """
- def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str) -> None: + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes) -> None: self.msg = msg self.command = command self.exit_status = exit_status @@ -211,7 +211,7 @@ class CallTimeoutError(CallError): :var float timeout: time we waited """
- def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str, timeout: float) -> None: + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: bytes, stderr: bytes, timeout: float) -> None: super(CallTimeoutError, self).__init__(msg, command, exit_status, runtime, stdout, stderr) self.timeout = timeout
@@ -242,8 +242,8 @@ class DaemonTask(object): self.result = None self.error = None
- self._process = None - self._pipe = None + self._process = None # type: Optional[multiprocessing.Process] + self._pipe = None # type: Optional[multiprocessing.connection.Connection]
if start: self.run() @@ -462,7 +462,7 @@ def is_running(command: Union[str, int, Sequence[str]]) -> bool: return None
-def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int: +def size_of(obj: Any, exclude: Optional[Collection[int]] = None) -> int: """ Provides the `approximate memory usage of an object https://code.activestate.com/recipes/577504/`_. This can recurse tuples, @@ -486,9 +486,9 @@ def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int: if platform.python_implementation() == 'PyPy': raise NotImplementedError('PyPy does not implement sys.getsizeof()')
- if exclude is None: - exclude = set() - elif id(obj) in exclude: + exclude = set(exclude) if exclude is not None else set() + + if id(obj) in exclude: return 0
try: @@ -548,7 +548,7 @@ def name_by_pid(pid: int) -> Optional[str]: return process_name
-def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, Sequence[int]]: +def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, List[int]]: """ Attempts to determine the process id for a running process, using...
@@ -996,10 +996,8 @@ def user(pid: int) -> Optional[str]: import pwd # only available on unix platforms
uid = stem.util.proc.uid(pid) - - if uid and uid.isdigit(): - return pwd.getpwuid(int(uid)).pw_name - except: + return pwd.getpwuid(uid).pw_name + except ImportError: pass
if is_available('ps'): @@ -1042,7 +1040,7 @@ def start_time(pid: str) -> Optional[float]: return None
-def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[str]: +def tail(target: Union[str, BinaryIO], lines: Optional[int] = None) -> Iterator[str]: """ Provides lines of a file starting with the end. For instance, 'tail -n 50 /tmp/my_log' could be done with... @@ -1061,8 +1059,8 @@ def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[st
if isinstance(target, str): with open(target, 'rb') as target_file: - for line in tail(target_file, lines): - yield line + for tail_line in tail(target_file, lines): + yield tail_line
return
@@ -1299,7 +1297,7 @@ def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_ex if timeout: while process.poll() is None: if time.time() - start_time > timeout: - raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, '', '', timeout) + raise CallTimeoutError("Process didn't finish after %0.1f seconds" % timeout, ' '.join(command_list), None, timeout, b'', b'', timeout)
time.sleep(0.001)
@@ -1313,11 +1311,11 @@ def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_ex trace_prefix = 'Received from system (%s)' % command
if stdout and stderr: - log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout, stderr)) + log.trace(trace_prefix + ', stdout:\n%s\nstderr:\n%s' % (stdout.decode('utf-8'), stderr.decode('utf-8'))) elif stdout: - log.trace(trace_prefix + ', stdout:\n%s' % stdout) + log.trace(trace_prefix + ', stdout:\n%s' % stdout.decode('utf-8')) elif stderr: - log.trace(trace_prefix + ', stderr:\n%s' % stderr) + log.trace(trace_prefix + ', stderr:\n%s' % stderr.decode('utf-8'))
exit_status = process.poll()
diff --git a/stem/util/term.py b/stem/util/term.py index acc52cad..862767c4 100644 --- a/stem/util/term.py +++ b/stem/util/term.py @@ -72,14 +72,14 @@ CSI = '\x1B[%sm' RESET = CSI % '0'
-def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> Optional[str]: +def encoding(*attrs: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> Optional[str]: """ Provides the ANSI escape sequence for these terminal color or attributes.
.. versionadded:: 1.5.0
- :param list attr: :data:`~stem.util.terminal.Color`, - :data:`~stem.util.terminal.BgColor`, or :data:`~stem.util.terminal.Attr` to + :param list attr: :data:`~stem.util.term.Color`, + :data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` to provide an ecoding for
:returns: **str** of the ANSI escape sequence, **None** no attributes are @@ -99,9 +99,11 @@ def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgCol
if term_encodings: return CSI % ';'.join(term_encodings) + else: + return None
-def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> str: +def format(msg: str, *attr: Union['stem.util.term.Color', 'stem.util.term.BgColor', 'stem.util.term.Attr']) -> str: """ Simple terminal text formatting using `ANSI escape sequences https://en.wikipedia.org/wiki/ANSI_escape_code#CSI_codes`_. @@ -125,12 +127,12 @@ def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.termina """
msg = stem.util.str_tools._to_unicode(msg) + attr = list(attr)
if DISABLE_COLOR_SUPPORT: return msg
if Attr.LINES in attr: - attr = list(attr) attr.remove(Attr.LINES) lines = [format(line, *attr) for line in msg.split('\n')] return '\n'.join(lines) diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py index 80de447e..34573450 100644 --- a/stem/util/test_tools.py +++ b/stem/util/test_tools.py @@ -44,7 +44,7 @@ import stem.util.conf import stem.util.enum import stem.util.system
-from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
CONFIG = stem.util.conf.config_dict('test', { 'pycodestyle.ignore': [], @@ -53,8 +53,8 @@ CONFIG = stem.util.conf.config_dict('test', { 'exclude_paths': [], })
-TEST_RUNTIMES = {} -ASYNC_TESTS = {} +TEST_RUNTIMES: Dict[str, float] = {} +ASYNC_TESTS: Dict[str, 'stem.util.test_tools.AsyncTest'] = {}
AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED') AsyncResult = collections.namedtuple('AsyncResult', 'type msg') @@ -147,11 +147,11 @@ class AsyncTest(object):
self.method = lambda test: self.result(test) # method that can be mixed into TestCases
- self._process = None - self._process_pipe = None + self._process = None # type: Optional[Union[threading.Thread, multiprocessing.Process]] + self._process_pipe = None # type: Optional[multiprocessing.connection.Connection] self._process_lock = threading.RLock()
- self._result = None + self._result = None # type: Optional[stem.util.test_tools.AsyncResult] self._status = AsyncStatus.PENDING
def run(self, *runner_args: Any, **kwargs: Any) -> None: @@ -194,9 +194,9 @@ class AsyncTest(object): self._process.start() self._status = AsyncStatus.RUNNING
- def pid(self) -> int: + def pid(self) -> Optional[int]: with self._process_lock: - return self._process.pid if (self._process and not self._threaded) else None + return getattr(self._process, 'pid', None)
def join(self) -> None: self.result(None) @@ -238,9 +238,9 @@ class TimedTestRunner(unittest.TextTestRunner): .. versionadded:: 1.6.0 """
- def run(self, test: 'unittest.TestCase') -> None: - for t in test._tests: - original_type = type(t) + def run(self, test: Union[unittest.TestCase, unittest.TestSuite]) -> unittest.TestResult: + for t in getattr(test, '_tests', ()): + original_type = type(t) # type: Any
class _TestWrapper(original_type): def run(self, result: Optional[Any] = None) -> Any: @@ -273,7 +273,7 @@ class TimedTestRunner(unittest.TextTestRunner): return super(TimedTestRunner, self).run(test)
-def test_runtimes() -> Mapping[str, float]: +def test_runtimes() -> Dict[str, float]: """ Provides the runtimes of tests executed through TimedTestRunners.
@@ -286,7 +286,7 @@ def test_runtimes() -> Mapping[str, float]: return dict(TEST_RUNTIMES)
-def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]: +def clean_orphaned_pyc(paths: Sequence[str]) -> List[str]: """ Deletes any file with a \*.pyc extention without a corresponding \*.py. This helps to address a common gotcha when deleting python files... @@ -302,7 +302,7 @@ def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]:
:param list paths: paths to search for orphaned pyc files
- :returns: list of absolute paths that were deleted + :returns: **list** of absolute paths that were deleted """
orphaned_pyc = [] @@ -366,7 +366,7 @@ def is_mypy_available() -> bool: return _module_exists('mypy.api')
-def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Mapping[str, 'stem.util.test_tools.Issue']: +def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Dict[str, List['stem.util.test_tools.Issue']]: """ Checks for stylistic issues that are an issue according to the parts of PEP8 we conform to. You can suppress pycodestyle issues by making a 'test' @@ -424,7 +424,7 @@ def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_e :returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances """
- issues = {} + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
ignore_rules = [] ignore_for_file = [] @@ -505,7 +505,7 @@ def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_e return issues
-def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']: +def pyflakes_issues(paths: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]: """ Performs static checks via pyflakes. False positives can be ignored via 'pyflakes.ignore' entries in our 'test' config. For instance... @@ -531,7 +531,7 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools. :returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances """
- issues = {} + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
if is_pyflakes_available(): import pyflakes.api @@ -539,19 +539,19 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.
class Reporter(pyflakes.reporter.Reporter): def __init__(self) -> None: - self._ignored_issues = {} + self._ignored_issues = {} # type: Dict[str, List[str]]
for line in CONFIG['pyflakes.ignore']: path, issue = line.split('=>') self._ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- def unexpectedError(self, filename: str, msg: str) -> None: + def unexpectedError(self, filename: str, msg: 'pyflakes.messages.Message') -> None: self._register_issue(filename, None, msg, None)
def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str) -> None: self._register_issue(filename, lineno, msg, text)
- def flake(self, msg: str) -> None: + def flake(self, msg: 'pyflakes.messages.Message') -> None: self._register_issue(msg.filename, msg.lineno, msg.message % msg.message_args, None)
def _register_issue(self, path: str, line_number: int, issue: str, line: str) -> None: @@ -569,7 +569,7 @@ def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools. return issues
-def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']: +def type_issues(args: Sequence[str]) -> Dict[str, List['stem.util.test_tools.Issue']]: """ Performs type checks via mypy. False positives can be ignored via 'mypy.ignore' entries in our 'test' config. For instance... @@ -578,23 +578,25 @@ def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issu
mypy.ignore stem/util/system.py => Incompatible types in assignment*
- :param list paths: paths to search for problems + :param list args: mypy commmandline arguments
:returns: dict of paths list of :class:`stem.util.test_tools.Issue` instances """
- issues = {} + issues = {} # type: Dict[str, List[stem.util.test_tools.Issue]]
if is_mypy_available(): import mypy.api
- ignored_issues = {} + ignored_issues = {} # type: Dict[str, List[str]]
for line in CONFIG['mypy.ignore']: path, issue = line.split('=>') ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- lines = mypy.api.run(paths)[0].splitlines() # mypy returns (report, errors, exit_status) + # mypy returns (report, errors, exit_status) + + lines = mypy.api.run(args)[0].splitlines() # type: ignore
for line in lines: # example: @@ -606,13 +608,13 @@ def type_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issu raise ValueError('Failed to parse mypy line: %s' % line)
path, line_number, _, issue = line.split(':', 3) - issue = issue.strip()
- if line_number.isdigit(): - line_number = int(line_number) - else: + if not line_number.isdigit(): raise ValueError('Malformed line number on: %s' % line)
+ issue = issue.strip() + line_number = int(line_number) + if _is_ignored(ignored_issues, path, issue): continue
@@ -660,16 +662,21 @@ def _python_files(paths: Sequence[str]) -> Iterator[str]:
def _is_ignored(config: Mapping[str, Sequence[str]], path: str, issue: str) -> bool: for ignored_path, ignored_issues in config.items(): - if path.endswith(ignored_path): - if issue in ignored_issues: - return True - - for prefix in [i[:1] for i in ignored_issues if i.endswith('*')]: - if issue.startswith(prefix): + if ignored_path == '*' or path.endswith(ignored_path): + for ignored_issue in ignored_issues: + if issue == ignored_issue: return True
- for suffix in [i[1:] for i in ignored_issues if i.startswith('*')]: - if issue.endswith(suffix): - return True + # TODO: try using glob module instead? + + if ignored_issue.startswith('*') and ignored_issue.endswith('*'): + if ignored_issue[1:-1] in issue: + return True # substring match + elif ignored_issue.startswith('*'): + if issue.endswith(ignored_issue[1:]): + return True # prefix match + elif ignored_issue.endswith('*'): + if issue.startswith(ignored_issue[:-1]): + return True # suffix match
return False diff --git a/stem/version.py b/stem/version.py index 8ec35293..6ef7c890 100644 --- a/stem/version.py +++ b/stem/version.py @@ -72,9 +72,9 @@ def get_system_tor_version(tor_cmd: str = 'tor') -> 'stem.version.Version':
if 'No such file or directory' in str(exc): if os.path.isabs(tor_cmd): - exc = "Unable to check tor's version. '%s' doesn't exist." % tor_cmd + raise IOError("Unable to check tor's version. '%s' doesn't exist." % tor_cmd) else: - exc = "Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd + raise IOError("Unable to run '%s'. Maybe tor isn't in your PATH?" % version_cmd)
raise IOError(exc)
@@ -132,13 +132,12 @@ class Version(object): version_parts = VERSION_PATTERN.match(version_str)
if version_parts: - major, minor, micro, patch, status, extra_str, _ = version_parts.groups() + major, minor, micro, patch_str, status, extra_str, _ = version_parts.groups()
# The patch and status matches are optional (may be None) and have an extra # proceeding period or dash if they exist. Stripping those off.
- if patch: - patch = int(patch[1:]) + patch = int(patch_str[1:]) if patch_str else None
if status: status = status[1:] @@ -166,7 +165,7 @@ class Version(object):
return self.version_str
- def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> Callable[[Any, Any], bool]: + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: """ Compares version ordering according to the spec. """ diff --git a/test/arguments.py b/test/arguments.py index d0f0dc3f..e06148c4 100644 --- a/test/arguments.py +++ b/test/arguments.py @@ -5,13 +5,14 @@ Commandline argument parsing for our test runner. """
-import collections import getopt
import stem.util.conf import stem.util.log import test
+from typing import Any, Dict, List, NamedTuple, Optional, Sequence + LOG_TYPE_ERROR = """\ '%s' isn't a logging runlevel, use one of the following instead: TRACE, DEBUG, INFO, NOTICE, WARN, ERROR @@ -23,138 +24,136 @@ CONFIG = stem.util.conf.config_dict('test', { 'target.torrc': {}, })
-DEFAULT_ARGS = { - 'run_unit': False, - 'run_integ': False, - 'specific_test': [], - 'exclude_test': [], - 'logging_runlevel': None, - 'logging_path': None, - 'tor_path': 'tor', - 'run_targets': [test.Target.RUN_OPEN], - 'attribute_targets': [], - 'quiet': False, - 'verbose': False, - 'print_help': False, -} - OPT = 'auit:l:qvh' OPT_EXPANDED = ['all', 'unit', 'integ', 'targets=', 'test=', 'exclude-test=', 'log=', 'log-file=', 'tor=', 'quiet', 'verbose', 'help']
-def parse(argv): - """ - Parses our arguments, providing a named tuple with their values. +class Arguments(NamedTuple): + run_unit: bool = False + run_integ: bool = False + specific_test: List[str] = [] + exclude_test: List[str] = [] + logging_runlevel: Optional[str] = None + logging_path: Optional[str] = None + tor_path: str = 'tor' + run_targets: List['test.Target'] = [test.Target.RUN_OPEN] + attribute_targets: List['test.Target'] = [] + quiet: bool = False + verbose: bool = False + print_help: bool = False + + @staticmethod + def parse(argv: Sequence[str]) -> 'test.arguments.Arguments': + """ + Parses our commandline arguments into this class. + + :param list argv: input arguments to be parsed + + :returns: :class:`test.arguments.Arguments` for this commandline input + + :raises: **ValueError** if we got an invalid argument + """
- :param list argv: input arguments to be parsed + args = {} # type: Dict[str, Any]
- :returns: a **named tuple** with our parsed arguments + try: + recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) # type: ignore
- :raises: **ValueError** if we got an invalid argument - """ + if unrecognized_args: + error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" + raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) + except Exception as exc: + raise ValueError('%s (for usage provide --help)' % exc)
- args = dict(DEFAULT_ARGS) - - try: - recognized_args, unrecognized_args = getopt.getopt(argv, OPT, OPT_EXPANDED) - - if unrecognized_args: - error_msg = "aren't recognized arguments" if len(unrecognized_args) > 1 else "isn't a recognized argument" - raise getopt.GetoptError("'%s' %s" % ("', '".join(unrecognized_args), error_msg)) - except Exception as exc: - raise ValueError('%s (for usage provide --help)' % exc) - - for opt, arg in recognized_args: - if opt in ('-a', '--all'): - args['run_unit'] = True - args['run_integ'] = True - elif opt in ('-u', '--unit'): - args['run_unit'] = True - elif opt in ('-i', '--integ'): - args['run_integ'] = True - elif opt in ('-t', '--targets'): - run_targets, attribute_targets = [], [] - - integ_targets = arg.split(',') - all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None] - - # validates the targets and split them into run and attribute targets - - if not integ_targets: - raise ValueError('No targets provided') - - for target in integ_targets: - if target not in test.Target: - raise ValueError('Invalid integration target: %s' % target) - elif target in all_run_targets: - run_targets.append(target) - else: - attribute_targets.append(target) - - # check if we were told to use all run targets - - if test.Target.RUN_ALL in attribute_targets: - attribute_targets.remove(test.Target.RUN_ALL) - run_targets = all_run_targets - - # if no RUN_* targets are provided then keep the default (otherwise we - # won't have any tests to run) - - if run_targets: - args['run_targets'] = run_targets - - args['attribute_targets'] = attribute_targets - elif opt == '--test': - args['specific_test'].append(crop_module_name(arg)) - elif opt == '--exclude-test': - args['exclude_test'].append(crop_module_name(arg)) - elif opt in ('-l', '--log'): - arg = arg.upper() - - if arg not in stem.util.log.LOG_VALUES: - raise ValueError(LOG_TYPE_ERROR % arg) - - args['logging_runlevel'] = arg - elif opt == '--log-file': - args['logging_path'] = arg - elif opt in ('--tor'): - args['tor_path'] = arg - elif opt in ('-q', '--quiet'): - args['quiet'] = True - elif opt in ('-v', '--verbose'): - args['verbose'] = True - elif opt in ('-h', '--help'): - args['print_help'] = True - - # translates our args dict into a named tuple - - Args = collections.namedtuple('Args', args.keys()) - return Args(**args) - - -def get_help(): - """ - Provides usage information, as provided by the '--help' argument. This - includes a listing of the valid integration targets. + for opt, arg in recognized_args: + if opt in ('-a', '--all'): + args['run_unit'] = True + args['run_integ'] = True + elif opt in ('-u', '--unit'): + args['run_unit'] = True + elif opt in ('-i', '--integ'): + args['run_integ'] = True + elif opt in ('-t', '--targets'): + run_targets, attribute_targets = [], []
- :returns: **str** with our usage information - """ + integ_targets = arg.split(',') + all_run_targets = [t for t in test.Target if CONFIG['target.torrc'].get(t) is not None] + + # validates the targets and split them into run and attribute targets + + if not integ_targets: + raise ValueError('No targets provided') + + for target in integ_targets: + if target not in test.Target: + raise ValueError('Invalid integration target: %s' % target) + elif target in all_run_targets: + run_targets.append(target) + else: + attribute_targets.append(target) + + # check if we were told to use all run targets + + if test.Target.RUN_ALL in attribute_targets: + attribute_targets.remove(test.Target.RUN_ALL) + run_targets = all_run_targets + + # if no RUN_* targets are provided then keep the default (otherwise we + # won't have any tests to run) + + if run_targets: + args['run_targets'] = run_targets + + args['attribute_targets'] = attribute_targets + elif opt == '--test': + args['specific_test'].append(crop_module_name(arg)) + elif opt == '--exclude-test': + args['exclude_test'].append(crop_module_name(arg)) + elif opt in ('-l', '--log'): + arg = arg.upper() + + if arg not in stem.util.log.LOG_VALUES: + raise ValueError(LOG_TYPE_ERROR % arg) + + args['logging_runlevel'] = arg + elif opt == '--log-file': + args['logging_path'] = arg + elif opt in ('--tor'): + args['tor_path'] = arg + elif opt in ('-q', '--quiet'): + args['quiet'] = True + elif opt in ('-v', '--verbose'): + args['verbose'] = True + elif opt in ('-h', '--help'): + args['print_help'] = True + + return Arguments(**args) + + @staticmethod + def get_help() -> str: + """ + Provides usage information, as provided by the '--help' argument. This + includes a listing of the valid integration targets. + + :returns: **str** with our usage information + """ + + help_msg = CONFIG['msg.help']
- help_msg = CONFIG['msg.help'] + # gets the longest target length so we can show the entries in columns
- # gets the longest target length so we can show the entries in columns - target_name_length = max(map(len, test.Target)) - description_format = '\n %%-%is - %%s' % target_name_length + target_name_length = max(map(len, test.Target)) + description_format = '\n %%-%is - %%s' % target_name_length
- for target in test.Target: - help_msg += description_format % (target, CONFIG['target.description'].get(target, '')) + for target in test.Target: + help_msg += description_format % (target, CONFIG['target.description'].get(target, ''))
- help_msg += '\n' + help_msg += '\n'
- return help_msg + return help_msg
-def crop_module_name(name): +def crop_module_name(name: str) -> str: """ Test modules have a 'test.unit.' or 'test.integ.' prefix which can be omitted from our '--test' argument. Cropping this so we can do diff --git a/test/mypy.ini b/test/mypy.ini new file mode 100644 index 00000000..1c77449a --- /dev/null +++ b/test/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +allow_redefinition = True +ignore_missing_imports = True +show_error_codes = True +strict_optional = False +warn_unused_ignores = True diff --git a/test/settings.cfg b/test/settings.cfg index 8c6423bb..51109f96 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -196,10 +196,12 @@ pyflakes.ignore stem/manual.py => undefined name 'sqlite3' pyflakes.ignore stem/client/cell.py => undefined name 'cryptography' pyflakes.ignore stem/client/cell.py => undefined name 'hashlib' pyflakes.ignore stem/client/datatype.py => redefinition of unused 'pop' from * +pyflakes.ignore stem/descriptor/__init__.py => undefined name 'cryptography' pyflakes.ignore stem/descriptor/hidden_service.py => undefined name 'cryptography' pyflakes.ignore stem/interpreter/autocomplete.py => undefined name 'stem' pyflakes.ignore stem/interpreter/help.py => undefined name 'stem' pyflakes.ignore stem/response/events.py => undefined name 'datetime' +pyflakes.ignore stem/socket.py => redefinition of unused '_recv'* pyflakes.ignore stem/util/conf.py => undefined name 'stem' pyflakes.ignore stem/util/enum.py => undefined name 'stem' pyflakes.ignore test/require.py => 'cryptography.utils.int_from_bytes' imported but unused @@ -214,6 +216,23 @@ pyflakes.ignore test/unit/response/events.py => 'from stem import *' used; unabl pyflakes.ignore test/unit/response/events.py => *may be undefined, or defined from star imports: stem pyflakes.ignore test/integ/interpreter.py => 'readline' imported but unused
+# Our enum class confuses mypy. Ignore this until we can change to python 3.x's +# new enum builtin. +# +# For example... +# +# See https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-a... +# Variable "stem.control.EventType" is not valid as a type [valid-type] + +mypy.ignore * => "Enum" has no attribute * +mypy.ignore * => "_IntegerEnum" has no attribute * +mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html* +mypy.ignore * => *is not valid as a type* + +# Metaprogramming prevents mypy from determining descriptor attributes. + +mypy.ignore * => "Descriptor" has no attribute "* + # Test modules we want to run. Modules are roughly ordered by the dependencies # so the lowest level tests come first. This is because a problem in say, # controller message parsing, will cause all higher level tests to fail too. diff --git a/test/task.py b/test/task.py index 2366564c..b2957e65 100644 --- a/test/task.py +++ b/test/task.py @@ -355,7 +355,7 @@ PYCODESTYLE_TASK = StaticCheckTask( MYPY_TASK = StaticCheckTask( 'running mypy', stem.util.test_tools.type_issues, - args = ([os.path.join(test.STEM_BASE, 'stem')],), + args = (['--config-file', os.path.join(test.STEM_BASE, 'test', 'mypy.ini'), os.path.join(test.STEM_BASE, 'stem')],), is_available = stem.util.test_tools.is_mypy_available(), unavailable_msg = MYPY_UNAVAILABLE, ) diff --git a/test/unit/client/address.py b/test/unit/client/address.py index c4e51f4e..352a8d8c 100644 --- a/test/unit/client/address.py +++ b/test/unit/client/address.py @@ -50,7 +50,7 @@ class TestAddress(unittest.TestCase): self.assertEqual(AddrType.UNKNOWN, addr.type) self.assertEqual(12, addr.type_int) self.assertEqual(None, addr.value) - self.assertEqual('hello', addr.value_bin) + self.assertEqual(b'hello', addr.value_bin)
def test_packing(self): test_data = { diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index 37e252b9..02ed2774 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -206,7 +206,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = InvalidArguments
- get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'ControlPort': '9050', 'ControlListenAddress': ['127.0.0.1'], }[param] @@ -217,7 +217,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'ControlPort': '9050', 'ControlListenAddress': ['27.4.4.1'], }[param] @@ -679,7 +679,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- get_conf_mock.side_effect = lambda param, **kwargs: { + get_conf_mock.side_effect = lambda param, *args, **kwargs: { 'BandwidthRate': '1073741824', 'BandwidthBurst': '1073741824', 'RelayBandwidthRate': '0', diff --git a/test/unit/descriptor/bandwidth_file.py b/test/unit/descriptor/bandwidth_file.py index 9bee5f95..5e56f9d2 100644 --- a/test/unit/descriptor/bandwidth_file.py +++ b/test/unit/descriptor/bandwidth_file.py @@ -7,6 +7,7 @@ import datetime import unittest
import stem.descriptor +import stem.util.str_tools
from unittest.mock import Mock, patch
@@ -334,5 +335,5 @@ class TestBandwidthFile(unittest.TestCase): )
for value in test_values: - expected_exc = "First line should be a unix timestamp, but was '%s'" % value + expected_exc = "First line should be a unix timestamp, but was '%s'" % stem.util.str_tools._to_unicode(value) self.assertRaisesWith(ValueError, expected_exc, BandwidthFile.create, {'timestamp': value}) diff --git a/test/unit/interpreter/arguments.py b/test/unit/interpreter/arguments.py index df81e7e3..d61de42d 100644 --- a/test/unit/interpreter/arguments.py +++ b/test/unit/interpreter/arguments.py @@ -1,39 +1,39 @@ import unittest
-from stem.interpreter.arguments import DEFAULT_ARGS, parse, get_help +from stem.interpreter.arguments import Arguments
class TestArgumentParsing(unittest.TestCase): def test_that_we_get_default_values(self): - args = parse([]) + args = Arguments.parse([])
- for attr in DEFAULT_ARGS: - self.assertEqual(DEFAULT_ARGS[attr], getattr(args, attr)) + for attr, value in Arguments._field_defaults.items(): + self.assertEqual(value, getattr(args, attr))
def test_that_we_load_arguments(self): - args = parse(['--interface', '10.0.0.25:80']) + args = Arguments.parse(['--interface', '10.0.0.25:80']) self.assertEqual('10.0.0.25', args.control_address) self.assertEqual(80, args.control_port)
- args = parse(['--interface', '80']) - self.assertEqual(DEFAULT_ARGS['control_address'], args.control_address) + args = Arguments.parse(['--interface', '80']) + self.assertEqual('127.0.0.1', args.control_address) self.assertEqual(80, args.control_port)
- args = parse(['--socket', '/tmp/my_socket']) + args = Arguments.parse(['--socket', '/tmp/my_socket']) self.assertEqual('/tmp/my_socket', args.control_socket)
- args = parse(['--help']) + args = Arguments.parse(['--help']) self.assertEqual(True, args.print_help)
def test_examples(self): - args = parse(['-i', '1643']) + args = Arguments.parse(['-i', '1643']) self.assertEqual(1643, args.control_port)
- args = parse(['-s', '~/.tor/socket']) + args = Arguments.parse(['-s', '~/.tor/socket']) self.assertEqual('~/.tor/socket', args.control_socket)
def test_that_we_reject_unrecognized_arguments(self): - self.assertRaises(ValueError, parse, ['--blarg', 'stuff']) + self.assertRaises(ValueError, Arguments.parse, ['--blarg', 'stuff'])
def test_that_we_reject_invalid_interfaces(self): invalid_inputs = ( @@ -49,15 +49,15 @@ class TestArgumentParsing(unittest.TestCase): )
for invalid_input in invalid_inputs: - self.assertRaises(ValueError, parse, ['--interface', invalid_input]) + self.assertRaises(ValueError, Arguments.parse, ['--interface', invalid_input])
def test_run_with_command(self): - self.assertEqual('GETINFO version', parse(['--run', 'GETINFO version']).run_cmd) + self.assertEqual('GETINFO version', Arguments.parse(['--run', 'GETINFO version']).run_cmd)
def test_run_with_path(self): - self.assertEqual(__file__, parse(['--run', __file__]).run_path) + self.assertEqual(__file__, Arguments.parse(['--run', __file__]).run_path)
def test_get_help(self): - help_text = get_help() + help_text = Arguments.get_help() self.assertTrue('Interactive interpreter for Tor.' in help_text) self.assertTrue('change control interface from 127.0.0.1:default' in help_text) diff --git a/test/unit/util/proc.py b/test/unit/util/proc.py index 2316f669..39087cbe 100644 --- a/test/unit/util/proc.py +++ b/test/unit/util/proc.py @@ -147,18 +147,17 @@ class TestProc(unittest.TestCase):
# tests the case where pid = 0
- if 'start time' in args: - response = 10 - else: - response = () - - for arg in args: - if arg == 'command': - response += ('sched',) - elif arg == 'utime': - response += ('0',) - elif arg == 'stime': - response += ('0',) + response = () + + for arg in args: + if arg == 'command': + response += ('sched',) + elif arg == 'utime': + response += ('0',) + elif arg == 'stime': + response += ('0',) + elif arg == 'start time': + response += ('10',)
get_line_mock.side_effect = lambda *params: { ('/proc/0/stat', '0', 'process %s' % ', '.join(args)): stat
tor-commits@lists.torproject.org