commit b8063b3b23af95e02b27848f6ab5c82edd644609 Author: Damian Johnson atagar@torproject.org Date: Tue Mar 24 17:55:31 2020 -0700
Type hints
Now that our minimum requirement finally meets Python 3.5 we can provide type hints for our IDE users.
https://docs.python.org/3/library/typing.html
This is just a best effort first pass. I don't have an IDE that uses these. No doubt there will be mistakes that need adjustment. --- stem/__init__.py | 30 ++-- stem/client/__init__.py | 37 ++--- stem/client/cell.py | 128 +++++++++-------- stem/client/datatype.py | 84 +++++------ stem/connection.py | 44 +++--- stem/control.py | 241 ++++++++++++++++---------------- stem/descriptor/__init__.py | 66 ++++----- stem/descriptor/bandwidth_file.py | 26 ++-- stem/descriptor/certificate.py | 39 +++--- stem/descriptor/collector.py | 49 +++---- stem/descriptor/extrainfo_descriptor.py | 59 ++++---- stem/descriptor/hidden_service.py | 109 ++++++++------- stem/descriptor/microdescriptor.py | 18 +-- stem/descriptor/networkstatus.py | 122 ++++++++-------- stem/descriptor/remote.py | 67 ++++----- stem/descriptor/router_status_entry.py | 62 ++++---- stem/descriptor/server_descriptor.py | 74 +++++----- stem/descriptor/tordnsel.py | 9 +- stem/directory.py | 49 +++---- stem/exit_policy.py | 100 ++++++------- stem/interpreter/__init__.py | 6 +- stem/interpreter/arguments.py | 6 +- stem/interpreter/autocomplete.py | 9 +- stem/interpreter/commands.py | 22 +-- stem/interpreter/help.py | 8 +- stem/manual.py | 50 +++---- stem/process.py | 8 +- stem/response/__init__.py | 55 ++++---- stem/response/add_onion.py | 2 +- stem/response/authchallenge.py | 2 +- stem/response/events.py | 71 +++++----- stem/response/getconf.py | 2 +- stem/response/getinfo.py | 6 +- stem/response/mapaddress.py | 2 +- stem/response/protocolinfo.py | 5 +- stem/socket.py | 73 +++++----- stem/util/__init__.py | 10 +- stem/util/conf.py | 47 ++++--- stem/util/connection.py | 39 +++--- stem/util/enum.py | 18 +-- stem/util/log.py | 34 ++--- stem/util/proc.py | 34 ++--- stem/util/str_tools.py | 32 +++-- stem/util/system.py | 71 +++++----- stem/util/term.py | 10 +- stem/util/test_tools.py | 68 ++++----- stem/util/tor_tools.py | 25 ++-- stem/version.py | 22 +-- test/integ/control/controller.py | 8 +- test/integ/response/protocolinfo.py | 8 +- test/settings.cfg | 23 ++- test/unit/response/protocolinfo.py | 2 +- 52 files changed, 1143 insertions(+), 1048 deletions(-)
diff --git a/stem/__init__.py b/stem/__init__.py index 907156fe..c0efab19 100644 --- a/stem/__init__.py +++ b/stem/__init__.py @@ -507,6 +507,8 @@ import traceback import stem.util import stem.util.enum
+from typing import Any, Optional, Sequence + __version__ = '1.8.0-dev' __author__ = 'Damian Johnson' __contact__ = 'atagar@torproject.org' @@ -584,7 +586,7 @@ class Endpoint(object): :var int port: port of the endpoint """
- def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: if not stem.util.connection.is_valid_ipv4_address(address) and not stem.util.connection.is_valid_ipv6_address(address): raise ValueError("'%s' isn't a valid IPv4 or IPv6 address" % address) elif not stem.util.connection.is_valid_port(port): @@ -593,13 +595,13 @@ class Endpoint(object): self.address = address self.port = int(port)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'address', 'port', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Endpoint) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -610,11 +612,11 @@ class ORPort(Endpoint): :var list link_protocols: link protocol version we're willing to establish """
- def __init__(self, address, port, link_protocols = None): + def __init__(self, address: str, port: int, link_protocols: Optional[Sequence[int]] = None) -> None: super(ORPort, self).__init__(address, port) self.link_protocols = link_protocols
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'link_protocols', parent = Endpoint, cache = True)
@@ -642,7 +644,7 @@ class OperationFailed(ControllerError): message """
- def __init__(self, code = None, message = None): + def __init__(self, code: Optional[str] = None, message: Optional[str] = None) -> None: super(ControllerError, self).__init__(message) self.code = code self.message = message @@ -658,10 +660,10 @@ class CircuitExtensionFailed(UnsatisfiableRequest): """ An attempt to create or extend a circuit failed.
- :var stem.response.CircuitEvent circ: response notifying us of the failure + :var stem.response.events.CircuitEvent circ: response notifying us of the failure """
- def __init__(self, message, circ = None): + def __init__(self, message: str, circ: Optional['stem.response.events.CircuitEvent'] = None) -> None: super(CircuitExtensionFailed, self).__init__(message = message) self.circ = circ
@@ -674,7 +676,7 @@ class DescriptorUnavailable(UnsatisfiableRequest): Subclassed under UnsatisfiableRequest rather than OperationFailed. """
- def __init__(self, message): + def __init__(self, message: str) -> None: super(DescriptorUnavailable, self).__init__(message = message)
@@ -685,7 +687,7 @@ class Timeout(UnsatisfiableRequest): .. versionadded:: 1.7.0 """
- def __init__(self, message): + def __init__(self, message: str) -> None: super(Timeout, self).__init__(message = message)
@@ -705,7 +707,7 @@ class InvalidArguments(InvalidRequest): :var list arguments: a list of arguments which were invalid """
- def __init__(self, code = None, message = None, arguments = None): + def __init__(self, code: Optional[str] = None, message: Optional[str] = None, arguments: Optional[Sequence[str]] = None): super(InvalidArguments, self).__init__(code, message) self.arguments = arguments
@@ -736,7 +738,7 @@ class DownloadFailed(IOError): :var str stacktrace_str: string representation of the stacktrace """
- def __init__(self, url, error, stacktrace, message = None): + def __init__(self, url: str, error: BaseException, stacktrace: Any, message: Optional[str] = None) -> None: if message is None: # The string representation of exceptions can reside in several places. # urllib.URLError use a 'reason' attribute that in turn may referrence @@ -773,7 +775,7 @@ class DownloadTimeout(DownloadFailed): .. versionadded:: 1.8.0 """
- def __init__(self, url, error, stacktrace, timeout): + def __init__(self, url: str, error: BaseException, stacktrace: Any, timeout: int): message = 'Failed to download from %s: %0.1f second timeout reached' % (url, timeout) super(DownloadTimeout, self).__init__(url, error, stacktrace, message)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 57cd3457..2972985d 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -33,6 +33,9 @@ import stem.client.cell import stem.socket import stem.util.connection
+from types import TracebackType +from typing import Iterator, Optional, Tuple, Type + from stem.client.cell import ( CELL_TYPE_SIZE, FIXED_PAYLOAD_LEN, @@ -63,7 +66,7 @@ class Relay(object): :var int link_protocol: link protocol version we established """
- def __init__(self, orport, link_protocol): + def __init__(self, orport: int, link_protocol: int) -> None: self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_buffer = b'' # unread bytes @@ -71,7 +74,7 @@ class Relay(object): self._circuits = {}
@staticmethod - def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS): + def connect(address: str, port: int, link_protocols: Tuple[int] = DEFAULT_LINK_PROTOCOLS) -> None: """ Establishes a connection with the given ORPort.
@@ -144,7 +147,7 @@ class Relay(object):
return Relay(conn, link_protocol)
- def _recv(self, raw = False): + def _recv(self, raw: bool = False) -> None: """ Reads the next cell from our ORPort. If none is present this blocks until one is available. @@ -185,7 +188,7 @@ class Relay(object): cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol) return cell
- def _msg(self, cell): + def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']: """ Sends a cell on the ORPort and provides the response we receive in reply.
@@ -217,7 +220,7 @@ class Relay(object): for received_cell in stem.client.cell.Cell.pop(response, self.link_protocol): yield received_cell
- def is_alive(self): + def is_alive(self) -> bool: """ Checks if our socket is currently connected. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.is_alive` method. @@ -227,7 +230,7 @@ class Relay(object):
return self._orport.is_alive()
- def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -239,7 +242,7 @@ class Relay(object):
return self._orport.connection_time()
- def close(self): + def close(self) -> None: """ Closes our socket connection. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.close` method. @@ -248,7 +251,7 @@ class Relay(object): with self._orport_lock: return self._orport.close()
- def create_circuit(self): + def create_circuit(self) -> None: """ Establishes a new circuit. """ @@ -277,15 +280,15 @@ class Relay(object):
return circ
- def __iter__(self): + def __iter__(self) -> Iterator['stem.client.Circuit']: with self._orport_lock: for circ in self._circuits.values(): yield circ
- def __enter__(self): + def __enter__(self) -> 'stem.client.Relay': return self
- def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close()
@@ -304,7 +307,7 @@ class Circuit(object): :raises: **ImportError** if the cryptography module is unavailable """
- def __init__(self, relay, circ_id, kdf): + def __init__(self, relay: 'stem.client.Relay', circ_id: int, kdf: 'stem.client.datatype.KDF') -> None: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -320,7 +323,7 @@ class Circuit(object): self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor() self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request, stream_id = 0): + def directory(self, request: str, stream_id: int = 0) -> str: """ Request descriptors from the relay.
@@ -355,7 +358,7 @@ class Circuit(object): else: response.append(decrypted_cell)
- def _send(self, command, data = '', stream_id = 0): + def _send(self, command: 'stem.client.datatype.RelayCommand', data: bytes = b'', stream_id: int = 0) -> None: """ Sends a message over the circuit.
@@ -375,13 +378,13 @@ class Circuit(object): self.forward_digest = forward_digest self.forward_key = forward_key
- def close(self): + def close(self) -> None: with self.relay._orport_lock: self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) del self.relay._circuits[self.id]
- def __enter__(self): + def __enter__(self) -> 'stem.client.Circuit': return self
- def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close() diff --git a/stem/client/cell.py b/stem/client/cell.py index 83888556..ef445a64 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -49,6 +49,8 @@ from stem import UNDEFINED from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split from stem.util import datetime_to_unix, str_tools
+from typing import Any, Sequence, Tuple, Type + FIXED_PAYLOAD_LEN = 509 # PAYLOAD_LEN, per tor-spec section 0.2 AUTH_CHALLENGE_SIZE = 32
@@ -96,17 +98,19 @@ class Cell(object): VALUE = -1 IS_FIXED_SIZE = False
- def __init__(self, unused = b''): + def __init__(self, unused: bytes = b'') -> None: super(Cell, self).__init__() self.unused = unused
@staticmethod - def by_name(name): + def by_name(name: str) -> Type['stem.client.cell.Cell']: """ Provides cell attributes by its name.
:param str name: cell command to fetch
+ :returns: cell class with this name + :raises: **ValueError** if cell type is invalid """
@@ -117,12 +121,14 @@ class Cell(object): raise ValueError("'%s' isn't a valid cell type" % name)
@staticmethod - def by_value(value): + def by_value(value: int) -> Type['stem.client.cell.Cell']: """ Provides cell attributes by its value.
:param int value: cell value to fetch
+ :returns: cell class with this numeric value + :raises: **ValueError** if cell type is invalid """
@@ -136,7 +142,7 @@ class Cell(object): raise NotImplementedError('Packing not yet implemented for %s cells' % type(self).NAME)
@staticmethod - def unpack(content, link_protocol): + def unpack(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell': """ Unpacks all cells from a response.
@@ -155,7 +161,7 @@ class Cell(object): yield cell
@staticmethod - def pop(content, link_protocol): + def pop(content: bytes, link_protocol: 'stem.client.datatype.LinkProtocol') -> Tuple['stem.client.cell.Cell', bytes]: """ Unpacks the first cell.
@@ -187,7 +193,7 @@ class Cell(object): return cls._unpack(payload, circ_id, link_protocol), content
@classmethod - def _pack(cls, link_protocol, payload, unused = b'', circ_id = None): + def _pack(cls: Type['stem.client.cell.Cell'], link_protocol: 'stem.client.datatype.LinkProtocol', payload: bytes, unused: bytes = b'', circ_id: int = None) -> bytes: """ Provides bytes that can be used on the wire for these cell attributes. Format of a properly packed cell depends on if it's fixed or variable @@ -241,13 +247,13 @@ class Cell(object): return bytes(cell)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.Cell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.Cell': """ Subclass implementation for unpacking cell content.
:param bytes content: payload to decode - :param stem.client.datatype.LinkProtocol link_protocol: link protocol version :param int circ_id: circuit id cell is for + :param stem.client.datatype.LinkProtocol link_protocol: link protocol version
:returns: instance of this cell type
@@ -256,10 +262,10 @@ class Cell(object):
raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Cell) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -270,7 +276,7 @@ class CircuitCell(Cell): :var int circ_id: circuit id """
- def __init__(self, circ_id, unused = b''): + def __init__(self, circ_id: int, unused: bytes = b'') -> None: super(CircuitCell, self).__init__(unused) self.circ_id = circ_id
@@ -286,7 +292,7 @@ class PaddingCell(Cell): VALUE = 0 IS_FIXED_SIZE = True
- def __init__(self, payload = None): + def __init__(self, payload: bytes = None) -> None: if not payload: payload = os.urandom(FIXED_PAYLOAD_LEN) elif len(payload) != FIXED_PAYLOAD_LEN: @@ -295,14 +301,14 @@ class PaddingCell(Cell): super(PaddingCell, self).__init__() self.payload = payload
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return PaddingCell._pack(link_protocol, self.payload)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.PaddingCell': return PaddingCell(content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'payload', cache = True)
@@ -311,7 +317,7 @@ class CreateCell(CircuitCell): VALUE = 1 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(CreateCell, self).__init__() # TODO: implement
@@ -320,7 +326,7 @@ class CreatedCell(CircuitCell): VALUE = 2 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(CreatedCell, self).__init__() # TODO: implement
@@ -346,7 +352,7 @@ class RelayCell(CircuitCell): VALUE = 3 IS_FIXED_SIZE = True
- def __init__(self, circ_id, command, data, digest = 0, stream_id = 0, recognized = 0, unused = b''): + def __init__(self, circ_id: int, command, data: bytes, digest: int = 0, stream_id: int = 0, recognized: int = 0, unused: bytes = b'') -> None: if 'hash' in str(type(digest)).lower(): # Unfortunately hashlib generates from a dynamic private class so # isinstance() isn't such a great option. With python2/python3 the @@ -375,7 +381,7 @@ class RelayCell(CircuitCell): elif stream_id and self.command in STREAM_ID_DISALLOWED: raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += Size.CHAR.pack(self.command_int) payload += Size.SHORT.pack(self.recognized) @@ -387,7 +393,7 @@ class RelayCell(CircuitCell): return RelayCell._pack(link_protocol, bytes(payload), self.unused, self.circ_id)
@staticmethod - def decrypt(link_protocol, content, key, digest): + def decrypt(link_protocol: 'stem.client.datatype.LinkProtocol', content: bytes, key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple['stem.client.cell.RelayCell', 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']: """ Decrypts content as a relay cell addressed to us. This provides back a tuple of the form... @@ -441,7 +447,7 @@ class RelayCell(CircuitCell):
return cell, new_key, new_digest
- def encrypt(self, link_protocol, key, digest): + def encrypt(self, link_protocol: 'stem.client.datatype.LinkProtocol', key: 'cryptography.hazmat.primitives.ciphers.CipherContext', digest: 'hashlib.HASH') -> Tuple[bytes, 'cryptography.hazmat.primitives.ciphers.CipherContext', 'hashlib.HASH']: """ Encrypts our cell content to be sent with the given key. This provides back a tuple of the form... @@ -477,7 +483,7 @@ class RelayCell(CircuitCell): return header + new_key.update(payload), new_key, new_digest
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.RelayCell': command, content = Size.CHAR.pop(content) recognized, content = Size.SHORT.pop(content) # 'recognized' field stream_id, content = Size.SHORT.pop(content) @@ -490,7 +496,7 @@ class RelayCell(CircuitCell):
return RelayCell(circ_id, command, data, digest, stream_id, recognized, unused)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'command_int', 'stream_id', 'digest', 'data', cache = True)
@@ -506,19 +512,19 @@ class DestroyCell(CircuitCell): VALUE = 4 IS_FIXED_SIZE = True
- def __init__(self, circ_id, reason = CloseReason.NONE, unused = b''): + def __init__(self, circ_id: int, reason: 'stem.client.datatype.CloseReason' = CloseReason.NONE, unused: bytes = b'') -> None: super(DestroyCell, self).__init__(circ_id, unused) self.reason, self.reason_int = CloseReason.get(reason)
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return DestroyCell._pack(link_protocol, Size.CHAR.pack(self.reason_int), self.unused, self.circ_id)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.DestroyCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.DestroyCell': reason, unused = Size.CHAR.pop(content) return DestroyCell(circ_id, reason, unused)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'reason_int', cache = True)
@@ -534,7 +540,7 @@ class CreateFastCell(CircuitCell): VALUE = 5 IS_FIXED_SIZE = True
- def __init__(self, circ_id, key_material = None, unused = b''): + def __init__(self, circ_id: int, key_material: bytes = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -543,11 +549,11 @@ class CreateFastCell(CircuitCell): super(CreateFastCell, self).__init__(circ_id, unused) self.key_material = key_material
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CreateFastCell._pack(link_protocol, self.key_material, self.unused, self.circ_id)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell': key_material, unused = split(content, HASH_LEN)
if len(key_material) != HASH_LEN: @@ -555,7 +561,7 @@ class CreateFastCell(CircuitCell):
return CreateFastCell(circ_id, key_material, unused)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'key_material', cache = True)
@@ -571,7 +577,7 @@ class CreatedFastCell(CircuitCell): VALUE = 6 IS_FIXED_SIZE = True
- def __init__(self, circ_id, derivative_key, key_material = None, unused = b''): + def __init__(self, circ_id: int, derivative_key: bytes, key_material: bytes = None, unused: bytes = b'') -> None: if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -584,11 +590,11 @@ class CreatedFastCell(CircuitCell): self.key_material = key_material self.derivative_key = derivative_key
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.unused, self.circ_id)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CreateFastCell': if len(content) < HASH_LEN * 2: raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
@@ -597,7 +603,7 @@ class CreatedFastCell(CircuitCell):
return CreatedFastCell(circ_id, derivative_key, key_material, content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'circ_id', 'derivative_key', 'key_material', cache = True)
@@ -612,16 +618,16 @@ class VersionsCell(Cell): VALUE = 7 IS_FIXED_SIZE = False
- def __init__(self, versions): + def __init__(self, versions: Sequence[int]) -> None: super(VersionsCell, self).__init__() self.versions = versions
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = b''.join([Size.SHORT.pack(v) for v in self.versions]) return VersionsCell._pack(link_protocol, payload)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.VersionsCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VersionsCell': link_protocols = []
while content: @@ -630,7 +636,7 @@ class VersionsCell(Cell):
return VersionsCell(link_protocols)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'versions', cache = True)
@@ -647,13 +653,13 @@ class NetinfoCell(Cell): VALUE = 8 IS_FIXED_SIZE = True
- def __init__(self, receiver_address, sender_addresses, timestamp = None, unused = b''): + def __init__(self, receiver_address: 'stem.client.datatype.Address', sender_addresses: Sequence['stem.client.datatype.Address'], timestamp: datetime.datetime = None, unused: bytes = b'') -> None: super(NetinfoCell, self).__init__(unused) self.timestamp = timestamp if timestamp else datetime.datetime.now() self.receiver_address = receiver_address self.sender_addresses = sender_addresses
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += Size.LONG.pack(int(datetime_to_unix(self.timestamp))) payload += self.receiver_address.pack() @@ -665,7 +671,7 @@ class NetinfoCell(Cell): return NetinfoCell._pack(link_protocol, bytes(payload), self.unused)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.NetinfoCell': timestamp, content = Size.LONG.pop(content) receiver_address, content = Address.pop(content)
@@ -678,7 +684,7 @@ class NetinfoCell(Cell):
return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp), unused = content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses', cache = True)
@@ -687,7 +693,7 @@ class RelayEarlyCell(CircuitCell): VALUE = 9 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(RelayEarlyCell, self).__init__() # TODO: implement
@@ -696,7 +702,7 @@ class Create2Cell(CircuitCell): VALUE = 10 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(Create2Cell, self).__init__() # TODO: implement
@@ -705,7 +711,7 @@ class Created2Cell(Cell): VALUE = 11 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(Created2Cell, self).__init__() # TODO: implement
@@ -714,7 +720,7 @@ class PaddingNegotiateCell(Cell): VALUE = 12 IS_FIXED_SIZE = True
- def __init__(self): + def __init__(self) -> None: super(PaddingNegotiateCell, self).__init__() # TODO: implement
@@ -729,7 +735,7 @@ class VPaddingCell(Cell): VALUE = 128 IS_FIXED_SIZE = False
- def __init__(self, size = None, payload = None): + def __init__(self, size: int = None, payload: bytes = None) -> None: if size is None and payload is None: raise ValueError('VPaddingCell constructor must specify payload or size') elif size is not None and size < 0: @@ -740,14 +746,14 @@ class VPaddingCell(Cell): super(VPaddingCell, self).__init__() self.payload = payload if payload is not None else os.urandom(size)
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return VPaddingCell._pack(link_protocol, self.payload)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.VPaddingCell': return VPaddingCell(payload = content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'payload', cache = True)
@@ -762,15 +768,15 @@ class CertsCell(Cell): VALUE = 129 IS_FIXED_SIZE = False
- def __init__(self, certs, unused = b''): + def __init__(self, certs: Sequence['stem.client.Certificate'], unused: bytes = b'') -> None: super(CertsCell, self).__init__(unused) self.certificates = certs
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: return CertsCell._pack(link_protocol, Size.CHAR.pack(len(self.certificates)) + b''.join([cert.pack() for cert in self.certificates]), self.unused)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls, content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.CertsCell': cert_count, content = Size.CHAR.pop(content) certs = []
@@ -783,7 +789,7 @@ class CertsCell(Cell):
return CertsCell(certs, unused = content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'certificates', cache = True)
@@ -800,7 +806,7 @@ class AuthChallengeCell(Cell): VALUE = 130 IS_FIXED_SIZE = False
- def __init__(self, methods, challenge = None, unused = b''): + def __init__(self, methods: Sequence[int], challenge: bytes = None, unused: bytes = b'') -> None: if not challenge: challenge = os.urandom(AUTH_CHALLENGE_SIZE) elif len(challenge) != AUTH_CHALLENGE_SIZE: @@ -810,7 +816,7 @@ class AuthChallengeCell(Cell): self.challenge = challenge self.methods = methods
- def pack(self, link_protocol): + def pack(self, link_protocol: 'stem.client.datatype.LinkProtocol') -> bytes: payload = bytearray() payload += self.challenge payload += Size.SHORT.pack(len(self.methods)) @@ -821,7 +827,7 @@ class AuthChallengeCell(Cell): return AuthChallengeCell._pack(link_protocol, bytes(payload), self.unused)
@classmethod - def _unpack(cls, content, circ_id, link_protocol): + def _unpack(cls: Type['stem.client.cell.AuthChallengeCell'], content: bytes, circ_id: int, link_protocol: 'stem.client.datatype.LinkProtocol') -> 'stem.client.cell.AuthChallengeCell': min_size = AUTH_CHALLENGE_SIZE + Size.SHORT.size if len(content) < min_size: raise ValueError('AUTH_CHALLENGE payload should be at least %i bytes, but was %i' % (min_size, len(content))) @@ -840,7 +846,7 @@ class AuthChallengeCell(Cell):
return AuthChallengeCell(methods, challenge, unused = content)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'challenge', 'methods', cache = True)
@@ -849,7 +855,7 @@ class AuthenticateCell(Cell): VALUE = 131 IS_FIXED_SIZE = False
- def __init__(self): + def __init__(self) -> None: super(AuthenticateCell, self).__init__() # TODO: implement
@@ -858,5 +864,5 @@ class AuthorizeCell(Cell): VALUE = 132 IS_FIXED_SIZE = False
- def __init__(self): + def __init__(self) -> None: super(AuthorizeCell, self).__init__() # TODO: implement diff --git a/stem/client/datatype.py b/stem/client/datatype.py index 4f7110e9..8d8ae7fb 100644 --- a/stem/client/datatype.py +++ b/stem/client/datatype.py @@ -144,6 +144,8 @@ import stem.util import stem.util.connection import stem.util.enum
+from typing import Any, Tuple, Type, Union + ZERO = b'\x00' HASH_LEN = 20 KEY_LEN = 16 @@ -155,7 +157,7 @@ class _IntegerEnum(stem.util.enum.Enum): **UNKNOWN** value for integer values that lack a mapping. """
- def __init__(self, *args): + def __init__(self, *args: Tuple[str, int]) -> None: self._enum_to_int = {} self._int_to_enum = {} parent_args = [] @@ -176,7 +178,7 @@ class _IntegerEnum(stem.util.enum.Enum): parent_args.append(('UNKNOWN', 'UNKNOWN')) super(_IntegerEnum, self).__init__(*parent_args)
- def get(self, val): + def get(self, val: Union[int, str]) -> Tuple[str, int]: """ Provides the (enum, int_value) tuple for a given value. """ @@ -246,7 +248,7 @@ CloseReason = _IntegerEnum( )
-def split(content, size): +def split(content: bytes, size: int) -> Tuple[bytes, bytes]: """ Simple split of bytes into two substrings.
@@ -270,7 +272,7 @@ class LinkProtocol(int): from a range that's determined by our link protocol. """
- def __new__(cls, version): + def __new__(cls: Type['stem.client.datatype.LinkProtocol'], version: int) -> 'stem.client.datatype.LinkProtocol': if isinstance(version, LinkProtocol): return version # already a LinkProtocol
@@ -284,14 +286,14 @@ class LinkProtocol(int):
return protocol
- def __hash__(self): + def __hash__(self) -> int: # All LinkProtocol attributes can be derived from our version, so that's # all we need in our hash. Offsetting by our type so we don't hash conflict # with ints.
return self.version * hash(str(type(self)))
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, int): return self.version == other elif isinstance(other, LinkProtocol): @@ -299,10 +301,10 @@ class LinkProtocol(int): else: return False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
- def __int__(self): + def __int__(self) -> int: return self.version
@@ -311,7 +313,7 @@ class Field(object): Packable and unpackable datatype. """
- def pack(self): + def pack(self) -> bytes: """ Encodes field into bytes.
@@ -323,7 +325,7 @@ class Field(object): raise NotImplementedError('Not yet available')
@classmethod - def unpack(cls, packed): + def unpack(cls, packed: bytes) -> 'stem.client.datatype.Field': """ Decodes bytes into a field of this type.
@@ -342,7 +344,7 @@ class Field(object): return unpacked
@staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple[Any, bytes]: """ Decodes bytes as this field type, providing it and the remainder.
@@ -355,10 +357,10 @@ class Field(object):
raise NotImplementedError('Not yet available')
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Field) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -378,15 +380,15 @@ class Size(Field): ==================== =========== """
- def __init__(self, name, size): + def __init__(self, name: str, size: int) -> None: self.name = name self.size = size
@staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple[int, bytes]: raise NotImplementedError("Use our constant's unpack() and pop() instead")
- def pack(self, content): + def pack(self, content: int) -> bytes: try: return content.to_bytes(self.size, 'big') except: @@ -397,18 +399,18 @@ class Size(Field): else: raise
- def unpack(self, packed): + def unpack(self, packed: bytes) -> int: if self.size != len(packed): raise ValueError('%s is the wrong size for a %s field' % (repr(packed), self.name))
return int.from_bytes(packed, 'big')
- def pop(self, packed): + def pop(self, packed: bytes) -> Tuple[int, bytes]: to_unpack, remainder = split(packed, self.size)
return self.unpack(to_unpack), remainder
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'size', cache = True)
@@ -422,7 +424,7 @@ class Address(Field): :var bytes value_bin: encoded address value """
- def __init__(self, value, addr_type = None): + def __init__(self, value: str, addr_type: Union[int, 'stem.client.datatype.AddrType'] = None) -> None: if addr_type is None: if stem.util.connection.is_valid_ipv4_address(value): addr_type = AddrType.IPv4 @@ -461,7 +463,7 @@ class Address(Field): self.value = None self.value_bin = value
- def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type_int) cell += Size.CHAR.pack(len(self.value_bin)) @@ -469,7 +471,7 @@ class Address(Field): return bytes(cell)
@staticmethod - def pop(content): + def pop(content) -> Tuple['stem.client.datatype.Address', bytes]: addr_type, content = Size.CHAR.pop(content) addr_length, content = Size.CHAR.pop(content)
@@ -480,7 +482,7 @@ class Address(Field):
return Address(addr_value, addr_type), content
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type_int', 'value_bin', cache = True)
@@ -493,11 +495,11 @@ class Certificate(Field): :var bytes value: certificate value """
- def __init__(self, cert_type, value): + def __init__(self, cert_type: Union[int, 'stem.client.datatype.CertType'], value: bytes) -> None: self.type, self.type_int = CertType.get(cert_type) self.value = value
- def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type_int) cell += Size.SHORT.pack(len(self.value)) @@ -505,7 +507,7 @@ class Certificate(Field): return bytes(cell)
@staticmethod - def pop(content): + def pop(content: bytes) -> Tuple['stem.client.datatype.Certificate', bytes]: cert_type, content = Size.CHAR.pop(content) cert_size, content = Size.SHORT.pop(content)
@@ -515,7 +517,7 @@ class Certificate(Field): cert_bytes, content = split(content, cert_size) return Certificate(cert_type, cert_bytes), content
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type_int', 'value')
@@ -532,12 +534,12 @@ class LinkSpecifier(Field): :var bytes value: encoded link specification destination """
- def __init__(self, link_type, value): + def __init__(self, link_type: int, value: bytes) -> None: self.type = link_type self.value = value
@staticmethod - def pop(packed): + def pop(packed: bytes) -> Tuple['stem.client.datatype.LinkSpecifier', bytes]: # LSTYPE (Link specifier type) [1 byte] # LSLEN (Link specifier length) [1 byte] # LSPEC (Link specifier) [LSLEN bytes] @@ -561,7 +563,7 @@ class LinkSpecifier(Field): else: return LinkSpecifier(link_type, value), packed # unrecognized type
- def pack(self): + def pack(self) -> bytes: cell = bytearray() cell += Size.CHAR.pack(self.type) cell += Size.CHAR.pack(len(self.value)) @@ -579,14 +581,14 @@ class LinkByIPv4(LinkSpecifier): :var int port: relay ORPort """
- def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: super(LinkByIPv4, self).__init__(0, _pack_ipv4_address(address) + Size.SHORT.pack(port))
self.address = address self.port = port
@staticmethod - def unpack(value): + def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv4': if len(value) != 6: raise ValueError('IPv4 link specifiers should be six bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
@@ -604,14 +606,14 @@ class LinkByIPv6(LinkSpecifier): :var int port: relay ORPort """
- def __init__(self, address, port): + def __init__(self, address: str, port: int) -> None: super(LinkByIPv6, self).__init__(1, _pack_ipv6_address(address) + Size.SHORT.pack(port))
self.address = address self.port = port
@staticmethod - def unpack(value): + def unpack(value: bytes) -> 'stem.client.datatype.LinkByIPv6': if len(value) != 18: raise ValueError('IPv6 link specifiers should be eighteen bytes, but was %i instead: %s' % (len(value), binascii.hexlify(value)))
@@ -628,7 +630,7 @@ class LinkByFingerprint(LinkSpecifier): :var str fingerprint: relay sha1 fingerprint """
- def __init__(self, value): + def __init__(self, value: bytes) -> None: super(LinkByFingerprint, self).__init__(2, value)
if len(value) != 20: @@ -646,7 +648,7 @@ class LinkByEd25519(LinkSpecifier): :var str fingerprint: relay ed25519 fingerprint """
- def __init__(self, value): + def __init__(self, value: bytes) -> None: super(LinkByEd25519, self).__init__(3, value)
if len(value) != 32: @@ -668,7 +670,7 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward """
@staticmethod - def from_value(key_material): + def from_value(key_material: bytes) -> 'stem.client.datatype.KDF': # Derived key material, as per... # # K = H(K0 | [00]) | H(K0 | [01]) | H(K0 | [02]) | ... @@ -689,19 +691,19 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward return KDF(key_hash, forward_digest, backward_digest, forward_key, backward_key)
-def _pack_ipv4_address(address): +def _pack_ipv4_address(address: str) -> bytes: return b''.join([Size.CHAR.pack(int(v)) for v in address.split('.')])
-def _unpack_ipv4_address(value): +def _unpack_ipv4_address(value: str) -> bytes: return '.'.join([str(Size.CHAR.unpack(value[i:i + 1])) for i in range(4)])
-def _pack_ipv6_address(address): +def _pack_ipv6_address(address: str) -> bytes: return b''.join([Size.SHORT.pack(int(v, 16)) for v in address.split(':')])
-def _unpack_ipv6_address(value): +def _unpack_ipv6_address(value: str) -> bytes: return ':'.join(['%04x' % Size.SHORT.unpack(value[i * 2:(i + 1) * 2]) for i in range(8)])
diff --git a/stem/connection.py b/stem/connection.py index e3032784..3d3eb3ee 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -135,6 +135,7 @@ import os
import stem.control import stem.response +import stem.response.protocolinfo import stem.socket import stem.util.connection import stem.util.enum @@ -142,6 +143,7 @@ import stem.util.str_tools import stem.util.system import stem.version
+from typing import Any, Optional, Sequence, Tuple, Type, Union from stem.util import log
AuthMethod = stem.util.enum.Enum('NONE', 'PASSWORD', 'COOKIE', 'SAFECOOKIE', 'UNKNOWN') @@ -209,7 +211,7 @@ COMMON_TOR_COMMANDS = ( )
-def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.Controller): +def connect(control_port: Tuple[str, int] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: type = stem.control.Controller) -> Union[stem.control.BaseController, stem.socket.ControlSocket]: """ Convenience function for quickly getting a control connection. This is very handy for debugging or CLI setup, handling setup and prompting for a password @@ -234,7 +236,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ Use both port 9051 and 9151 by default.
:param tuple contol_port: address and port tuple, for instance **('127.0.0.1', 9051)** - :param str path: path where the control socket is located + :param str control_socket: path where the control socket is located :param str password: passphrase to authenticate to the socket :param bool password_prompt: prompt for the controller password if it wasn't supplied @@ -295,7 +297,7 @@ def connect(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/ return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket, password, password_prompt, chroot_path, controller): +def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Union[Type[stem.control.BaseController], Type[stem.socket.ControlSocket]]) -> Union[stem.control.BaseController, stem.socket.ControlSocket]: """ Helper for the connect_* functions that authenticates the socket and constructs the controller. @@ -361,7 +363,7 @@ def _connect_auth(control_socket, password, password_prompt, chroot_path, contro return None
-def authenticate(controller, password = None, chroot_path = None, protocolinfo_response = None): +def authenticate(controller: Any, password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None: """ Authenticates to a control socket using the information provided by a PROTOCOLINFO response. In practice this will often be all we need to @@ -575,7 +577,7 @@ def authenticate(controller, password = None, chroot_path = None, protocolinfo_r raise AssertionError('BUG: Authentication failed without providing a recognized exception: %s' % str(auth_exceptions))
-def authenticate_none(controller, suppress_ctl_errors = True): +def authenticate_none(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], suppress_ctl_errors: bool = True) -> None: """ Authenticates to an open control socket. All control connections need to authenticate before they can be used, even if tor hasn't been configured to @@ -622,7 +624,7 @@ def authenticate_none(controller, suppress_ctl_errors = True): raise OpenAuthRejected('Socket failed (%s)' % exc)
-def authenticate_password(controller, password, suppress_ctl_errors = True): +def authenticate_password(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket that uses a password (via the HashedControlPassword torrc option). Quotes in the password are escaped. @@ -692,7 +694,7 @@ def authenticate_password(controller, password, suppress_ctl_errors = True): raise PasswordAuthRejected('Socket failed (%s)' % exc)
-def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True): +def authenticate_cookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket that uses the contents of an authentication cookie (generated via the CookieAuthentication torrc option). This does basic @@ -782,7 +784,7 @@ def authenticate_cookie(controller, cookie_path, suppress_ctl_errors = True): raise CookieAuthRejected('Socket failed (%s)' % exc, cookie_path, False)
-def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True): +def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None: """ Authenticates to a control socket using the safe cookie method, which is enabled by setting the CookieAuthentication torrc option on Tor client's which @@ -931,7 +933,7 @@ def authenticate_safecookie(controller, cookie_path, suppress_ctl_errors = True) raise CookieAuthRejected(str(auth_response), cookie_path, True, auth_response)
-def get_protocolinfo(controller): +def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.ControlSocket]) -> stem.response.protocolinfo.ProtocolInfoResponse: """ Issues a PROTOCOLINFO query to a control socket, getting information about the tor process running on it. If the socket is already closed then it is @@ -971,7 +973,7 @@ def get_protocolinfo(controller): return protocolinfo_response
-def _msg(controller, message): +def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage: """ Sends and receives a message with either a :class:`~stem.socket.ControlSocket` or :class:`~stem.control.BaseController`. @@ -984,7 +986,7 @@ def _msg(controller, message): return controller.msg(message)
-def _connection_for_default_port(address): +def _connection_for_default_port(address: str) -> stem.socket.ControlPort: """ Attempts to provide a controller connection for either port 9051 (default for relays) or 9151 (default for Tor Browser). If both fail then this raises the @@ -1006,7 +1008,7 @@ def _connection_for_default_port(address): raise exc
-def _read_cookie(cookie_path, is_safecookie): +def _read_cookie(cookie_path: str, is_safecookie: bool) -> str: """ Provides the contents of a given cookie file.
@@ -1014,6 +1016,8 @@ def _read_cookie(cookie_path, is_safecookie): :param bool is_safecookie: **True** if this was for SAFECOOKIE authentication, **False** if for COOKIE
+ :returns: **str** with the cookie file content + :raises: * :class:`stem.connection.UnreadableCookieFile` if the cookie file is unreadable @@ -1048,7 +1052,7 @@ def _read_cookie(cookie_path, is_safecookie): raise UnreadableCookieFile(exc_msg, cookie_path, is_safecookie)
-def _hmac_sha256(key, msg): +def _hmac_sha256(key: str, msg: str) -> bytes: """ Generates a sha256 digest using the given key and message.
@@ -1065,11 +1069,11 @@ class AuthenticationFailure(Exception): """ Base error for authentication failures.
- :var stem.socket.ControlMessage auth_response: AUTHENTICATE response from the + :var stem.response.ControlMessage auth_response: AUTHENTICATE response from the control socket, **None** if one wasn't received """
- def __init__(self, message, auth_response = None): + def __init__(self, message: str, auth_response: Optional[stem.response.ControlMessage] = None) -> None: super(AuthenticationFailure, self).__init__(message) self.auth_response = auth_response
@@ -1081,7 +1085,7 @@ class UnrecognizedAuthMethods(AuthenticationFailure): :var list unknown_auth_methods: authentication methods that weren't recognized """
- def __init__(self, message, unknown_auth_methods): + def __init__(self, message: str, unknown_auth_methods: Sequence[str]) -> None: super(UnrecognizedAuthMethods, self).__init__(message) self.unknown_auth_methods = unknown_auth_methods
@@ -1125,7 +1129,7 @@ class CookieAuthFailed(AuthenticationFailure): authentication attempt """
- def __init__(self, message, cookie_path, is_safecookie, auth_response = None): + def __init__(self, message: str, cookie_path: str, is_safecookie: bool, auth_response: Optional[stem.response.ControlMessage] = None) -> None: super(CookieAuthFailed, self).__init__(message, auth_response) self.is_safecookie = is_safecookie self.cookie_path = cookie_path @@ -1152,7 +1156,7 @@ class AuthChallengeFailed(CookieAuthFailed): AUTHCHALLENGE command has failed. """
- def __init__(self, message, cookie_path): + def __init__(self, message: str, cookie_path: str) -> None: super(AuthChallengeFailed, self).__init__(message, cookie_path, True)
@@ -1169,7 +1173,7 @@ class UnrecognizedAuthChallengeMethod(AuthChallengeFailed): :var str authchallenge_method: AUTHCHALLENGE method that Tor couldn't recognize """
- def __init__(self, message, cookie_path, authchallenge_method): + def __init__(self, message: str, cookie_path: str, authchallenge_method: str) -> None: super(UnrecognizedAuthChallengeMethod, self).__init__(message, cookie_path) self.authchallenge_method = authchallenge_method
@@ -1201,7 +1205,7 @@ class NoAuthCookie(MissingAuthInfo): authentication, **False** if for COOKIE """
- def __init__(self, message, is_safecookie): + def __init__(self, message: str, is_safecookie: bool) -> None: super(NoAuthCookie, self).__init__(message) self.is_safecookie = is_safecookie
diff --git a/stem/control.py b/stem/control.py index 4016e762..ec4ba54e 100644 --- a/stem/control.py +++ b/stem/control.py @@ -255,7 +255,9 @@ import stem.descriptor.router_status_entry import stem.descriptor.server_descriptor import stem.exit_policy import stem.response +import stem.response.add_onion import stem.response.events +import stem.response.protocolinfo import stem.socket import stem.util import stem.util.conf @@ -268,6 +270,8 @@ import stem.version
from stem import UNDEFINED, CircStatus, Signal from stem.util import log +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events, # but if it takes longer than this we terminate. @@ -447,15 +451,15 @@ class CreateHiddenServiceOutput(collections.namedtuple('CreateHiddenServiceOutpu """
-def with_default(yields = False): +def with_default(yields: bool = False) -> Callable: """ Provides a decorator to support having a default value. This should be treated as private. """
- def decorator(func): - def get_default(func, args, kwargs): - arg_names = inspect.getargspec(func).args[1:] # drop 'self' + def decorator(func: Callable) -> Callable: + def get_default(func: Callable, args: Any, kwargs: Any) -> Any: + arg_names = inspect.getfullargspec(func).args[1:] # drop 'self' default_position = arg_names.index('default') if 'default' in arg_names else None
if default_position is not None and default_position < len(args): @@ -465,7 +469,7 @@ def with_default(yields = False):
if not yields: @functools.wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(self, *args: Any, **kwargs: Any) -> Any: try: return func(self, *args, **kwargs) except: @@ -477,7 +481,7 @@ def with_default(yields = False): return default else: @functools.wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(self, *args: Any, **kwargs: Any) -> Any: try: for val in func(self, *args, **kwargs): yield val @@ -496,7 +500,7 @@ def with_default(yields = False): return decorator
-def event_description(event): +def event_description(event: str) -> str: """ Provides a description for Tor events.
@@ -538,7 +542,7 @@ class BaseController(object): socket as though it hasn't yet been authenticated. """
- def __init__(self, control_socket, is_authenticated = False): + def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._socket = control_socket self._msg_lock = threading.RLock()
@@ -576,7 +580,7 @@ class BaseController(object): if is_authenticated: self._post_authentication()
- def msg(self, message): + def msg(self, message: str) -> stem.response.ControlMessage: """ Sends a message to our control socket and provides back its reply.
@@ -659,7 +663,7 @@ class BaseController(object): self.close() raise
- def is_alive(self): + def is_alive(self) -> bool: """ Checks if our socket is currently connected. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.is_alive` method. @@ -669,7 +673,7 @@ class BaseController(object):
return self._socket.is_alive()
- def is_localhost(self): + def is_localhost(self) -> bool: """ Returns if the connection is for the local system or not.
@@ -680,7 +684,7 @@ class BaseController(object):
return self._socket.is_localhost()
- def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -694,7 +698,7 @@ class BaseController(object):
return self._socket.connection_time()
- def is_authenticated(self): + def is_authenticated(self) -> bool: """ Checks if our socket is both connected and authenticated.
@@ -704,7 +708,7 @@ class BaseController(object):
return self._is_authenticated if self.is_alive() else False
- def connect(self): + def connect(self) -> None: """ Reconnects our control socket. This is a pass-through for our socket's :func:`~stem.socket.ControlSocket.connect` method. @@ -714,7 +718,7 @@ class BaseController(object):
self._socket.connect()
- def close(self): + def close(self) -> None: """ Closes our socket connection. This is a pass-through for our socket's :func:`~stem.socket.BaseSocket.close` method. @@ -733,7 +737,7 @@ class BaseController(object): if t.is_alive() and threading.current_thread() != t: t.join()
- def get_socket(self): + def get_socket(self) -> stem.socket.ControlSocket: """ Provides the socket used to speak with the tor process. Communicating with the socket directly isn't advised since it may confuse this controller. @@ -743,7 +747,7 @@ class BaseController(object):
return self._socket
- def get_latest_heartbeat(self): + def get_latest_heartbeat(self) -> float: """ Provides the unix timestamp for when we last heard from tor. This is zero if we've never received a message. @@ -753,7 +757,7 @@ class BaseController(object):
return self._last_heartbeat
- def add_status_listener(self, callback, spawn = True): + def add_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None], spawn: bool = True) -> None: """ Notifies a given function when the state of our socket changes. Functions are expected to be of the form... @@ -783,7 +787,7 @@ class BaseController(object): with self._status_listeners_lock: self._status_listeners.append((callback, spawn))
- def remove_status_listener(self, callback): + def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool: """ Stops listener from being notified of further events.
@@ -805,13 +809,13 @@ class BaseController(object): self._status_listeners = new_listeners return is_changed
- def __enter__(self): + def __enter__(self) -> 'stem.control.BaseController': return self
- def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: self.close()
- def _handle_event(self, event_message): + def _handle_event(self, event_message: stem.response.ControlMessage) -> None: """ Callback to be overwritten by subclasses for event listening. This is notified whenever we receive an event from the control socket. @@ -822,13 +826,13 @@ class BaseController(object):
pass
- def _connect(self): + def _connect(self) -> None: self._launch_threads() self._notify_status_listeners(State.INIT) self._socket_connect() self._is_authenticated = False
- def _close(self): + def _close(self) -> None: # Our is_alive() state is now false. Our reader thread should already be # awake from recv() raising a closure exception. Wake up the event thread # too so it can end. @@ -846,12 +850,12 @@ class BaseController(object):
self._socket_close()
- def _post_authentication(self): + def _post_authentication(self) -> None: # actions to be taken after we have a newly authenticated connection
self._is_authenticated = True
- def _notify_status_listeners(self, state): + def _notify_status_listeners(self, state: 'stem.control.State') -> None: """ Informs our status listeners that a state change occurred.
@@ -895,7 +899,7 @@ class BaseController(object): else: listener(self, state, change_timestamp)
- def _launch_threads(self): + def _launch_threads(self) -> None: """ Initializes daemon threads. Threads can't be reused so we need to recreate them if we're restarted. @@ -915,7 +919,7 @@ class BaseController(object): self._event_thread.setDaemon(True) self._event_thread.start()
- def _reader_loop(self): + def _reader_loop(self) -> None: """ Continually pulls from the control socket, directing the messages into queues based on their type. Controller messages come in two varieties... @@ -944,7 +948,7 @@ class BaseController(object):
self._reply_queue.put(exc)
- def _event_loop(self): + def _event_loop(self) -> None: """ Continually pulls messages from the _event_queue and sends them to our handle_event callback. This is done via its own thread so subclasses with a @@ -982,7 +986,7 @@ class Controller(BaseController): """
@staticmethod - def from_port(address = '127.0.0.1', port = 'default'): + def from_port(address: str = '127.0.0.1', port: int = 'default') -> 'stem.control.Controller': """ Constructs a :class:`~stem.socket.ControlPort` based Controller.
@@ -1016,7 +1020,7 @@ class Controller(BaseController): return Controller(control_port)
@staticmethod - def from_socket_file(path = '/var/run/tor/control'): + def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller': """ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
@@ -1030,7 +1034,7 @@ class Controller(BaseController): control_socket = stem.socket.ControlSocketFile(path) return Controller(control_socket)
- def __init__(self, control_socket, is_authenticated = False): + def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None: self._is_caching_enabled = True self._request_cache = {} self._last_newnym = 0.0 @@ -1048,14 +1052,14 @@ class Controller(BaseController):
super(Controller, self).__init__(control_socket, is_authenticated)
- def _sighup_listener(event): + def _sighup_listener(event: stem.response.events.Event) -> None: if event.signal == Signal.RELOAD: self.clear_cache() self._notify_status_listeners(State.RESET)
self.add_event_listener(_sighup_listener, EventType.SIGNAL)
- def _confchanged_listener(event): + def _confchanged_listener(event: stem.response.events.Event) -> None: if self.is_caching_enabled(): to_cache_changed = dict((k.lower(), v) for k, v in event.changed.items()) to_cache_unset = dict((k.lower(), []) for k in event.unset) # [] represents None value in cache @@ -1070,7 +1074,7 @@ class Controller(BaseController):
self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED)
- def _address_changed_listener(event): + def _address_changed_listener(event: stem.response.events.Event) -> None: if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'): self._set_cache({'exit_policy': None}) self._set_cache({'address': None}, 'getinfo') @@ -1078,11 +1082,11 @@ class Controller(BaseController):
self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER)
- def close(self): + def close(self) -> None: self.clear_cache() super(Controller, self).close()
- def authenticate(self, *args, **kwargs): + def authenticate(self, *args: Any, **kwargs: Any) -> None: """ A convenience method to authenticate the controller. This is just a pass-through to :func:`stem.connection.authenticate`. @@ -1091,7 +1095,7 @@ class Controller(BaseController): import stem.connection stem.connection.authenticate(self, *args, **kwargs)
- def reconnect(self, *args, **kwargs): + def reconnect(self, *args: Any, **kwargs: Any) -> None: """ Reconnects and authenticates to our control socket.
@@ -1108,7 +1112,7 @@ class Controller(BaseController): self.authenticate(*args, **kwargs)
@with_default() - def get_info(self, params, default = UNDEFINED, get_bytes = False): + def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]: """ get_info(params, default = UNDEFINED, get_bytes = False)
@@ -1232,7 +1236,7 @@ class Controller(BaseController): raise
@with_default() - def get_version(self, default = UNDEFINED): + def get_version(self, default: Any = UNDEFINED) -> stem.version.Version: """ get_version(default = UNDEFINED)
@@ -1261,7 +1265,7 @@ class Controller(BaseController): return version
@with_default() - def get_exit_policy(self, default = UNDEFINED): + def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy: """ get_exit_policy(default = UNDEFINED)
@@ -1293,7 +1297,7 @@ class Controller(BaseController): return policy
@with_default() - def get_ports(self, listener_type, default = UNDEFINED): + def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]: """ get_ports(listener_type, default = UNDEFINED)
@@ -1315,7 +1319,7 @@ class Controller(BaseController): and no default was provided """
- def is_localhost(address): + def is_localhost(address: str) -> bool: if stem.util.connection.is_valid_ipv4_address(address): return address == '0.0.0.0' or address.startswith('127.') elif stem.util.connection.is_valid_ipv6_address(address): @@ -1330,7 +1334,7 @@ class Controller(BaseController): return [port for (addr, port) in self.get_listeners(listener_type) if is_localhost(addr)]
@with_default() - def get_listeners(self, listener_type, default = UNDEFINED): + def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]: """ get_listeners(listener_type, default = UNDEFINED)
@@ -1436,7 +1440,7 @@ class Controller(BaseController): return listeners
@with_default() - def get_accounting_stats(self, default = UNDEFINED): + def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats': """ get_accounting_stats(default = UNDEFINED)
@@ -1480,7 +1484,7 @@ class Controller(BaseController): )
@with_default() - def get_protocolinfo(self, default = UNDEFINED): + def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse: """ get_protocolinfo(default = UNDEFINED)
@@ -1503,7 +1507,7 @@ class Controller(BaseController): return stem.connection.get_protocolinfo(self)
@with_default() - def get_user(self, default = UNDEFINED): + def get_user(self, default: Any = UNDEFINED) -> str: """ get_user(default = UNDEFINED)
@@ -1538,7 +1542,7 @@ class Controller(BaseController): raise ValueError("Unable to resolve tor's user" if self.is_localhost() else "Tor isn't running locally")
@with_default() - def get_pid(self, default = UNDEFINED): + def get_pid(self, default: Any = UNDEFINED) -> int: """ get_pid(default = UNDEFINED)
@@ -1594,7 +1598,7 @@ class Controller(BaseController): raise ValueError("Unable to resolve tor's pid" if self.is_localhost() else "Tor isn't running locally")
@with_default() - def get_start_time(self, default = UNDEFINED): + def get_start_time(self, default: Any = UNDEFINED) -> float: """ get_start_time(default = UNDEFINED)
@@ -1644,7 +1648,7 @@ class Controller(BaseController): raise ValueError("Unable to resolve when tor began" if self.is_localhost() else "Tor isn't running locally")
@with_default() - def get_uptime(self, default = UNDEFINED): + def get_uptime(self, default: Any = UNDEFINED) -> float: """ get_uptime(default = UNDEFINED)
@@ -1662,7 +1666,7 @@ class Controller(BaseController):
return time.time() - self.get_start_time()
- def is_user_traffic_allowed(self): + def is_user_traffic_allowed(self) -> bool: """ Checks if we're likely to service direct user traffic. This essentially boils down to... @@ -1704,7 +1708,7 @@ class Controller(BaseController): return UserTrafficAllowed(inbound_allowed, outbound_allowed)
@with_default() - def get_microdescriptor(self, relay = None, default = UNDEFINED): + def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor: """ get_microdescriptor(relay = None, default = UNDEFINED)
@@ -1762,7 +1766,7 @@ class Controller(BaseController): return stem.descriptor.microdescriptor.Microdescriptor(desc_content)
@with_default(yields = True) - def get_microdescriptors(self, default = UNDEFINED): + def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]: """ get_microdescriptors(default = UNDEFINED)
@@ -1793,7 +1797,7 @@ class Controller(BaseController): yield desc
@with_default() - def get_server_descriptor(self, relay = None, default = UNDEFINED): + def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor: """ get_server_descriptor(relay = None, default = UNDEFINED)
@@ -1856,7 +1860,7 @@ class Controller(BaseController): return stem.descriptor.server_descriptor.RelayDescriptor(desc_content)
@with_default(yields = True) - def get_server_descriptors(self, default = UNDEFINED): + def get_server_descriptors(self, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor: """ get_server_descriptors(default = UNDEFINED)
@@ -1892,7 +1896,7 @@ class Controller(BaseController): yield desc
@with_default() - def get_network_status(self, relay = None, default = UNDEFINED): + def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3: """ get_network_status(relay = None, default = UNDEFINED)
@@ -1951,7 +1955,7 @@ class Controller(BaseController): return stem.descriptor.router_status_entry.RouterStatusEntryV3(desc_content)
@with_default(yields = True) - def get_network_statuses(self, default = UNDEFINED): + def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]: """ get_network_statuses(default = UNDEFINED)
@@ -1988,7 +1992,7 @@ class Controller(BaseController): yield desc
@with_default() - def get_hidden_service_descriptor(self, address, default = UNDEFINED, servers = None, await_result = True, timeout = None): + def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: """ get_hidden_service_descriptor(address, default = UNDEFINED, servers = None, await_result = True)
@@ -2036,10 +2040,10 @@ class Controller(BaseController): start_time = time.time()
if await_result: - def hs_desc_listener(event): + def hs_desc_listener(event: stem.response.events.Event) -> None: hs_desc_queue.put(event)
- def hs_desc_content_listener(event): + def hs_desc_content_listener(event: stem.response.events.Event) -> None: hs_desc_content_queue.put(event)
self.add_event_listener(hs_desc_listener, EventType.HS_DESC) @@ -2084,7 +2088,7 @@ class Controller(BaseController): if hs_desc_content_listener: self.remove_event_listener(hs_desc_content_listener)
- def get_conf(self, param, default = UNDEFINED, multiple = False): + def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]: """ get_conf(param, default = UNDEFINED, multiple = False)
@@ -2133,7 +2137,7 @@ class Controller(BaseController): entries = self.get_conf_map(param, default, multiple) return _case_insensitive_lookup(entries, param, default)
- def get_conf_map(self, params, default = UNDEFINED, multiple = True): + def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]: """ get_conf_map(params, default = UNDEFINED, multiple = True)
@@ -2251,7 +2255,7 @@ class Controller(BaseController): else: raise
- def _get_conf_dict_to_response(self, config_dict, default, multiple): + def _get_conf_dict_to_response(self, config_dict: Mapping[str, Sequence[str]], default: Any, multiple: bool) -> Dict[str, Union[str, Sequence[str]]]: """ Translates a dictionary of 'config key => [value1, value2...]' into the return value of :func:`~stem.control.Controller.get_conf_map`, taking into @@ -2273,7 +2277,7 @@ class Controller(BaseController): return return_dict
@with_default() - def is_set(self, param, default = UNDEFINED): + def is_set(self, param: str, default: Any = UNDEFINED) -> bool: """ is_set(param, default = UNDEFINED)
@@ -2293,7 +2297,7 @@ class Controller(BaseController):
return param in self._get_custom_options()
- def _get_custom_options(self): + def _get_custom_options(self) -> Dict[str, str]: result = self._get_cache('get_custom_options')
if not result: @@ -2320,7 +2324,7 @@ class Controller(BaseController):
return result
- def set_conf(self, param, value): + def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None: """ Changes the value of a tor configuration option. Our value can be any of the following... @@ -2342,7 +2346,7 @@ class Controller(BaseController):
self.set_options({param: value}, False)
- def reset_conf(self, *params): + def reset_conf(self, *params: str) -> None: """ Reverts one or more parameters to their default values.
@@ -2357,7 +2361,7 @@ class Controller(BaseController):
self.set_options(dict([(entry, None) for entry in params]), True)
- def set_options(self, params, reset = False): + def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None: """ Changes multiple tor configuration options via either a SETCONF or RESETCONF query. Both behave identically unless our value is None, in which @@ -2439,7 +2443,7 @@ class Controller(BaseController): raise stem.ProtocolError('Returned unexpected status code: %s' % response.code)
@with_default() - def get_hidden_service_conf(self, default = UNDEFINED): + def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]: """ get_hidden_service_conf(default = UNDEFINED)
@@ -2534,7 +2538,7 @@ class Controller(BaseController): self._set_cache({'hidden_service_conf': service_dir_map}) return service_dir_map
- def set_hidden_service_conf(self, conf): + def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None: """ Update all the configured hidden services from a dictionary having the same format as @@ -2599,7 +2603,7 @@ class Controller(BaseController):
self.set_options(hidden_service_options)
- def create_hidden_service(self, path, port, target_address = None, target_port = None, auth_type = None, client_names = None): + def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.cotroller.CreateHiddenServiceOutput': """ Create a new hidden service. If the directory is already present, a new port is added. @@ -2717,7 +2721,7 @@ class Controller(BaseController): config = conf, )
- def remove_hidden_service(self, path, port = None): + def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool: """ Discontinues a given hidden service.
@@ -2759,7 +2763,7 @@ class Controller(BaseController): return True
@with_default() - def list_ephemeral_hidden_services(self, default = UNDEFINED, our_services = True, detached = False): + def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]: """ list_ephemeral_hidden_services(default = UNDEFINED, our_services = True, detached = False)
@@ -2799,7 +2803,7 @@ class Controller(BaseController):
return [r for r in result if r] # drop any empty responses (GETINFO is blank if unset)
- def create_ephemeral_hidden_service(self, ports, key_type = 'NEW', key_content = 'BEST', discard_key = False, detached = False, await_publication = False, timeout = None, basic_auth = None, max_streams = None): + def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse: """ Creates a new hidden service. Unlike :func:`~stem.control.Controller.create_hidden_service` this style of @@ -2905,7 +2909,7 @@ class Controller(BaseController): start_time = time.time()
if await_publication: - def hs_desc_listener(event): + def hs_desc_listener(event: stem.response.events.Event) -> None: hs_desc_queue.put(event)
self.add_event_listener(hs_desc_listener, EventType.HS_DESC) @@ -2983,7 +2987,7 @@ class Controller(BaseController):
return response
- def remove_ephemeral_hidden_service(self, service_id): + def remove_ephemeral_hidden_service(self, service_id: str) -> bool: """ Discontinues a given hidden service that was created with :func:`~stem.control.Controller.create_ephemeral_hidden_service`. @@ -3008,7 +3012,7 @@ class Controller(BaseController): else: raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
- def add_event_listener(self, listener, *events): + def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None: """ Directs further tor controller events to a given function. The function is expected to take a single argument, which is a @@ -3066,7 +3070,7 @@ class Controller(BaseController): if failed_events: raise stem.ProtocolError('SETEVENTS rejected %s' % ', '.join(failed_events))
- def remove_event_listener(self, listener): + def remove_event_listener(self, listener: Callable[[stem.response.events.Event], None]) -> None: """ Stops a listener from being notified of further tor events.
@@ -3092,7 +3096,7 @@ class Controller(BaseController): if not response.is_ok(): raise stem.ProtocolError('SETEVENTS received unexpected response\n%s' % response)
- def _get_cache(self, param, namespace = None): + def _get_cache(self, param: str, namespace: Optional[str] = None) -> Any: """ Queries our request cache for the given key.
@@ -3109,7 +3113,7 @@ class Controller(BaseController): cache_key = '%s.%s' % (namespace, param) if namespace else param return self._request_cache.get(cache_key, None)
- def _get_cache_map(self, params, namespace = None): + def _get_cache_map(self, params: Sequence[str], namespace: Optional[str] = None) -> Dict[str, Any]: """ Queries our request cache for multiple entries.
@@ -3131,7 +3135,7 @@ class Controller(BaseController):
return cached_values
- def _set_cache(self, params, namespace = None): + def _set_cache(self, params: Mapping[str, Any], namespace: Optional[str] = None) -> None: """ Sets the given request cache entries. If the new cache value is **None** then it is removed from our cache. @@ -3173,7 +3177,7 @@ class Controller(BaseController): else: self._request_cache[cache_key] = value
- def _confchanged_cache_invalidation(self, params): + def _confchanged_cache_invalidation(self, params: Mapping[str, Any]) -> None: """ Drops dependent portions of the cache when configuration changes.
@@ -3197,7 +3201,7 @@ class Controller(BaseController):
self._set_cache({'exit_policy': None}) # numerous options can change our policy
- def is_caching_enabled(self): + def is_caching_enabled(self) -> bool: """ **True** if caching has been enabled, **False** otherwise.
@@ -3206,7 +3210,7 @@ class Controller(BaseController):
return self._is_caching_enabled
- def set_caching(self, enabled): + def set_caching(self, enabled: bool) -> None: """ Enables or disables caching of information retrieved from tor.
@@ -3218,7 +3222,7 @@ class Controller(BaseController): if not self._is_caching_enabled: self.clear_cache()
- def clear_cache(self): + def clear_cache(self) -> None: """ Drops any cached results. """ @@ -3227,7 +3231,7 @@ class Controller(BaseController): self._request_cache = {} self._last_newnym = 0.0
- def load_conf(self, configtext): + def load_conf(self, configtext: str) -> None: """ Sends the configuration text to Tor and loads it as if it has been read from the torrc. @@ -3247,7 +3251,7 @@ class Controller(BaseController): elif not response.is_ok(): raise stem.ProtocolError('+LOADCONF Received unexpected response\n%s' % str(response))
- def save_conf(self, force = False): + def save_conf(self, force: bool = False) -> None: """ Saves the current configuration options into the active torrc file.
@@ -3273,7 +3277,7 @@ class Controller(BaseController): else: raise stem.ProtocolError('SAVECONF returned unexpected response code')
- def is_feature_enabled(self, feature): + def is_feature_enabled(self, feature: str) -> bool: """ Checks if a control connection feature is enabled. These features can be enabled using :func:`~stem.control.Controller.enable_feature`. @@ -3290,7 +3294,7 @@ class Controller(BaseController):
return feature in self._enabled_features
- def enable_feature(self, features): + def enable_feature(self, features: Union[str, Sequence[str]]) -> None: """ Enables features that are disabled by default to maintain backward compatibility. Once enabled, a feature cannot be disabled and a new @@ -3324,7 +3328,7 @@ class Controller(BaseController): self._enabled_features += [entry.upper() for entry in features]
@with_default() - def get_circuit(self, circuit_id, default = UNDEFINED): + def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent: """ get_circuit(circuit_id, default = UNDEFINED)
@@ -3349,7 +3353,7 @@ class Controller(BaseController): raise ValueError("Tor currently does not have a circuit with the id of '%s'" % circuit_id)
@with_default() - def get_circuits(self, default = UNDEFINED): + def get_circuits(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.CircuitEvent]: """ get_circuits(default = UNDEFINED)
@@ -3372,7 +3376,7 @@ class Controller(BaseController):
return circuits
- def new_circuit(self, path = None, purpose = 'general', await_build = False, timeout = None): + def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: """ Requests a new circuit. If the path isn't provided, one is automatically selected. @@ -3380,7 +3384,7 @@ class Controller(BaseController): .. versionchanged:: 1.7.0 Added the timeout argument.
- :param list,str path: one or more relays to make a circuit through + :param str,list path: one or more relays to make a circuit through :param str purpose: 'general' or 'controller' :param bool await_build: blocks until the circuit is built if **True** :param float timeout: seconds to wait when **await_build** is **True** @@ -3394,7 +3398,7 @@ class Controller(BaseController):
return self.extend_circuit('0', path, purpose, await_build, timeout)
- def extend_circuit(self, circuit_id = '0', path = None, purpose = 'general', await_build = False, timeout = None): + def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str: """ Either requests the creation of a new circuit or extends an existing one.
@@ -3418,7 +3422,7 @@ class Controller(BaseController): Added the timeout argument.
:param str circuit_id: id of a circuit to be extended - :param list,str path: one or more relays to make a circuit through, this is + :param str,list path: one or more relays to make a circuit through, this is required if the circuit id is non-zero :param str purpose: 'general' or 'controller' :param bool await_build: blocks until the circuit is built if **True** @@ -3442,7 +3446,7 @@ class Controller(BaseController): start_time = time.time()
if await_build: - def circ_listener(event): + def circ_listener(event: stem.response.events.Event) -> None: circ_queue.put(event)
self.add_event_listener(circ_listener, EventType.CIRC) @@ -3489,7 +3493,7 @@ class Controller(BaseController): if circ_listener: self.remove_event_listener(circ_listener)
- def repurpose_circuit(self, circuit_id, purpose): + def repurpose_circuit(self, circuit_id: str, purpose: str) -> None: """ Changes a circuit's purpose. Currently, two purposes are recognized... * general @@ -3510,7 +3514,7 @@ class Controller(BaseController): else: raise stem.ProtocolError('SETCIRCUITPURPOSE returned unexpected response code: %s' % response.code)
- def close_circuit(self, circuit_id, flag = ''): + def close_circuit(self, circuit_id: str, flag: str = '') -> None: """ Closes the specified circuit.
@@ -3518,8 +3522,9 @@ class Controller(BaseController): :param str flag: optional value to modify closing, the only flag available is 'IfUnused' which will not close the circuit unless it is unused
- :raises: :class:`stem.InvalidArguments` if the circuit is unknown - :raises: :class:`stem.InvalidRequest` if not enough information is provided + :raises: + * :class:`stem.InvalidArguments` if the circuit is unknown + * :class:`stem.InvalidRequest` if not enough information is provided """
response = self.msg('CLOSECIRCUIT %s %s' % (circuit_id, flag)) @@ -3534,7 +3539,7 @@ class Controller(BaseController): raise stem.ProtocolError('CLOSECIRCUIT returned unexpected response code: %s' % response.code)
@with_default() - def get_streams(self, default = UNDEFINED): + def get_streams(self, default: Any = UNDEFINED) -> Sequence[stem.response.events.StreamEvent]: """ get_streams(default = UNDEFINED)
@@ -3558,7 +3563,7 @@ class Controller(BaseController):
return streams
- def attach_stream(self, stream_id, circuit_id, exiting_hop = None): + def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None: """ Attaches a stream to a circuit.
@@ -3593,7 +3598,7 @@ class Controller(BaseController): else: raise stem.ProtocolError('ATTACHSTREAM returned unexpected response code: %s' % response.code)
- def close_stream(self, stream_id, reason = stem.RelayEndReason.MISC, flag = ''): + def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None: """ Closes the specified stream.
@@ -3622,7 +3627,7 @@ class Controller(BaseController): else: raise stem.ProtocolError('CLOSESTREAM returned unexpected response code: %s' % response.code)
- def signal(self, signal): + def signal(self, signal: stem.Signal) -> None: """ Sends a signal to the Tor client.
@@ -3645,7 +3650,7 @@ class Controller(BaseController):
raise stem.ProtocolError('SIGNAL response contained unrecognized status code: %s' % response.code)
- def is_newnym_available(self): + def is_newnym_available(self) -> bool: """ Indicates if tor would currently accept a NEWNYM signal. This can only account for signals sent via this controller. @@ -3661,7 +3666,7 @@ class Controller(BaseController): else: return False
- def get_newnym_wait(self): + def get_newnym_wait(self) -> float: """ Provides the number of seconds until a NEWNYM signal would be respected. This can only account for signals sent via this controller. @@ -3675,7 +3680,7 @@ class Controller(BaseController): return max(0.0, self._last_newnym + 10 - time.time())
@with_default() - def get_effective_rate(self, default = UNDEFINED, burst = False): + def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int: """ get_effective_rate(default = UNDEFINED, burst = False)
@@ -3714,7 +3719,7 @@ class Controller(BaseController):
return value
- def map_address(self, mapping): + def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]: """ Map addresses to replacement addresses. Tor replaces subseqent connections to the original addresses with the replacement addresses. @@ -3726,11 +3731,11 @@ class Controller(BaseController):
:param dict mapping: mapping of original addresses to replacement addresses
+ :returns: **dict** with 'original -> replacement' address mappings + :raises: * :class:`stem.InvalidRequest` if the addresses are malformed * :class:`stem.OperationFailed` if Tor couldn't fulfill the request - - :returns: **dict** with 'original -> replacement' address mappings """
mapaddress_arg = ' '.join(['%s=%s' % (k, v) for (k, v) in list(mapping.items())]) @@ -3739,7 +3744,7 @@ class Controller(BaseController):
return response.entries
- def drop_guards(self): + def drop_guards(self) -> None: """ Drops our present guard nodes and picks a new set.
@@ -3750,7 +3755,7 @@ class Controller(BaseController):
self.msg('DROPGUARDS')
- def _post_authentication(self): + def _post_authentication(self) -> None: super(Controller, self)._post_authentication()
# try to re-attach event listeners to the new instance @@ -3788,7 +3793,7 @@ class Controller(BaseController): else: log.warn('We were unable assert ownership of tor through TAKEOWNERSHIP, despite being configured to be the owning process through __OwningControllerProcess. (%s)' % response)
- def _handle_event(self, event_message): + def _handle_event(self, event_message: str) -> None: try: stem.response.convert('EVENT', event_message) event_type = event_message.type @@ -3805,7 +3810,7 @@ class Controller(BaseController): except Exception as exc: log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event_message))
- def _attach_listeners(self): + def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]: """ Attempts to subscribe to the self._event_listeners events from tor. This is a no-op if we're not currently authenticated. @@ -3849,7 +3854,7 @@ class Controller(BaseController): return (set_events, failed_events)
-def _parse_circ_path(path): +def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]: """ Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor circuit paths are defined as being of the form... @@ -3892,7 +3897,7 @@ def _parse_circ_path(path): return []
-def _parse_circ_entry(entry): +def _parse_circ_entry(entry: str) -> Tuple[str, str]: """ Parses a single relay's 'LongName' or 'ServerID'. See the :func:`~stem.control._parse_circ_path` function for more information. @@ -3930,7 +3935,7 @@ def _parse_circ_entry(entry):
@with_default() -def _case_insensitive_lookup(entries, key, default = UNDEFINED): +def _case_insensitive_lookup(entries: Union[Sequence[str], Mapping[str, Any]], key: str, default: Any = UNDEFINED) -> Any: """ Makes a case insensitive lookup within a list or dictionary, providing the first matching entry that we come across. @@ -3957,7 +3962,7 @@ def _case_insensitive_lookup(entries, key, default = UNDEFINED): raise ValueError("key '%s' doesn't exist in dict: %s" % (key, entries))
-def _get_with_timeout(event_queue, timeout, start_time): +def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time: float) -> Any: """ Pulls an item from a queue with a given timeout. """ diff --git a/stem/descriptor/__init__.py b/stem/descriptor/__init__.py index ff273405..9c769749 100644 --- a/stem/descriptor/__init__.py +++ b/stem/descriptor/__init__.py @@ -120,6 +120,8 @@ import stem.util.enum import stem.util.str_tools import stem.util.system
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type + __all__ = [ 'bandwidth_file', 'certificate', @@ -192,7 +194,7 @@ class _Compression(object): .. versionadded:: 1.8.0 """
- def __init__(self, name, module, encoding, extension, decompression_func): + def __init__(self, name: str, module: Optional[str], encoding: str, extension: str, decompression_func: Callable[[Any, str], bytes]) -> None: if module is None: self._module = None self.available = True @@ -222,7 +224,7 @@ class _Compression(object): self._module_name = module self._decompression_func = decompression_func
- def decompress(self, content): + def decompress(self, content: bytes) -> bytes: """ Decompresses the given content via this method.
@@ -250,11 +252,11 @@ class _Compression(object): except Exception as exc: raise IOError('Failed to decompress as %s: %s' % (self, exc))
- def __str__(self): + def __str__(self) -> str: return self._name
-def _zstd_decompress(module, content): +def _zstd_decompress(module: Any, content: str) -> bytes: output_buffer = io.BytesIO()
with module.ZstdDecompressor().write_to(output_buffer) as decompressor: @@ -286,7 +288,7 @@ class TypeAnnotation(collections.namedtuple('TypeAnnotation', ['name', 'major_ve :var int minor_version: minor version number """
- def __str__(self): + def __str__(self) -> str: return '@type %s %s.%s' % (self.name, self.major_version, self.minor_version)
@@ -302,7 +304,7 @@ class SigningKey(collections.namedtuple('SigningKey', ['private', 'public', 'pub """
-def parse_file(descriptor_file, descriptor_type = None, validate = False, document_handler = DocumentHandler.ENTRIES, normalize_newlines = None, **kwargs): +def parse_file(descriptor_file: BinaryIO, descriptor_type: str = None, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, normalize_newlines: Optional[bool] = None, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: """ Simple function to read the descriptor contents from a file, providing an iterator for its :class:`~stem.descriptor.__init__.Descriptor` contents. @@ -405,7 +407,7 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume descriptor_path = getattr(descriptor_file, 'name', None) filename = '<undefined>' if descriptor_path is None else os.path.basename(descriptor_file.name)
- def parse(descriptor_file): + def parse(descriptor_file: BinaryIO) -> Iterator['stem.descriptor.Descriptor']: if normalize_newlines: descriptor_file = NewlineNormalizer(descriptor_file)
@@ -448,20 +450,20 @@ def parse_file(descriptor_file, descriptor_type = None, validate = False, docume yield desc
-def _parse_file_for_path(descriptor_file, *args, **kwargs): +def _parse_file_for_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with open(descriptor_file, 'rb') as desc_file: for desc in parse_file(desc_file, *args, **kwargs): yield desc
-def _parse_file_for_tar_path(descriptor_file, *args, **kwargs): +def _parse_file_for_tar_path(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: with tarfile.open(descriptor_file) as tar_file: for desc in parse_file(tar_file, *args, **kwargs): desc._set_path(os.path.abspath(descriptor_file)) yield desc
-def _parse_file_for_tarfile(descriptor_file, *args, **kwargs): +def _parse_file_for_tarfile(descriptor_file: BinaryIO, *args: Any, **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: for tar_entry in descriptor_file: if tar_entry.isfile(): entry = descriptor_file.extractfile(tar_entry) @@ -477,7 +479,7 @@ def _parse_file_for_tarfile(descriptor_file, *args, **kwargs): entry.close()
-def _parse_metrics_file(descriptor_type, major_version, minor_version, descriptor_file, validate, document_handler, **kwargs): +def _parse_metrics_file(descriptor_type: Type['stem.descriptor.Descriptor'], major_version: int, minor_version: int, descriptor_file: BinaryIO, validate: bool, document_handler: 'stem.descriptor.DocumentHandler', **kwargs: Any) -> Iterator['stem.descriptor.Descriptor']: # Parses descriptor files from metrics, yielding individual descriptors. This # throws a TypeError if the descriptor_type or version isn't recognized.
@@ -547,7 +549,7 @@ def _parse_metrics_file(descriptor_type, major_version, minor_version, descripto raise TypeError("Unrecognized metrics descriptor format. type: '%s', version: '%i.%i'" % (descriptor_type, major_version, minor_version))
-def _descriptor_content(attr = None, exclude = (), header_template = (), footer_template = ()): +def _descriptor_content(attr: Mapping[str, str] = None, exclude: Sequence[str] = (), header_template: Sequence[str] = (), footer_template: Sequence[str] = ()) -> bytes: """ Constructs a minimal descriptor with the given attributes. The content we provide back is of the form... @@ -619,28 +621,28 @@ def _descriptor_content(attr = None, exclude = (), header_template = (), footer_ return stem.util.str_tools._to_bytes('\n'.join(header_content + remainder + footer_content))
-def _value(line, entries): +def _value(line: str, entries: Dict[str, Sequence[str]]) -> str: return entries[line][0][0]
-def _values(line, entries): +def _values(line: str, entries: Dict[str, Sequence[str]]) -> Sequence[str]: return [entry[0] for entry in entries[line]]
-def _parse_simple_line(keyword, attribute, func = None): - def _parse(descriptor, entries): +def _parse_simple_line(keyword: str, attribute: str, func: Callable[[str], str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries) setattr(descriptor, attribute, func(value) if func else value)
return _parse
-def _parse_if_present(keyword, attribute): +def _parse_if_present(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: return lambda descriptor, entries: setattr(descriptor, attribute, keyword in entries)
-def _parse_bytes_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_bytes_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: line_match = re.search(stem.util.str_tools._to_bytes('^(opt )?%s(?:[%s]+(.*))?$' % (keyword, WHITESPACE)), descriptor.get_bytes(), re.MULTILINE) result = None
@@ -653,8 +655,8 @@ def _parse_bytes_line(keyword, attribute): return _parse
-def _parse_int_line(keyword, attribute, allow_negative = True): - def _parse(descriptor, entries): +def _parse_int_line(keyword: str, attribute: str, allow_negative: bool = True) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
try: @@ -670,10 +672,10 @@ def _parse_int_line(keyword, attribute, allow_negative = True): return _parse
-def _parse_timestamp_line(keyword, attribute): +def _parse_timestamp_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: # "<keyword>" YYYY-MM-DD HH:MM:SS
- def _parse(descriptor, entries): + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
try: @@ -684,10 +686,10 @@ def _parse_timestamp_line(keyword, attribute): return _parse
-def _parse_forty_character_hex(keyword, attribute): +def _parse_forty_character_hex(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: # format of fingerprints, sha1 digests, etc
- def _parse(descriptor, entries): + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
if not stem.util.tor_tools.is_hex_digits(value, 40): @@ -698,8 +700,8 @@ def _parse_forty_character_hex(keyword, attribute): return _parse
-def _parse_protocol_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_protocol_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # parses 'protocol' entries like: Cons=1-2 Desc=1-2 DirCache=1 HSDir=1
value = _value(keyword, entries) @@ -729,8 +731,8 @@ def _parse_protocol_line(keyword, attribute): return _parse
-def _parse_key_block(keyword, attribute, expected_block_type, value_attribute = None): - def _parse(descriptor, entries): +def _parse_key_block(keyword: str, attribute: str, expected_block_type: str, value_attribute: Optional[str] = None) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value, block_type, block_contents = entries[keyword][0]
if not block_contents or block_type != expected_block_type: @@ -744,7 +746,7 @@ def _parse_key_block(keyword, attribute, expected_block_type, value_attribute = return _parse
-def _mappings_for(keyword, value, require_value = False, divider = ' '): +def _mappings_for(keyword: str, value: str, require_value: bool = False, divider: str = ' ') -> Iterator[Tuple[str, str]]: """ Parses an attribute as a series of 'key=value' mappings. Unlike _parse_* functions this is a helper, returning the attribute value rather than setting @@ -777,7 +779,7 @@ def _mappings_for(keyword, value, require_value = False, divider = ' '): yield k, v
-def _copy(default): +def _copy(default: Any) -> Any: if default is None or isinstance(default, (bool, stem.exit_policy.ExitPolicy)): return default # immutable elif default in EMPTY_COLLECTION: @@ -786,7 +788,7 @@ def _copy(default): return copy.copy(default)
-def _encode_digest(hash_value, encoding): +def _encode_digest(hash_value: bytes, encoding: 'stem.descriptor.DigestEncoding') -> str: """ Encodes a hash value with the given HashEncoding. """ diff --git a/stem/descriptor/bandwidth_file.py b/stem/descriptor/bandwidth_file.py index 3cf20595..49df3173 100644 --- a/stem/descriptor/bandwidth_file.py +++ b/stem/descriptor/bandwidth_file.py @@ -21,6 +21,8 @@ import time
import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type + from stem.descriptor import ( _mappings_for, Descriptor, @@ -50,7 +52,7 @@ class RecentStats(object): :var RelayFailures relay_failures: number of relays we failed to measure """
- def __init__(self): + def __init__(self) -> None: self.consensus_count = None self.prioritized_relays = None self.prioritized_relay_lists = None @@ -73,7 +75,7 @@ class RelayFailures(object): by default) """
- def __init__(self): + def __init__(self) -> None: self.no_measurement = None self.insuffient_period = None self.insufficient_measurements = None @@ -83,22 +85,22 @@ class RelayFailures(object): # Converts header attributes to a given type. Malformed fields should be # ignored according to the spec.
-def _str(val): +def _str(val: str) -> str: return val # already a str
-def _int(val): +def _int(val: str) -> int: return int(val) if (val and val.isdigit()) else None
-def _date(val): +def _date(val: str) -> datetime.datetime: try: return stem.util.str_tools._parse_iso_timestamp(val) except ValueError: return None # not an iso formatted date
-def _csv(val): +def _csv(val: str) -> Sequence[str]: return list(map(lambda v: v.strip(), val.split(','))) if val is not None else None
@@ -150,7 +152,7 @@ HEADER_DEFAULT = { }
-def _parse_file(descriptor_file, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: """ Iterates over the bandwidth authority metrics in a file.
@@ -169,7 +171,7 @@ def _parse_file(descriptor_file, validate = False, **kwargs): yield BandwidthFile(descriptor_file.read(), validate, **kwargs)
-def _parse_header(descriptor, entries): +def _parse_header(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: header = collections.OrderedDict() content = io.BytesIO(descriptor.get_bytes())
@@ -214,7 +216,7 @@ def _parse_header(descriptor, entries): raise ValueError("The 'version' header must be in the second position")
-def _parse_timestamp(descriptor, entries): +def _parse_timestamp(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: first_line = io.BytesIO(descriptor.get_bytes()).readline().strip()
if first_line.isdigit(): @@ -223,7 +225,7 @@ def _parse_timestamp(descriptor, entries): raise ValueError("First line should be a unix timestamp, but was '%s'" % first_line)
-def _parse_body(descriptor, entries): +def _parse_body(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # In version 1.0.0 the body is everything after the first line. Otherwise # it's everything after the header's divider.
@@ -301,7 +303,7 @@ class BandwidthFile(Descriptor): ATTRIBUTES.update(dict([(k, (None, _parse_header)) for k in HEADER_ATTR.keys()]))
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.bandwidth_file.BandwidthFile'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: """ Creates descriptor content with the given attributes. This descriptor type differs somewhat from others and treats our attr/exclude attributes as @@ -352,7 +354,7 @@ class BandwidthFile(Descriptor):
return b'\n'.join(lines)
- def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: str, validate: bool = False) -> None: super(BandwidthFile, self).__init__(raw_content, lazy_load = not validate)
if validate: diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py index 0522b883..6956a60f 100644 --- a/stem/descriptor/certificate.py +++ b/stem/descriptor/certificate.py @@ -64,6 +64,7 @@ import stem.util.enum import stem.util.str_tools
from stem.client.datatype import CertType, Field, Size, split +from typing import Callable, Dict, Optional, Sequence, Tuple, Union
ED25519_KEY_LENGTH = 32 ED25519_HEADER_LENGTH = 40 @@ -88,7 +89,7 @@ class Ed25519Extension(Field): :var bytes data: data the extension concerns """
- def __init__(self, ext_type, flag_val, data): + def __init__(self, ext_type: 'stem.descriptor.certificate.ExtensionType', flag_val: int, data: bytes) -> None: self.type = ext_type self.flags = [] self.flag_int = flag_val if flag_val else 0 @@ -104,7 +105,7 @@ class Ed25519Extension(Field): if ext_type == ExtensionType.HAS_SIGNING_KEY and len(data) != 32: raise ValueError('Ed25519 HAS_SIGNING_KEY extension must be 32 bytes, but was %i.' % len(data))
- def pack(self): + def pack(self) -> bytes: encoded = bytearray() encoded += Size.SHORT.pack(len(self.data)) encoded += Size.CHAR.pack(self.type) @@ -113,7 +114,7 @@ class Ed25519Extension(Field): return bytes(encoded)
@staticmethod - def pop(content): + def pop(content: bytes) -> Tuple['stem.descriptor.certificate.Ed25519Extension', bytes]: if len(content) < 4: raise ValueError('Ed25519 extension is missing header fields')
@@ -127,7 +128,7 @@ class Ed25519Extension(Field):
return Ed25519Extension(ext_type, flags, data), content
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'type', 'flag_int', 'data', cache = True)
@@ -138,11 +139,11 @@ class Ed25519Certificate(object): :var int version: certificate format version """
- def __init__(self, version): + def __init__(self, version: int) -> None: self.version = version
@staticmethod - def unpack(content): + def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519Certificate': """ Parses a byte encoded ED25519 certificate.
@@ -162,7 +163,7 @@ class Ed25519Certificate(object): raise ValueError('Ed25519 certificate is version %i. Parser presently only supports version 1.' % version)
@staticmethod - def from_base64(content): + def from_base64(content: str) -> 'stem.descriptor.certificate.Ed25519Certificate': """ Parses a base64 encoded ED25519 certificate.
@@ -189,7 +190,7 @@ class Ed25519Certificate(object): except (TypeError, binascii.Error) as exc: raise ValueError("Ed25519 certificate wasn't propoerly base64 encoded (%s):\n%s" % (exc, content))
- def pack(self): + def pack(self) -> bytes: """ Encoded byte representation of our certificate.
@@ -198,7 +199,7 @@ class Ed25519Certificate(object):
raise NotImplementedError('Certificate encoding has not been implemented for %s' % type(self).__name__)
- def to_base64(self, pem = False): + def to_base64(self, pem: bool = False) -> str: """ Base64 encoded certificate data.
@@ -206,7 +207,7 @@ class Ed25519Certificate(object): https://en.wikipedia.org/wiki/Privacy-Enhanced_Mail`_, for more information see `RFC 7468 https://tools.ietf.org/html/rfc7468`_
- :returns: **unicode** for our encoded certificate representation + :returns: **str** for our encoded certificate representation """
encoded = b'\n'.join(stem.util.str_tools._split_by_length(base64.b64encode(self.pack()), 64)) @@ -217,7 +218,7 @@ class Ed25519Certificate(object): return stem.util.str_tools._to_unicode(encoded)
@staticmethod - def _from_descriptor(keyword, attribute): + def _from_descriptor(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: def _parse(descriptor, entries): value, block_type, block_contents = entries[keyword][0]
@@ -228,7 +229,7 @@ class Ed25519Certificate(object):
return _parse
- def __str__(self): + def __str__(self) -> str: return self.to_base64(pem = True)
@@ -252,7 +253,7 @@ class Ed25519CertificateV1(Ed25519Certificate): is unavailable """
- def __init__(self, cert_type = None, expiration = None, key_type = None, key = None, extensions = None, signature = None, signing_key = None): + def __init__(self, cert_type: Optional['stem.client.datatype.CertType'] = None, expiration: Optional[datetime.datetime] = None, key_type: Optional[int] = None, key: Optional[bytes] = None, extensions: Optional[Sequence['stem.descriptor.certificate.Ed25519Extension']] = None, signature: Optional[bytes] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> None: super(Ed25519CertificateV1, self).__init__(1)
if cert_type is None: @@ -284,7 +285,7 @@ class Ed25519CertificateV1(Ed25519Certificate): elif self.type == CertType.UNKNOWN: raise ValueError('Ed25519 certificate type %i is unrecognized' % self.type_int)
- def pack(self): + def pack(self) -> bytes: encoded = bytearray() encoded += Size.CHAR.pack(self.version) encoded += Size.CHAR.pack(self.type_int) @@ -302,7 +303,7 @@ class Ed25519CertificateV1(Ed25519Certificate): return bytes(encoded)
@staticmethod - def unpack(content): + def unpack(content: bytes) -> 'stem.descriptor.certificate.Ed25519CertificateV1': if len(content) < ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH: raise ValueError('Ed25519 certificate was %i bytes, but should be at least %i' % (len(content), ED25519_HEADER_LENGTH + ED25519_SIGNATURE_LENGTH))
@@ -329,7 +330,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return Ed25519CertificateV1(cert_type, datetime.datetime.utcfromtimestamp(expiration_hours * 3600), key_type, key, extensions, signature)
- def is_expired(self): + def is_expired(self) -> bool: """ Checks if this certificate is presently expired or not.
@@ -338,7 +339,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return datetime.datetime.now() > self.expiration
- def signing_key(self): + def signing_key(self) -> bytes: """ Provides this certificate's signing key.
@@ -354,7 +355,7 @@ class Ed25519CertificateV1(Ed25519Certificate):
return None
- def validate(self, descriptor): + def validate(self, descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> None: """ Validate our descriptor content matches its ed25519 signature. Supported descriptor types include... @@ -410,7 +411,7 @@ class Ed25519CertificateV1(Ed25519Certificate): raise ValueError('Descriptor Ed25519 certificate signature invalid (signature forged or corrupt)')
@staticmethod - def _signed_content(descriptor): + def _signed_content(descriptor: Union['stem.descriptor.server_descriptor.RelayDescriptor', 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3']) -> bytes: """ Provides this descriptor's signing constant, appended with the portion of the descriptor that's signed. diff --git a/stem/descriptor/collector.py b/stem/descriptor/collector.py index 7aeb298b..1f1b1e95 100644 --- a/stem/descriptor/collector.py +++ b/stem/descriptor/collector.py @@ -63,6 +63,7 @@ import stem.util.connection import stem.util.str_tools
from stem.descriptor import Compression, DocumentHandler +from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
COLLECTOR_URL = 'https://collector.torproject.org/' REFRESH_INDEX_RATE = 3600 # get new index if cached copy is an hour old @@ -76,7 +77,7 @@ SEC_DATE = re.compile('(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2})') FUTURE = datetime.datetime(9999, 1, 1)
-def get_instance(): +def get_instance() -> 'stem.descriptor.collector.CollecTor': """ Provides the singleton :class:`~stem.descriptor.collector.CollecTor` used for this module's shorthand functions. @@ -92,7 +93,7 @@ def get_instance(): return SINGLETON_COLLECTOR
-def get_server_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): +def get_server_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_server_descriptors` @@ -103,7 +104,7 @@ def get_server_descriptors(start = None, end = None, cache_to = None, bridge = F yield desc
-def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): +def get_extrainfo_descriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_extrainfo_descriptors` @@ -114,7 +115,7 @@ def get_extrainfo_descriptors(start = None, end = None, cache_to = None, bridge yield desc
-def get_microdescriptors(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_microdescriptors(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_microdescriptors` @@ -125,7 +126,7 @@ def get_microdescriptors(start = None, end = None, cache_to = None, timeout = No yield desc
-def get_consensus(start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3): +def get_consensus(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_consensus` @@ -136,7 +137,7 @@ def get_consensus(start = None, end = None, cache_to = None, document_handler = yield desc
-def get_key_certificates(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_key_certificates(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_key_certificates` @@ -147,7 +148,7 @@ def get_key_certificates(start = None, end = None, cache_to = None, timeout = No yield desc
-def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_bandwidth_files(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_bandwidth_files` @@ -158,7 +159,7 @@ def get_bandwidth_files(start = None, end = None, cache_to = None, timeout = Non yield desc
-def get_exit_lists(start = None, end = None, cache_to = None, timeout = None, retries = 3): +def get_exit_lists(start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: """ Shorthand for :func:`~stem.descriptor.collector.CollecTor.get_exit_lists` @@ -187,7 +188,7 @@ class File(object): :var datetime last_modified: when the file was last modified """
- def __init__(self, path, types, size, sha256, first_published, last_published, last_modified): + def __init__(self, path: str, types: Tuple[str], size: int, sha256: str, first_published: datetime.datetime, last_published: datetime.datetime, last_modified: datetime.datetime) -> None: self.path = path self.types = tuple(types) if types else () self.compression = File._guess_compression(path) @@ -205,7 +206,7 @@ class File(object): else: self.start, self.end = File._guess_time_range(path)
- def read(self, directory = None, descriptor_type = None, start = None, end = None, document_handler = DocumentHandler.ENTRIES, timeout = None, retries = 3): + def read(self, directory: Optional[str] = None, descriptor_type: Optional[str] = None, start: datetime.datetime = None, end: datetime.datetime = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.Descriptor']: """ Provides descriptors from this archive. Descriptors are downloaded or read from disk as follows... @@ -289,7 +290,7 @@ class File(object):
yield desc
- def download(self, directory, decompress = True, timeout = None, retries = 3, overwrite = False): + def download(self, directory: str, decompress: bool = True, timeout: Optional[int] = None, retries: Optional[int] = 3, overwrite: bool = False) -> str: """ Downloads this file to the given location. If a file already exists this is a no-op. @@ -345,7 +346,7 @@ class File(object): return path
@staticmethod - def _guess_compression(path): + def _guess_compression(path) -> 'stem.descriptor.Compression': """ Determine file comprssion from CollecTor's filename. """ @@ -357,7 +358,7 @@ class File(object): return Compression.PLAINTEXT
@staticmethod - def _guess_time_range(path): + def _guess_time_range(path) -> Tuple[datetime.datetime, datetime.datetime]: """ Attemt to determine the (start, end) time range from CollecTor's filename. This provides (None, None) if this cannot be determined. @@ -398,7 +399,7 @@ class CollecTor(object): :var float timeout: duration before we'll time out our request """
- def __init__(self, retries = 2, timeout = None): + def __init__(self, retries: Optional[int] = 2, timeout: Optional[int] = None) -> None: self.retries = retries self.timeout = timeout
@@ -406,7 +407,7 @@ class CollecTor(object): self._cached_files = None self._cached_index_at = 0
- def get_server_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): + def get_server_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.server_descriptor.RelayDescriptor']: """ Provides server descriptors published during the given time range, sorted oldest to newest. @@ -433,7 +434,7 @@ class CollecTor(object): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): yield desc
- def get_extrainfo_descriptors(self, start = None, end = None, cache_to = None, bridge = False, timeout = None, retries = 3): + def get_extrainfo_descriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor']: """ Provides extrainfo descriptors published during the given time range, sorted oldest to newest. @@ -460,7 +461,7 @@ class CollecTor(object): for desc in f.read(cache_to, desc_type, start, end, timeout = timeout, retries = retries): yield desc
- def get_microdescriptors(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_microdescriptors(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: """ Provides microdescriptors estimated to be published during the given time range, sorted oldest to newest. Unlike server/extrainfo descriptors, @@ -494,7 +495,7 @@ class CollecTor(object): for desc in f.read(cache_to, 'microdescriptor', start, end, timeout = timeout, retries = retries): yield desc
- def get_consensus(self, start = None, end = None, cache_to = None, document_handler = DocumentHandler.ENTRIES, version = 3, microdescriptor = False, bridge = False, timeout = None, retries = 3): + def get_consensus(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, version: int = 3, microdescriptor: bool = False, bridge: bool = False, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: """ Provides consensus router status entries published during the given time range, sorted oldest to newest. @@ -538,7 +539,7 @@ class CollecTor(object): for desc in f.read(cache_to, desc_type, start, end, document_handler, timeout = timeout, retries = retries): yield desc
- def get_key_certificates(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_key_certificates(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: """ Directory authority key certificates for the given time range, sorted oldest to newest. @@ -562,7 +563,7 @@ class CollecTor(object): for desc in f.read(cache_to, 'dir-key-certificate-3', start, end, timeout = timeout, retries = retries): yield desc
- def get_bandwidth_files(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_bandwidth_files(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.bandwidth_file.BandwidthFile']: """ Bandwidth authority heuristics for the given time range, sorted oldest to newest. @@ -586,7 +587,7 @@ class CollecTor(object): for desc in f.read(cache_to, 'bandwidth-file', start, end, timeout = timeout, retries = retries): yield desc
- def get_exit_lists(self, start = None, end = None, cache_to = None, timeout = None, retries = 3): + def get_exit_lists(self, start: datetime.datetime = None, end: datetime.datetime = None, cache_to: Optional[str] = None, timeout: Optional[int] = None, retries: Optional[int] = 3) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: """ `TorDNSEL exit lists https://www.torproject.org/projects/tordnsel.html.en`_ for the given time range, sorted oldest to newest. @@ -610,7 +611,7 @@ class CollecTor(object): for desc in f.read(cache_to, 'tordnsel', start, end, timeout = timeout, retries = retries): yield desc
- def index(self, compression = 'best'): + def index(self, compression: Union[str, 'descriptor.Compression'] = 'best') -> Dict[str, Any]: """ Provides the archives available in CollecTor.
@@ -645,7 +646,7 @@ class CollecTor(object):
return self._cached_index
- def files(self, descriptor_type = None, start = None, end = None): + def files(self, descriptor_type: str = None, start: datetime.datetime = None, end: datetime.datetime = None) -> Sequence['stem.descriptor.collector.File']: """ Provides files CollecTor presently has, sorted oldest to newest.
@@ -680,7 +681,7 @@ class CollecTor(object): return matches
@staticmethod - def _files(val, path): + def _files(val: str, path: Sequence[str]) -> Sequence['stem.descriptor.collector.File']: """ Recursively provies files within the index.
diff --git a/stem/descriptor/extrainfo_descriptor.py b/stem/descriptor/extrainfo_descriptor.py index d92bb770..6aca3c29 100644 --- a/stem/descriptor/extrainfo_descriptor.py +++ b/stem/descriptor/extrainfo_descriptor.py @@ -67,6 +67,7 @@ Extra-info descriptors are available from a few sources... ===================== =========== """
+import datetime import functools import hashlib import re @@ -75,6 +76,8 @@ import stem.util.connection import stem.util.enum import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union + from stem.descriptor import ( PGP_BLOCK_END, Descriptor, @@ -163,7 +166,7 @@ _timestamp_re = re.compile('^(.*) \(([0-9]+) s\)( .*)?$') _locale_re = re.compile('^[a-zA-Z0-9\?]{2}$')
-def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, is_bridge = False, validate = False, **kwargs: Any) -> Iterator['stem.descriptor.extrainfo_descriptor.ExtraInfoDescriptor']: """ Iterates over the extra-info descriptors in a file.
@@ -204,7 +207,7 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): break # done parsing file
-def _parse_timestamp_and_interval(keyword, content): +def _parse_timestamp_and_interval(keyword: str, content: str) -> Tuple[datetime.datetime, int, str]: """ Parses a 'YYYY-MM-DD HH:MM:SS (NSEC s) *' entry.
@@ -238,7 +241,7 @@ def _parse_timestamp_and_interval(keyword, content): raise ValueError("%s line's timestamp wasn't parsable: %s" % (keyword, line))
-def _parse_extra_info_line(descriptor, entries): +def _parse_extra_info_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "extra-info" Nickname Fingerprint
value = _value('extra-info', entries) @@ -255,7 +258,7 @@ def _parse_extra_info_line(descriptor, entries): descriptor.fingerprint = extra_info_comp[1]
-def _parse_transport_line(descriptor, entries): +def _parse_transport_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "transport" transportname address:port [arglist] # Everything after the transportname is scrubbed in published bridge # descriptors, so we'll never see it in practice. @@ -301,7 +304,7 @@ def _parse_transport_line(descriptor, entries): descriptor.transport = transports
-def _parse_padding_counts_line(descriptor, entries): +def _parse_padding_counts_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "padding-counts" YYYY-MM-DD HH:MM:SS (NSEC s) key=val key=val...
value = _value('padding-counts', entries) @@ -316,7 +319,7 @@ def _parse_padding_counts_line(descriptor, entries): setattr(descriptor, 'padding_counts', counts)
-def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr, descriptor, entries): +def _parse_dirreq_line(keyword: str, recognized_counts_attr: str, unrecognized_counts_attr: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
recognized_counts = {} @@ -340,7 +343,7 @@ def _parse_dirreq_line(keyword, recognized_counts_attr, unrecognized_counts_attr setattr(descriptor, unrecognized_counts_attr, unrecognized_counts)
-def _parse_dirreq_share_line(keyword, attribute, descriptor, entries): +def _parse_dirreq_share_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
if not value.endswith('%'): @@ -353,7 +356,7 @@ def _parse_dirreq_share_line(keyword, attribute, descriptor, entries): setattr(descriptor, attribute, float(value[:-1]) / 100)
-def _parse_cell_line(keyword, attribute, descriptor, entries): +def _parse_cell_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" num,...,num
value = _value(keyword, entries) @@ -375,7 +378,7 @@ def _parse_cell_line(keyword, attribute, descriptor, entries): raise exc
-def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribute, descriptor, entries): +def _parse_timestamp_and_interval_line(keyword: str, end_attribute: str, interval_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s)
timestamp, interval, _ = _parse_timestamp_and_interval(keyword, _value(keyword, entries)) @@ -383,7 +386,7 @@ def _parse_timestamp_and_interval_line(keyword, end_attribute, interval_attribut setattr(descriptor, interval_attribute, interval)
-def _parse_conn_bi_direct_line(descriptor, entries): +def _parse_conn_bi_direct_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "conn-bi-direct" YYYY-MM-DD HH:MM:SS (NSEC s) BELOW,READ,WRITE,BOTH
value = _value('conn-bi-direct', entries) @@ -401,7 +404,7 @@ def _parse_conn_bi_direct_line(descriptor, entries): descriptor.conn_bi_direct_both = int(stats[3])
-def _parse_history_line(keyword, end_attribute, interval_attribute, values_attribute, descriptor, entries): +def _parse_history_line(keyword: str, end_attribute: str, interval_attribute: str, values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" YYYY-MM-DD HH:MM:SS (NSEC s) NUM,NUM,NUM,NUM,NUM...
value = _value(keyword, entries) @@ -419,7 +422,7 @@ def _parse_history_line(keyword, end_attribute, interval_attribute, values_attri setattr(descriptor, values_attribute, history_values)
-def _parse_port_count_line(keyword, attribute, descriptor, entries): +def _parse_port_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" port=N,port=N,...
value, port_mappings = _value(keyword, entries), {} @@ -434,7 +437,7 @@ def _parse_port_count_line(keyword, attribute, descriptor, entries): setattr(descriptor, attribute, port_mappings)
-def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries): +def _parse_geoip_to_count_line(keyword: str, attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" CC=N,CC=N,... # # The maxmind geoip (https://www.maxmind.com/app/iso3166) has numeric @@ -454,7 +457,7 @@ def _parse_geoip_to_count_line(keyword, attribute, descriptor, entries): setattr(descriptor, attribute, locale_usage)
-def _parse_bridge_ip_versions_line(descriptor, entries): +def _parse_bridge_ip_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value, ip_versions = _value('bridge-ip-versions', entries), {}
for protocol, count in _mappings_for('bridge-ip-versions', value, divider = ','): @@ -466,7 +469,7 @@ def _parse_bridge_ip_versions_line(descriptor, entries): descriptor.ip_versions = ip_versions
-def _parse_bridge_ip_transports_line(descriptor, entries): +def _parse_bridge_ip_transports_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value, ip_transports = _value('bridge-ip-transports', entries), {}
for protocol, count in _mappings_for('bridge-ip-transports', value, divider = ','): @@ -478,7 +481,7 @@ def _parse_bridge_ip_transports_line(descriptor, entries): descriptor.ip_transports = ip_transports
-def _parse_hs_stats(keyword, stat_attribute, extra_attribute, descriptor, entries): +def _parse_hs_stats(keyword: str, stat_attribute: str, extra_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "<keyword>" num key=val key=val...
value, stat, extra = _value(keyword, entries), None, {} @@ -814,7 +817,7 @@ class ExtraInfoDescriptor(Descriptor): 'bridge-ip-transports': _parse_bridge_ip_transports_line, }
- def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: str, validate: bool = False) -> None: """ Extra-info descriptor constructor. By default this validates the descriptor's content as it's parsed. This validation can be disabled to @@ -851,7 +854,7 @@ class ExtraInfoDescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: """ Digest of this descriptor's content. These are referenced by...
@@ -876,13 +879,13 @@ class ExtraInfoDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ExtraInfoDescriptor subclass')
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: return REQUIRED_FIELDS
- def _first_keyword(self): + def _first_keyword(self) -> str: return 'extra-info'
- def _last_keyword(self): + def _last_keyword(self) -> str: return 'router-signature'
@@ -917,7 +920,7 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor): })
@classmethod - def content(cls, attr = None, exclude = (), sign = False, signing_key = None): + def content(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> str: base_header = ( ('extra-info', '%s %s' % (_random_nickname(), _random_fingerprint())), ('published', _random_date()), @@ -938,11 +941,11 @@ class RelayExtraInfoDescriptor(ExtraInfoDescriptor): ))
@classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None): + def create(cls: Type['stem.descriptor.extrainfo.RelayExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None) -> 'stem.descriptor.extrainfo.RelayExtraInfoDescriptor': return cls(cls.content(attr, exclude, sign, signing_key), validate = validate)
@functools.lru_cache() - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: if hash_type == DigestHash.SHA1: # our digest is calculated from everything except our signature
@@ -986,7 +989,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): })
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.extrainfo.BridgeExtraInfoDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('extra-info', 'ec2bridgereaac65a3 %s' % _random_fingerprint()), ('published', _random_date()), @@ -994,7 +997,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): ('router-digest', _random_fingerprint()), ))
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest elif hash_type == DigestHash.SHA256 and encoding == DigestEncoding.BASE64: @@ -1002,7 +1005,7 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor): else: raise NotImplementedError('Bridge extrainfo digests are only available as sha1/hex and sha256/base64, not %s/%s' % (hash_type, encoding))
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: excluded_fields = [ 'router-signature', ] @@ -1013,5 +1016,5 @@ class BridgeExtraInfoDescriptor(ExtraInfoDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _last_keyword(self): + def _last_keyword(self) -> str: return None diff --git a/stem/descriptor/hidden_service.py b/stem/descriptor/hidden_service.py index 75a78d2e..8d23838e 100644 --- a/stem/descriptor/hidden_service.py +++ b/stem/descriptor/hidden_service.py @@ -51,6 +51,7 @@ import stem.util.tor_tools
from stem.client.datatype import CertType from stem.descriptor.certificate import ExtensionType, Ed25519Extension, Ed25519Certificate, Ed25519CertificateV1 +from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
from stem.descriptor import ( PGP_BLOCK_END, @@ -162,7 +163,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s """
@staticmethod - def parse(content): + def parse(content: str) -> 'stem.descriptor.hidden_service.IntroductionPointV3': """ Parses an introduction point from its descriptor content.
@@ -200,7 +201,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, legacy_key, legacy_key_cert)
@staticmethod - def create_for_address(address, port, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None): + def create_for_address(address: str, port: int, expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': """ Simplified constructor for a single address/port link specifier.
@@ -232,7 +233,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3.create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None)
@staticmethod - def create_for_link_specifiers(link_specifiers, expiration = None, onion_key = None, enc_key = None, auth_key = None, signing_key = None): + def create_for_link_specifiers(link_specifiers: Sequence['stem.client.datatype.LinkSpecifier'], expiration: Optional[datetime.datetime] = None, onion_key: Optional[str] = None, enc_key: Optional[str] = None, auth_key: Optional[str] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None) -> 'stem.descriptor.hidden_service.IntroductionPointV3': """ Simplified constructor. For more sophisticated use cases you can use this as a template for how introduction points are properly created. @@ -271,7 +272,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3(link_specifiers, onion_key, auth_key_cert, enc_key, enc_key_cert, None, None)
- def encode(self): + def encode(self) -> str: """ Descriptor representation of this introduction point.
@@ -299,7 +300,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return '\n'.join(lines)
- def onion_key(self): + def onion_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': """ Provides our ntor introduction point public key.
@@ -312,7 +313,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.onion_key_raw, x25519 = True)
- def auth_key(self): + def auth_key(self) -> 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey': """ Provides our authentication certificate's public key.
@@ -325,7 +326,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.auth_key_cert.key, ed25519 = True)
- def enc_key(self): + def enc_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': """ Provides our encryption key.
@@ -338,7 +339,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return IntroductionPointV3._key_as(self.enc_key_raw, x25519 = True)
- def legacy_key(self): + def legacy_key(self) -> 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey': """ Provides our legacy introduction point public key.
@@ -352,7 +353,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return IntroductionPointV3._key_as(self.legacy_key_raw, x25519 = True)
@staticmethod - def _key_as(value, x25519 = False, ed25519 = False): + def _key_as(value: str, x25519: bool = False, ed25519: bool = False) -> Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']: if value is None or (not x25519 and not ed25519): return value
@@ -375,7 +376,7 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s return Ed25519PublicKey.from_public_bytes(value)
@staticmethod - def _parse_link_specifiers(content): + def _parse_link_specifiers(content: str) -> 'stem.client.datatype.LinkSpecifier': try: content = base64.b64decode(content) except Exception as exc: @@ -393,16 +394,16 @@ class IntroductionPointV3(collections.namedtuple('IntroductionPointV3', ['link_s
return link_specifiers
- def __hash__(self): + def __hash__(self) -> int: if not hasattr(self, '_hash'): self._hash = hash(self.encode())
return self._hash
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, IntroductionPointV3) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -417,22 +418,22 @@ class AuthorizedClient(object): :var str cookie: base64 encoded authentication cookie """
- def __init__(self, id = None, iv = None, cookie = None): + def __init__(self, id: str = None, iv: str = None, cookie: str = None) -> None: self.id = stem.util.str_tools._to_unicode(id if id else base64.b64encode(os.urandom(8)).rstrip(b'=')) self.iv = stem.util.str_tools._to_unicode(iv if iv else base64.b64encode(os.urandom(16)).rstrip(b'=')) self.cookie = stem.util.str_tools._to_unicode(cookie if cookie else base64.b64encode(os.urandom(16)).rstrip(b'='))
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'id', 'iv', 'cookie', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, AuthorizedClient) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
-def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, desc_type: str = None, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.hidden_service.HiddenServiceDescriptor']: """ Iterates over the hidden service descriptors in a file.
@@ -442,7 +443,7 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): **True**, skips these checks otherwise :param dict kwargs: additional arguments for the descriptor constructor
- :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptorV2` + :returns: iterator for :class:`~stem.descriptor.hidden_service.HiddenServiceDescriptor` instances in the file
:raises: @@ -472,7 +473,7 @@ def _parse_file(descriptor_file, desc_type = None, validate = False, **kwargs): break # done parsing file
-def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, blinded_key): +def _decrypt_layer(encrypted_block: bytes, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> str: if encrypted_block.startswith('-----BEGIN MESSAGE-----\n') and encrypted_block.endswith('\n-----END MESSAGE-----'): encrypted_block = encrypted_block[24:-22]
@@ -499,7 +500,7 @@ def _decrypt_layer(encrypted_block, constant, revision_counter, subcredential, b return stem.util.str_tools._to_unicode(plaintext)
-def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded_key): +def _encrypt_layer(plaintext: str, constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: salt = os.urandom(16) cipher, mac_for = _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt)
@@ -510,7 +511,7 @@ def _encrypt_layer(plaintext, constant, revision_counter, subcredential, blinded return b'-----BEGIN MESSAGE-----\n%s\n-----END MESSAGE-----' % b'\n'.join(stem.util.str_tools._split_by_length(encoded, 64))
-def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt): +def _layer_cipher(constant: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes, salt: bytes) -> Tuple['cryptography.hazmat.primitives.ciphers.Cipher', Callable[[bytes], bytes]]: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -530,7 +531,7 @@ def _layer_cipher(constant, revision_counter, subcredential, blinded_key, salt): return cipher, lambda ciphertext: hashlib.sha3_256(mac_prefix + ciphertext).digest()
-def _parse_protocol_versions_line(descriptor, entries): +def _parse_protocol_versions_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value('protocol-versions', entries)
try: @@ -545,7 +546,7 @@ def _parse_protocol_versions_line(descriptor, entries): descriptor.protocol_versions = versions
-def _parse_introduction_points_line(descriptor, entries): +def _parse_introduction_points_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: _, block_type, block_contents = entries['introduction-points'][0]
if not block_contents or block_type != 'MESSAGE': @@ -559,7 +560,7 @@ def _parse_introduction_points_line(descriptor, entries): raise ValueError("'introduction-points' isn't base64 encoded content:\n%s" % block_contents)
-def _parse_v3_outer_clients(descriptor, entries): +def _parse_v3_outer_clients(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "auth-client" client-id iv encrypted-cookie
clients = {} @@ -575,7 +576,7 @@ def _parse_v3_outer_clients(descriptor, entries): descriptor.clients = clients
-def _parse_v3_inner_formats(descriptor, entries): +def _parse_v3_inner_formats(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value, formats = _value('create2-formats', entries), []
for entry in value.split(' '): @@ -587,7 +588,7 @@ def _parse_v3_inner_formats(descriptor, entries): descriptor.formats = formats
-def _parse_v3_introduction_points(descriptor, entries): +def _parse_v3_introduction_points(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: if hasattr(descriptor, '_unparsed_introduction_points'): introduction_points = [] remaining = descriptor._unparsed_introduction_points @@ -687,7 +688,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): }
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('rendezvous-service-descriptor', 'y3olqqblqw2gbh6phimfuiroechjjafa'), ('version', '2'), @@ -701,10 +702,10 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): ))
@classmethod - def create(cls, attr = None, exclude = (), validate = True): + def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV2'], attr: Mapping[str, str] = None, exclude: Sequence[str] = (), validate: bool = True) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV2': return cls(cls.content(attr, exclude), validate = validate, skip_crypto_validation = True)
- def __init__(self, raw_contents, validate = False, skip_crypto_validation = False): + def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(HiddenServiceDescriptorV2, self).__init__(raw_contents, lazy_load = not validate) entries = _descriptor_components(raw_contents, validate, non_ascii_fields = ('introduction-points'))
@@ -736,10 +737,12 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): self._entries = entries
@functools.lru_cache() - def introduction_points(self, authentication_cookie = None): + def introduction_points(self, authentication_cookie: Optional[str] = None) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: """ Provided this service's introduction points.
+ :param str authentication_cookie: base64 encoded authentication cookie + :returns: **list** of :class:`~stem.descriptor.hidden_service.IntroductionPointV2`
:raises: @@ -774,7 +777,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): return HiddenServiceDescriptorV2._parse_introduction_points(content)
@staticmethod - def _decrypt_basic_auth(content, authentication_cookie): + def _decrypt_basic_auth(content: bytes, authentication_cookie: str) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -821,7 +824,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): return content # nope, unable to decrypt the content
@staticmethod - def _decrypt_stealth_auth(content, authentication_cookie): + def _decrypt_stealth_auth(content: bytes, authentication_cookie: str) -> bytes: try: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend @@ -836,7 +839,7 @@ class HiddenServiceDescriptorV2(HiddenServiceDescriptor): return decryptor.update(encrypted) + decryptor.finalize()
@staticmethod - def _parse_introduction_points(content): + def _parse_introduction_points(content: bytes) -> Sequence['stem.descriptor.hidden_service.IntroductionPointV2']: """ Provides the parsed list of IntroductionPointV2 for the unencrypted content. """ @@ -928,7 +931,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): }
@classmethod - def content(cls, attr = None, exclude = (), sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None): + def content(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> str: """ Hidden service v3 descriptors consist of three parts:
@@ -1023,10 +1026,10 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return desc_content
@classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, outer_layer = None, identity_key = None, signing_key = None, signing_cert = None, revision_counter = None, blinding_nonce = None): + def create(cls: Type['stem.descriptor.hidden_service.HiddenServiceDescriptorV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None, identity_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_key: Optional['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'] = None, signing_cert: Optional['stem.descriptor.Ed25519CertificateV1'] = None, revision_counter: int = None, blinding_nonce: bytes = None) -> 'stem.descriptor.hidden_service.HiddenServiceDescriptorV3': return cls(cls.content(attr, exclude, sign, inner_layer, outer_layer, identity_key, signing_key, signing_cert, revision_counter, blinding_nonce), validate = validate)
- def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: bytes, validate: bool = False) -> None: super(HiddenServiceDescriptorV3, self).__init__(raw_contents, lazy_load = not validate)
self._inner_layer = None @@ -1054,7 +1057,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): else: self._entries = entries
- def decrypt(self, onion_address): + def decrypt(self, onion_address: str) -> 'stem.descriptor.hidden_service.InnerLayer': """ Decrypt this descriptor. Hidden serice descriptors contain two encryption layers (:class:`~stem.descriptor.hidden_service.OuterLayer` and @@ -1086,7 +1089,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return self._inner_layer
@staticmethod - def address_from_identity_key(key, suffix = True): + def address_from_identity_key(key: Union[bytes, 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey'], suffix: bool = True) -> str: """ Converts a hidden service identity key into its address. This accepts all key formats (private, public, or public bytes). @@ -1094,7 +1097,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): :param Ed25519PublicKey,Ed25519PrivateKey,bytes key: hidden service identity key :param bool suffix: includes the '.onion' suffix if true, excluded otherwise
- :returns: **unicode** hidden service address + :returns: **str** hidden service address
:raises: **ImportError** if key is a cryptographic type and ed25519 support is unavailable @@ -1109,7 +1112,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return stem.util.str_tools._to_unicode(onion_address + b'.onion' if suffix else onion_address).lower()
@staticmethod - def identity_key_from_address(onion_address): + def identity_key_from_address(onion_address: str) -> bool: """ Converts a hidden service address into its public identity key.
@@ -1146,7 +1149,7 @@ class HiddenServiceDescriptorV3(HiddenServiceDescriptor): return pubkey
@staticmethod - def _subcredential(identity_key, blinded_key): + def _subcredential(identity_key: bytes, blinded_key: bytes) -> bytes: # credential = H('credential' | public-identity-key) # subcredential = H('subcredential' | credential | blinded-public-key)
@@ -1186,11 +1189,11 @@ class OuterLayer(Descriptor): }
@staticmethod - def _decrypt(encrypted, revision_counter, subcredential, blinded_key): + def _decrypt(encrypted: bytes, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> 'stem.descriptor.hidden_service.OuterLayer': plaintext = _decrypt_layer(encrypted, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key) return OuterLayer(plaintext)
- def _encrypt(self, revision_counter, subcredential, blinded_key): + def _encrypt(self, revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: # Spec mandated padding: "Before encryption the plaintext is padded with # NUL bytes to the nearest multiple of 10k bytes."
@@ -1201,7 +1204,7 @@ class OuterLayer(Descriptor): return _encrypt_layer(content, b'hsdir-superencrypted-data', revision_counter, subcredential, blinded_key)
@classmethod - def content(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None): + def content(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: Optional[int] = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> str: try: from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey @@ -1235,10 +1238,10 @@ class OuterLayer(Descriptor): ))
@classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, inner_layer = None, revision_counter = None, authorized_clients = None, subcredential = None, blinded_key = None): + def create(cls: Type['stem.descriptor.hidden_service.OuterLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, inner_layer: Optional['stem.descriptor.hidden_service.InnerLayer'] = None, revision_counter: int = None, authorized_clients: Optional[Sequence['stem.descriptor.hidden_service.AuthorizedClient']] = None, subcredential: bytes = None, blinded_key: bytes = None) -> 'stem.descriptor.hidden_service.OuterLayer': return cls(cls.content(attr, exclude, validate, sign, inner_layer, revision_counter, authorized_clients, subcredential, blinded_key), validate = validate)
- def __init__(self, content, validate = False): + def __init__(self, content: bytes, validate: bool = False) -> None: content = stem.util.str_tools._to_bytes(content).rstrip(b'\x00') # strip null byte padding
super(OuterLayer, self).__init__(content, lazy_load = not validate) @@ -1282,7 +1285,7 @@ class InnerLayer(Descriptor): }
@staticmethod - def _decrypt(outer_layer, revision_counter, subcredential, blinded_key): + def _decrypt(outer_layer: 'stem.descriptor.hidden_service.OuterLayer', revision_counter: int, subcredential: bytes, blinded_key: bytes) -> bytes: plaintext = _decrypt_layer(outer_layer.encrypted, b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key) return InnerLayer(plaintext, validate = True, outer_layer = outer_layer)
@@ -1292,7 +1295,7 @@ class InnerLayer(Descriptor): return _encrypt_layer(self.get_bytes(), b'hsdir-encrypted-data', revision_counter, subcredential, blinded_key)
@classmethod - def content(cls, attr = None, exclude = (), introduction_points = None): + def content(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> str: if introduction_points: suffix = '\n' + '\n'.join(map(IntroductionPointV3.encode, introduction_points)) else: @@ -1303,10 +1306,10 @@ class InnerLayer(Descriptor): )) + stem.util.str_tools._to_bytes(suffix)
@classmethod - def create(cls, attr = None, exclude = (), validate = True, introduction_points = None): + def create(cls: Type['stem.descriptor.hidden_service.InnerLayer'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, introduction_points: Optional[Sequence['stem.descriptor.hidden_service.IntroductionPointV3']] = None) -> 'stem.descriptor.hidden_service.InnerLayer': return cls(cls.content(attr, exclude, introduction_points), validate = validate)
- def __init__(self, content, validate = False, outer_layer = None): + def __init__(self, content: bytes, validate: bool = False, outer_layer: Optional['stem.descriptor.hidden_service.OuterLayer'] = None) -> None: super(InnerLayer, self).__init__(content, lazy_load = not validate) self.outer = outer_layer
@@ -1331,7 +1334,7 @@ class InnerLayer(Descriptor): self._entries = entries
-def _blinded_pubkey(identity_key, blinding_nonce): +def _blinded_pubkey(identity_key: bytes, blinding_nonce: bytes) -> bytes: from stem.util import ed25519
mult = 2 ** (ed25519.b - 2) + sum(2 ** i * ed25519.bit(blinding_nonce, i) for i in range(3, ed25519.b - 2)) @@ -1339,7 +1342,7 @@ def _blinded_pubkey(identity_key, blinding_nonce): return ed25519.encodepoint(ed25519.scalarmult(P, mult))
-def _blinded_sign(msg, identity_key, blinded_key, blinding_nonce): +def _blinded_sign(msg: bytes, identity_key: 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', blinded_key: bytes, blinding_nonce: bytes) -> bytes: try: from cryptography.hazmat.primitives import serialization except ImportError: diff --git a/stem/descriptor/microdescriptor.py b/stem/descriptor/microdescriptor.py index c62a3d0d..c2c104ff 100644 --- a/stem/descriptor/microdescriptor.py +++ b/stem/descriptor/microdescriptor.py @@ -69,6 +69,8 @@ import hashlib
import stem.exit_policy
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Type, Union + from stem.descriptor import ( Descriptor, DigestHash, @@ -102,7 +104,7 @@ SINGLE_FIELDS = ( )
-def _parse_file(descriptor_file, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.microdescriptor.Microdescriptor']: """ Iterates over the microdescriptors in a file.
@@ -159,7 +161,7 @@ def _parse_file(descriptor_file, validate = False, **kwargs): break # done parsing descriptors
-def _parse_id_line(descriptor, entries): +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: identities = {}
for entry in _values('id', entries): @@ -244,7 +246,7 @@ class Microdescriptor(Descriptor): }
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.microdescriptor.Microdescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('onion-key', _random_crypto_blob('RSA PUBLIC KEY')), )) @@ -260,7 +262,7 @@ class Microdescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type = DigestHash.SHA256, encoding = DigestEncoding.BASE64): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA256, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.BASE64) -> Union[str, 'hashlib.HASH']: """ Digest of this microdescriptor. These are referenced by...
@@ -285,7 +287,7 @@ class Microdescriptor(Descriptor): raise NotImplementedError('Microdescriptor digests are only available in sha1 and sha256, not %s' % hash_type)
@functools.lru_cache() - def get_annotations(self): + def get_annotations(self) -> Dict[str, str]: """ Provides content that appeared prior to the descriptor. If this comes from the cached-microdescs then this commonly contains content like... @@ -308,7 +310,7 @@ class Microdescriptor(Descriptor):
return annotation_dict
- def get_annotation_lines(self): + def get_annotation_lines(self) -> Sequence[str]: """ Provides the lines of content that appeared prior to the descriptor. This is the same as the @@ -320,7 +322,7 @@ class Microdescriptor(Descriptor):
return self._annotation_lines
- def _check_constraints(self, entries): + def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. @@ -341,5 +343,5 @@ class Microdescriptor(Descriptor): if 'onion-key' != list(entries.keys())[0]: raise ValueError("Microdescriptor must start with a 'onion-key' entry")
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'microdescriptors' if is_plural else 'microdescriptor' diff --git a/stem/descriptor/networkstatus.py b/stem/descriptor/networkstatus.py index 77c6d612..48940987 100644 --- a/stem/descriptor/networkstatus.py +++ b/stem/descriptor/networkstatus.py @@ -65,6 +65,8 @@ import stem.util.str_tools import stem.util.tor_tools import stem.version
+from typing import Any, BinaryIO, Callable, Dict, Iterator, Mapping, Optional, Sequence, Type + from stem.descriptor import ( PGP_BLOCK_END, Descriptor, @@ -293,7 +295,7 @@ class DocumentDigest(collections.namedtuple('DocumentDigest', ['flavor', 'algori """
-def _parse_file(document_file, document_type = None, validate = False, is_microdescriptor = False, document_handler = DocumentHandler.ENTRIES, **kwargs): +def _parse_file(document_file: BinaryIO, document_type: Optional[Type['stem.descriptor.networkstatus.NetworkStatusDocument']] = None, validate: bool = False, is_microdescriptor: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = DocumentHandler.ENTRIES, **kwargs: Any) -> 'stem.descriptor.networkstatus.NetworkStatusDocument': """ Parses a network status and iterates over the RouterStatusEntry in it. The document that these instances reference have an empty 'routers' attribute to @@ -372,7 +374,7 @@ def _parse_file(document_file, document_type = None, validate = False, is_microd raise ValueError('Unrecognized document_handler: %s' % document_handler)
-def _parse_file_key_certs(certificate_file, validate = False): +def _parse_file_key_certs(certificate_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.KeyCertificate']: """ Parses a file containing one or more authority key certificates.
@@ -401,7 +403,7 @@ def _parse_file_key_certs(certificate_file, validate = False): break # done parsing file
-def _parse_file_detached_sigs(detached_signature_file, validate = False): +def _parse_file_detached_sigs(detached_signature_file: BinaryIO, validate: bool = False) -> Iterator['stem.descriptor.networkstatus.DetachedSignature']: """ Parses a file containing one or more detached signatures.
@@ -431,7 +433,7 @@ class NetworkStatusDocument(Descriptor): Common parent for network status documents. """
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> None: """ Digest of this descriptor's content. These are referenced by...
@@ -458,8 +460,8 @@ class NetworkStatusDocument(Descriptor): raise NotImplementedError('Network status document digests are only available in sha1 and sha256, not %s' % hash_type)
-def _parse_version_line(keyword, attribute, expected_version): - def _parse(descriptor, entries): +def _parse_version_line(keyword: str, attribute: str, expected_version: int) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries)
if not value.isdigit(): @@ -473,7 +475,7 @@ def _parse_version_line(keyword, attribute, expected_version): return _parse
-def _parse_dir_source_line(descriptor, entries): +def _parse_dir_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value('dir-source', entries) dir_source_comp = value.split()
@@ -493,7 +495,7 @@ def _parse_dir_source_line(descriptor, entries): descriptor.dir_port = None if dir_source_comp[2] == '0' else int(dir_source_comp[2])
-def _parse_additional_digests(descriptor, entries): +def _parse_additional_digests(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: digests = []
for val in _values('additional-digest', entries): @@ -507,7 +509,7 @@ def _parse_additional_digests(descriptor, entries): descriptor.additional_digests = digests
-def _parse_additional_signatures(descriptor, entries): +def _parse_additional_signatures(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: signatures = []
for val, block_type, block_contents in entries['additional-signature']: @@ -598,7 +600,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): }
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('network-status-version', '2'), ('dir-source', '%s %s 80' % (_random_ipv4_address(), _random_ipv4_address())), @@ -610,7 +612,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): ('directory-signature', 'moria2' + _random_crypto_blob('SIGNATURE')), ))
- def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: bytes, validate: bool = False) -> None: super(NetworkStatusDocumentV2, self).__init__(raw_content, lazy_load = not validate)
# Splitting the document from the routers. Unlike v3 documents we're not @@ -646,7 +648,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): else: self._entries = entries
- def _check_constraints(self, entries): + def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: required_fields = [field for (field, is_mandatory) in NETWORK_STATUS_V2_FIELDS if is_mandatory] for keyword in required_fields: if keyword not in entries: @@ -662,7 +664,7 @@ class NetworkStatusDocumentV2(NetworkStatusDocument): raise ValueError("Network status document (v2) are expected to start with a 'network-status-version' line:\n%s" % str(self))
-def _parse_header_network_status_version_line(descriptor, entries): +def _parse_header_network_status_version_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "network-status-version" version
value = _value('network-status-version', entries) @@ -683,7 +685,7 @@ def _parse_header_network_status_version_line(descriptor, entries): raise ValueError("Expected a version 3 network status document, got version '%s' instead" % descriptor.version)
-def _parse_header_vote_status_line(descriptor, entries): +def _parse_header_vote_status_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "vote-status" type # # The consensus-method and consensus-methods fields are optional since @@ -700,7 +702,7 @@ def _parse_header_vote_status_line(descriptor, entries): raise ValueError("A network status document's vote-status line can only be 'consensus' or 'vote', got '%s' instead" % value)
-def _parse_header_consensus_methods_line(descriptor, entries): +def _parse_header_consensus_methods_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "consensus-methods" IntegerList
if descriptor._lazy_loading and descriptor.is_vote: @@ -717,7 +719,7 @@ def _parse_header_consensus_methods_line(descriptor, entries): descriptor.consensus_methods = consensus_methods
-def _parse_header_consensus_method_line(descriptor, entries): +def _parse_header_consensus_method_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "consensus-method" Integer
if descriptor._lazy_loading and descriptor.is_consensus: @@ -731,7 +733,7 @@ def _parse_header_consensus_method_line(descriptor, entries): descriptor.consensus_method = int(value)
-def _parse_header_voting_delay_line(descriptor, entries): +def _parse_header_voting_delay_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "voting-delay" VoteSeconds DistSeconds
value = _value('voting-delay', entries) @@ -744,8 +746,8 @@ def _parse_header_voting_delay_line(descriptor, entries): raise ValueError("A network status document's 'voting-delay' line must be a pair of integer values, but was '%s'" % value)
-def _parse_versions_line(keyword, attribute): - def _parse(descriptor, entries): +def _parse_versions_line(keyword: str, attribute: str) -> Callable[['stem.descriptor.Descriptor', Dict[str, Sequence[str]]], None]: + def _parse(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value, entries = _value(keyword, entries), []
for entry in value.split(','): @@ -759,7 +761,7 @@ def _parse_versions_line(keyword, attribute): return _parse
-def _parse_header_flag_thresholds_line(descriptor, entries): +def _parse_header_flag_thresholds_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "flag-thresholds" SP THRESHOLDS
value, thresholds = _value('flag-thresholds', entries).strip(), {} @@ -782,7 +784,7 @@ def _parse_header_flag_thresholds_line(descriptor, entries): descriptor.flag_thresholds = thresholds
-def _parse_header_parameters_line(descriptor, entries): +def _parse_header_parameters_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "params" [Parameters] # Parameter ::= Keyword '=' Int32 # Int32 ::= A decimal integer between -2147483648 and 2147483647. @@ -798,7 +800,7 @@ def _parse_header_parameters_line(descriptor, entries): descriptor._check_params_constraints()
-def _parse_directory_footer_line(descriptor, entries): +def _parse_directory_footer_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # nothing to parse, simply checking that we don't have a value
value = _value('directory-footer', entries) @@ -807,7 +809,7 @@ def _parse_directory_footer_line(descriptor, entries): raise ValueError("A network status document's 'directory-footer' line shouldn't have any content, got 'directory-footer %s'" % value)
-def _parse_footer_directory_signature_line(descriptor, entries): +def _parse_footer_directory_signature_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: signatures = []
for sig_value, block_type, block_contents in entries['directory-signature']: @@ -828,7 +830,7 @@ def _parse_footer_directory_signature_line(descriptor, entries): descriptor.signatures = signatures
-def _parse_package_line(descriptor, entries): +def _parse_package_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: package_versions = []
for value, _, _ in entries['package']: @@ -849,7 +851,7 @@ def _parse_package_line(descriptor, entries): descriptor.packages = package_versions
-def _parsed_shared_rand_commit(descriptor, entries): +def _parsed_shared_rand_commit(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "shared-rand-commit" Version AlgName Identity Commit [Reveal]
commitments = [] @@ -871,7 +873,7 @@ def _parsed_shared_rand_commit(descriptor, entries): descriptor.shared_randomness_commitments = commitments
-def _parse_shared_rand_previous_value(descriptor, entries): +def _parse_shared_rand_previous_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "shared-rand-previous-value" NumReveals Value
value = _value('shared-rand-previous-value', entries) @@ -884,7 +886,7 @@ def _parse_shared_rand_previous_value(descriptor, entries): raise ValueError("A network status document's 'shared-rand-previous-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_shared_rand_current_value(descriptor, entries): +def _parse_shared_rand_current_value(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "shared-rand-current-value" NumReveals Value
value = _value('shared-rand-current-value', entries) @@ -897,7 +899,7 @@ def _parse_shared_rand_current_value(descriptor, entries): raise ValueError("A network status document's 'shared-rand-current-value' line must be a pair of values, the first an integer but was '%s'" % value)
-def _parse_bandwidth_file_headers(descriptor, entries): +def _parse_bandwidth_file_headers(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "bandwidth-file-headers" KeyValues # KeyValues ::= "" | KeyValue | KeyValues SP KeyValue # KeyValue ::= Keyword '=' Value @@ -912,7 +914,7 @@ def _parse_bandwidth_file_headers(descriptor, entries): descriptor.bandwidth_file_headers = results
-def _parse_bandwidth_file_digest(descriptor, entries): +def _parse_bandwidth_file_digest(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "bandwidth-file-digest" 1*(SP algorithm "=" digest)
value = _value('bandwidth-file-digest', entries) @@ -1096,7 +1098,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): }
@classmethod - def content(cls, attr = None, exclude = (), authorities = None, routers = None): + def content(cls: Type['stem.descriptor.networkstatus.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), authorities: Optional[Sequence['stem.descriptor.networkstatus.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> str: attr = {} if attr is None else dict(attr) is_vote = attr.get('vote-status') == 'vote'
@@ -1168,10 +1170,10 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): return desc_content
@classmethod - def create(cls, attr = None, exclude = (), validate = True, authorities = None, routers = None): + def create(cls: Type['stem.descriptor.NetworkStatusDocumentV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, authorities: Optional[Sequence['stem.directory.DirectoryAuthority']] = None, routers: Optional[Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']] = None) -> 'stem.descriptor.NetworkStatusDocumentV3': return cls(cls.content(attr, exclude, authorities, routers), validate = validate)
- def __init__(self, raw_content, validate = False, default_params = True): + def __init__(self, raw_content: str, validate: bool = False, default_params: bool = True) -> None: """ Parse a v3 network status document.
@@ -1213,7 +1215,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): self.routers = dict((desc.fingerprint, desc) for desc in router_iter) self._footer(document_file, validate)
- def type_annotation(self): + def type_annotation(self) -> 'stem.descriptor.TypeAnnotation': if isinstance(self, BridgeNetworkStatusDocument): return TypeAnnotation('bridge-network-status', 1, 0) elif not self.is_microdescriptor: @@ -1225,7 +1227,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return TypeAnnotation('network-status-microdesc-consensus-3', 1, 0)
- def is_valid(self): + def is_valid(self) -> bool: """ Checks if the current time is between this document's **valid_after** and **valid_until** timestamps. To be valid means the information within this @@ -1239,7 +1241,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.valid_until
- def is_fresh(self): + def is_fresh(self) -> bool: """ Checks if the current time is between this document's **valid_after** and **fresh_until** timestamps. To be fresh means this should be the latest @@ -1253,7 +1255,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return self.valid_after < datetime.datetime.utcnow() < self.fresh_until
- def validate_signatures(self, key_certs): + def validate_signatures(self, key_certs: Sequence['stem.descriptor.networkstatus.KeyCertificates']) -> None: """ Validates we're properly signed by the signing certificates.
@@ -1287,7 +1289,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): if valid_digests < required_digests: raise ValueError('Network Status Document has %i valid signatures out of %i total, needed %i' % (valid_digests, total_digests, required_digests))
- def get_unrecognized_lines(self): + def get_unrecognized_lines(self) -> Sequence[str]: if self._lazy_loading: self._parse(self._header_entries, False, parser_for_line = self._HEADER_PARSER_FOR_LINE) self._parse(self._footer_entries, False, parser_for_line = self._FOOTER_PARSER_FOR_LINE) @@ -1295,7 +1297,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument):
return super(NetworkStatusDocumentV3, self).get_unrecognized_lines()
- def meets_consensus_method(self, method): + def meets_consensus_method(self, method: int) -> bool: """ Checks if we meet the given consensus-method. This works for both votes and consensuses, checking our 'consensus-method' and 'consensus-methods' @@ -1313,7 +1315,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): else: return False # malformed document
- def _header(self, document_file, validate): + def _header(self, document_file: BinaryIO, validate: bool) -> None: content = bytes.join(b'', _read_until_keywords((AUTH_START, ROUTERS_START, FOOTER_START), document_file)) entries = _descriptor_components(content, validate) header_fields = [attr[0] for attr in HEADER_STATUS_DOCUMENT_FIELDS] @@ -1347,7 +1349,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): self._header_entries = entries self._entries.update(entries)
- def _footer(self, document_file, validate): + def _footer(self, document_file: BinaryIO, validate: bool) -> None: entries = _descriptor_components(document_file.read(), validate) footer_fields = [attr[0] for attr in FOOTER_STATUS_DOCUMENT_FIELDS]
@@ -1379,7 +1381,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): self._footer_entries = entries self._entries.update(entries)
- def _check_params_constraints(self): + def _check_params_constraints(self) -> None: """ Checks that the params we know about are within their documented ranges. """ @@ -1398,7 +1400,7 @@ class NetworkStatusDocumentV3(NetworkStatusDocument): raise ValueError("'%s' value on the params line must be in the range of %i - %i, was %i" % (key, minimum, maximum, value))
-def _check_for_missing_and_disallowed_fields(document, entries, fields): +def _check_for_missing_and_disallowed_fields(document: 'stem.descriptor.networkstatus.NetworkStatusDocumentV3', entries: Mapping[str, str], fields: Sequence[str]) -> None: """ Checks that we have mandatory fields for our type, and that we don't have any fields exclusive to the other (ie, no vote-only fields appear in a @@ -1431,7 +1433,7 @@ def _check_for_missing_and_disallowed_fields(document, entries, fields): raise ValueError("Network status document has fields that shouldn't appear in this document type or version: %s" % ', '.join(disallowed_fields))
-def _parse_int_mappings(keyword, value, validate): +def _parse_int_mappings(keyword: str, value: str, validate: bool) -> Dict[str, int]: # Parse a series of 'key=value' entries, checking the following: # - values are integers # - keys are sorted in lexical order @@ -1461,7 +1463,7 @@ def _parse_int_mappings(keyword, value, validate): return results
-def _parse_dirauth_source_line(descriptor, entries): +def _parse_dirauth_source_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "dir-source" nickname identity address IP dirport orport
value = _value('dir-source', entries) @@ -1580,7 +1582,7 @@ class DirectoryAuthority(Descriptor): }
@classmethod - def content(cls, attr = None, exclude = (), is_vote = False): + def content(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), is_vote: bool = False) -> str: attr = {} if attr is None else dict(attr)
# include mandatory 'vote-digest' if a consensus @@ -1599,10 +1601,10 @@ class DirectoryAuthority(Descriptor): return content
@classmethod - def create(cls, attr = None, exclude = (), validate = True, is_vote = False): + def create(cls: Type['stem.descriptor.networkstatus.DirectoryAuthority'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, is_vote: bool = False) -> 'stem.descriptor.networkstatus.DirectoryAuthority': return cls(cls.content(attr, exclude, is_vote), validate = validate, is_vote = is_vote)
- def __init__(self, raw_content, validate = False, is_vote = False): + def __init__(self, raw_content: str, validate: bool = False, is_vote: bool = False) -> None: """ Parse a directory authority entry in a v3 network status document.
@@ -1677,7 +1679,7 @@ class DirectoryAuthority(Descriptor): self._entries = entries
-def _parse_dir_address_line(descriptor, entries): +def _parse_dir_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "dir-address" IPPort
value = _value('dir-address', entries) @@ -1752,7 +1754,7 @@ class KeyCertificate(Descriptor): }
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.KeyCertificate'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('dir-key-certificate-version', '3'), ('fingerprint', _random_fingerprint()), @@ -1764,7 +1766,7 @@ class KeyCertificate(Descriptor): ('dir-key-certification', _random_crypto_blob('SIGNATURE')), ))
- def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: str, validate: str = False) -> None: super(KeyCertificate, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate)
@@ -1805,7 +1807,7 @@ class DocumentSignature(object): :raises: **ValueError** if a validity check fails """
- def __init__(self, method, identity, key_digest, signature, flavor = None, validate = False): + def __init__(self, method: str, identity: str, key_digest: str, signature: str, flavor: Optional[str] = None, validate: bool = False) -> None: # Checking that these attributes are valid. Technically the key # digest isn't a fingerprint, but it has the same characteristics.
@@ -1822,7 +1824,7 @@ class DocumentSignature(object): self.signature = signature self.flavor = flavor
- def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> bool: if not isinstance(other, DocumentSignature): return False
@@ -1832,19 +1834,19 @@ class DocumentSignature(object):
return method(True, True) # we're equal
- def __hash__(self): + def __hash__(self) -> int: return hash(str(self).strip())
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
- def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s < o)
- def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s <= o)
@@ -1898,7 +1900,7 @@ class DetachedSignature(Descriptor): }
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.networkstatus.DetachedSignature'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('consensus-digest', '6D3CC0EFA408F228410A4A8145E1B0BB0670E442'), ('valid-after', _random_date()), @@ -1906,7 +1908,7 @@ class DetachedSignature(Descriptor): ('valid-until', _random_date()), ))
- def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: str, validate: bool = False) -> None: super(DetachedSignature, self).__init__(raw_content, lazy_load = not validate) entries = _descriptor_components(raw_content, validate)
@@ -1941,7 +1943,7 @@ class BridgeNetworkStatusDocument(NetworkStatusDocument):
TYPE_ANNOTATION_NAME = 'bridge-network-status'
- def __init__(self, raw_content, validate = False): + def __init__(self, raw_content: str, validate: bool = False) -> None: super(BridgeNetworkStatusDocument, self).__init__(raw_content)
self.published = None diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py index 24eb7b9b..f3c6d6bd 100644 --- a/stem/descriptor/remote.py +++ b/stem/descriptor/remote.py @@ -101,6 +101,7 @@ import stem.util.tor_tools
from stem.descriptor import Compression from stem.util import log, str_tools +from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
# Tor has a limited number of descriptors we can fetch explicitly by their # fingerprint or hashes due to a limit on the url length by squid proxies. @@ -121,7 +122,7 @@ SINGLETON_DOWNLOADER = None DIR_PORT_BLACKLIST = ('tor26', 'Serge')
-def get_instance(): +def get_instance() -> 'stem.descriptor.remote.DescriptorDownloader': """ Provides the singleton :class:`~stem.descriptor.remote.DescriptorDownloader` used for this module's shorthand functions. @@ -139,7 +140,7 @@ def get_instance(): return SINGLETON_DOWNLOADER
-def their_server_descriptor(**query_args): +def their_server_descriptor(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptor of the relay we're downloading from.
@@ -154,7 +155,7 @@ def their_server_descriptor(**query_args): return get_instance().their_server_descriptor(**query_args)
-def get_server_descriptors(fingerprints = None, **query_args): +def get_server_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_server_descriptors` @@ -166,7 +167,7 @@ def get_server_descriptors(fingerprints = None, **query_args): return get_instance().get_server_descriptors(fingerprints, **query_args)
-def get_extrainfo_descriptors(fingerprints = None, **query_args): +def get_extrainfo_descriptors(fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_extrainfo_descriptors` @@ -178,7 +179,7 @@ def get_extrainfo_descriptors(fingerprints = None, **query_args): return get_instance().get_extrainfo_descriptors(fingerprints, **query_args)
-def get_microdescriptors(hashes, **query_args): +def get_microdescriptors(hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_microdescriptors` @@ -190,7 +191,7 @@ def get_microdescriptors(hashes, **query_args): return get_instance().get_microdescriptors(hashes, **query_args)
-def get_consensus(authority_v3ident = None, microdescriptor = False, **query_args): +def get_consensus(authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_consensus` @@ -202,7 +203,7 @@ def get_consensus(authority_v3ident = None, microdescriptor = False, **query_arg return get_instance().get_consensus(authority_v3ident, microdescriptor, **query_args)
-def get_bandwidth_file(**query_args): +def get_bandwidth_file(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_bandwidth_file` @@ -214,7 +215,7 @@ def get_bandwidth_file(**query_args): return get_instance().get_bandwidth_file(**query_args)
-def get_detached_signatures(**query_args): +def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query': """ Shorthand for :func:`~stem.descriptor.remote.DescriptorDownloader.get_detached_signatures` @@ -370,7 +371,7 @@ class Query(object): the same as running **query.run(True)** (default is **False**) """
- def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs): + def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoints: Optional[Sequence['stem.Endpoint']] = None, compression: Sequence['stem.descriptor.Compression'] = (Compression.GZIP,), retries: int = 2, fall_back_to_authority: bool = False, timeout: Optional[float] = None, start: bool = True, block: bool = False, validate: bool = False, document_handler: 'stem.descriptor.DocumentHandler' = stem.descriptor.DocumentHandler.ENTRIES, **kwargs: Any) -> None: if not resource.startswith('/'): raise ValueError("Resources should start with a '/': %s" % resource)
@@ -433,7 +434,7 @@ class Query(object): if block: self.run(True)
- def start(self): + def start(self) -> None: """ Starts downloading the scriptors if we haven't started already. """ @@ -449,7 +450,7 @@ class Query(object): self._downloader_thread.setDaemon(True) self._downloader_thread.start()
- def run(self, suppress = False): + def run(self, suppress: bool = False) -> Sequence['stem.descriptor.Descriptor']: """ Blocks until our request is complete then provides the descriptors. If we haven't yet started our request then this does so. @@ -469,7 +470,7 @@ class Query(object):
return list(self._run(suppress))
- def _run(self, suppress): + def _run(self, suppress: bool) -> Iterator['stem.descriptor.Descriptor']: with self._downloader_thread_lock: self.start() self._downloader_thread.join() @@ -505,11 +506,11 @@ class Query(object):
raise self.error
- def __iter__(self): + def __iter__(self) -> Iterator['stem.descriptor.Descriptor']: for desc in self._run(True): yield desc
- def _pick_endpoint(self, use_authority = False): + def _pick_endpoint(self, use_authority: bool = False) -> 'stem.Endpoint': """ Provides an endpoint to query. If we have multiple endpoints then one is picked at random. @@ -527,7 +528,7 @@ class Query(object): else: return random.choice(self.endpoints)
- def _download_descriptors(self, retries, timeout): + def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None: try: self.start_time = time.time() endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority) @@ -572,7 +573,7 @@ class DescriptorDownloader(object): :class:`~stem.descriptor.remote.Query` constructor """
- def __init__(self, use_mirrors = False, **default_args): + def __init__(self, use_mirrors: bool = False, **default_args: Any) -> None: self._default_args = default_args
self._endpoints = None @@ -585,7 +586,7 @@ class DescriptorDownloader(object): except Exception as exc: log.debug('Unable to retrieve directory mirrors: %s' % exc)
- def use_directory_mirrors(self): + def use_directory_mirrors(self) -> 'stem.descriptor.networkstatus.NetworkStatusDocumentV3': """ Downloads the present consensus and configures ourselves to use directory mirrors, in addition to authorities. @@ -611,7 +612,7 @@ class DescriptorDownloader(object):
return consensus
- def their_server_descriptor(self, **query_args): + def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptor of the relay we're downloading from.
@@ -625,7 +626,7 @@ class DescriptorDownloader(object):
return self.query('/tor/server/authority', **query_args)
- def get_server_descriptors(self, fingerprints = None, **query_args): + def get_server_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the server descriptors with the given fingerprints. If no fingerprints are provided then this returns all descriptors known @@ -655,7 +656,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_extrainfo_descriptors(self, fingerprints = None, **query_args): + def get_extrainfo_descriptors(self, fingerprints: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the extrainfo descriptors with the given fingerprints. If no fingerprints are provided then this returns all descriptors in the present @@ -685,7 +686,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_microdescriptors(self, hashes, **query_args): + def get_microdescriptors(self, hashes: Optional[Union[str, Sequence[str]]], **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the microdescriptors with the given hashes. To get these see the **microdescriptor_digest** attribute of @@ -731,7 +732,7 @@ class DescriptorDownloader(object):
return self.query('/tor/micro/d/%s' % '-'.join(hashes), **query_args)
- def get_consensus(self, authority_v3ident = None, microdescriptor = False, **query_args): + def get_consensus(self, authority_v3ident: Optional[str] = None, microdescriptor: bool = False, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the present router status entries.
@@ -775,7 +776,7 @@ class DescriptorDownloader(object):
return consensus_query
- def get_vote(self, authority, **query_args): + def get_vote(self, authority: 'stem.directory.Authority', **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the present vote for a given directory authority.
@@ -794,13 +795,13 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_key_certificates(self, authority_v3idents = None, **query_args): + def get_key_certificates(self, authority_v3idents: Optional[Union[str, Sequence[str]]] = None, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the key certificates for authorities with the given fingerprints. If no fingerprints are provided then this returns all present key certificates.
- :param str authority_v3idents: fingerprint or list of fingerprints of the + :param str,list authority_v3idents: fingerprint or list of fingerprints of the authority keys, see `'v3ident' in tor's config.c https://gitweb.torproject.org/tor.git/tree/src/or/config.c#n819`_ for the values. @@ -827,7 +828,7 @@ class DescriptorDownloader(object):
return self.query(resource, **query_args)
- def get_bandwidth_file(self, **query_args): + def get_bandwidth_file(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the bandwidth authority heuristics used to make the next consensus. @@ -843,7 +844,7 @@ class DescriptorDownloader(object):
return self.query('/tor/status-vote/next/bandwidth', **query_args)
- def get_detached_signatures(self, **query_args): + def get_detached_signatures(self, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Provides the detached signatures that will be used to make the next consensus. Please note that **these are only available during minutes 55-60 @@ -896,7 +897,7 @@ class DescriptorDownloader(object):
return self.query('/tor/status-vote/next/consensus-signatures', **query_args)
- def query(self, resource, **query_args): + def query(self, resource: str, **query_args: Any) -> 'stem.descriptor.remote.Query': """ Issues a request for the given resource.
@@ -923,7 +924,7 @@ class DescriptorDownloader(object): return Query(resource, **args)
-def _download_from_orport(endpoint, compression, resource): +def _download_from_orport(endpoint: 'stem.ORPort', compression: Sequence['stem.Compression'], resource: str) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given orport. Payload is just like an http response (headers and all)... @@ -981,7 +982,7 @@ def _download_from_orport(endpoint, compression, resource): return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url, compression, timeout): +def _download_from_dirport(url: str, compression: Sequence['stem.descriptor.Compression'], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]: """ Downloads descriptors from the given url.
@@ -1016,7 +1017,7 @@ def _download_from_dirport(url, compression, timeout): return _decompress(response.read(), response.headers.get('Content-Encoding')), response.headers
-def _decompress(data, encoding): +def _decompress(data: bytes, encoding: str) -> bytes: """ Decompresses descriptor data.
@@ -1030,6 +1031,8 @@ def _decompress(data, encoding): :param bytes data: data we received :param str encoding: 'Content-Encoding' header of the response
+ :returns: **bytes** with the decompressed data + :raises: * **ValueError** if encoding is unrecognized * **ImportError** if missing the decompression module @@ -1045,7 +1048,7 @@ def _decompress(data, encoding): raise ValueError("'%s' isn't a recognized type of encoding" % encoding)
-def _guess_descriptor_type(resource): +def _guess_descriptor_type(resource: str) -> str: # Attempts to determine the descriptor type based on the resource url. This # raises a ValueError if the resource isn't recognized.
diff --git a/stem/descriptor/router_status_entry.py b/stem/descriptor/router_status_entry.py index c2d8dd07..20822c82 100644 --- a/stem/descriptor/router_status_entry.py +++ b/stem/descriptor/router_status_entry.py @@ -27,6 +27,8 @@ import io import stem.exit_policy import stem.util.str_tools
+from typing import Any, BinaryIO, Dict, Iterator, Mapping, Optional, Sequence, Tuple, Type + from stem.descriptor import ( KEYWORD_LINE, Descriptor, @@ -44,7 +46,7 @@ from stem.descriptor import ( _parse_pr_line = _parse_protocol_line('pr', 'protocols')
-def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start_position = None, end_position = None, section_end_keywords = (), extra_args = ()): +def _parse_file(document_file: BinaryIO, validate: bool, entry_class: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], entry_keyword: str = 'r', start_position: int = None, end_position: int = None, section_end_keywords: Sequence[str] = (), extra_args: Sequence[str] = ()) -> Iterator['stem.descriptor.router_status_entry.RouterStatusEntry']: """ Reads a range of the document_file containing some number of entry_class instances. We deliminate the entry_class entries by the keyword on their @@ -111,7 +113,7 @@ def _parse_file(document_file, validate, entry_class, entry_keyword = 'r', start break
-def _parse_r_line(descriptor, entries): +def _parse_r_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # Parses a RouterStatusEntry's 'r' line. They're very nearly identical for # all current entry types (v2, v3, and microdescriptor v3) with one little # wrinkle: only the microdescriptor flavor excludes a 'digest' field. @@ -163,7 +165,7 @@ def _parse_r_line(descriptor, entries): raise ValueError("Publication time time wasn't parsable: r %s" % value)
-def _parse_a_line(descriptor, entries): +def _parse_a_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "a" SP address ":" portlist # example: a [2001:888:2133:0:82:94:251:204]:9001
@@ -186,7 +188,7 @@ def _parse_a_line(descriptor, entries): descriptor.or_addresses = or_addresses
-def _parse_s_line(descriptor, entries): +def _parse_s_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "s" Flags # example: s Named Running Stable Valid
@@ -201,7 +203,7 @@ def _parse_s_line(descriptor, entries): raise ValueError("%s had extra whitespace on its 's' line: s %s" % (descriptor._name(), value))
-def _parse_v_line(descriptor, entries): +def _parse_v_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "v" version # example: v Tor 0.2.2.35 # @@ -219,7 +221,7 @@ def _parse_v_line(descriptor, entries): raise ValueError('%s has a malformed tor version (%s): v %s' % (descriptor._name(), exc, value))
-def _parse_w_line(descriptor, entries): +def _parse_w_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "w" "Bandwidth=" INT ["Measured=" INT] ["Unmeasured=1"] # example: w Bandwidth=7980
@@ -266,7 +268,7 @@ def _parse_w_line(descriptor, entries): descriptor.unrecognized_bandwidth_entries = unrecognized_bandwidth_entries
-def _parse_p_line(descriptor, entries): +def _parse_p_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "p" ("accept" / "reject") PortList # # examples: @@ -282,7 +284,7 @@ def _parse_p_line(descriptor, entries): raise ValueError('%s exit policy is malformed (%s): p %s' % (descriptor._name(), exc, value))
-def _parse_id_line(descriptor, entries): +def _parse_id_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "id" "ed25519" ed25519-identity # # examples: @@ -305,7 +307,7 @@ def _parse_id_line(descriptor, entries): raise ValueError("'id' lines should contain both the key type and digest: id %s" % value)
-def _parse_m_line(descriptor, entries): +def _parse_m_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "m" methods 1*(algorithm "=" digest) # example: m 8,9,10,11,12 sha256=g1vx9si329muxV3tquWIXXySNOIwRGMeAESKs/v4DWs
@@ -339,14 +341,14 @@ def _parse_m_line(descriptor, entries): descriptor.microdescriptor_hashes = all_hashes
-def _parse_microdescriptor_m_line(descriptor, entries): +def _parse_microdescriptor_m_line(descriptor: 'stem.descriptor.Descriptor', entries): # "m" digest # example: m aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70
descriptor.microdescriptor_digest = _value('m', entries)
-def _base64_to_hex(identity, check_if_fingerprint = True): +def _base64_to_hex(identity: str, check_if_fingerprint: bool = True) -> str: """ Decodes a base64 value to hex. For example...
@@ -420,7 +422,7 @@ class RouterStatusEntry(Descriptor): }
@classmethod - def from_str(cls, content, **kwargs): + def from_str(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntry'], content: str, **kwargs: Any) -> 'stem.descriptor.router_status_entry.RouterStatusEntry': # Router status entries don't have their own @type annotation, so to make # our subclass from_str() work we need to do the type inferencing ourself.
@@ -440,14 +442,14 @@ class RouterStatusEntry(Descriptor): else: raise ValueError("Descriptor.from_str() expected a single descriptor, but had %i instead. Please include 'multiple = True' if you want a list of results instead." % len(results))
- def __init__(self, content, validate = False, document = None): + def __init__(self, content: str, validate: bool = False, document: Optional['stem.descriptor.NetworkStatusDocument'] = None) -> None: """ Parse a router descriptor in a network status document.
:param str content: router descriptor content to be parsed - :param NetworkStatusDocument document: document this descriptor came from :param bool validate: checks the validity of the content if **True**, skips these checks otherwise + :param NetworkStatusDocument document: document this descriptor came from
:raises: **ValueError** if the descriptor data is invalid """ @@ -472,21 +474,21 @@ class RouterStatusEntry(Descriptor): else: self._entries = entries
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: """ Name for this descriptor type. """
return 'Router status entries' if is_plural else 'Router status entry'
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: """ Provides lines that must appear in the descriptor. """
return ()
- def _single_fields(self): + def _single_fields(self) -> Tuple[str]: """ Provides lines that can only appear in the descriptor once. """ @@ -512,18 +514,18 @@ class RouterStatusEntryV2(RouterStatusEntry): })
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV2'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), ))
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v2)' if is_plural else 'Router status entry (v2)'
- def _required_fields(self): - return ('r') + def _required_fields(self) -> Tuple[str]: + return ('r',)
- def _single_fields(self): + def _single_fields(self) -> Tuple[str]: return ('r', 's', 'v')
@@ -603,19 +605,19 @@ class RouterStatusEntryV3(RouterStatusEntry): })
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('r', '%s p1aag7VwarGxqctS7/fS0y5FU+s oQZFLYe9e4A7bOkWKR7TaNxb0JE %s %s 9001 0' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('s', 'Fast Named Running Stable Valid'), ))
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (v3)' if is_plural else 'Router status entry (v3)'
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: return ('r', 's')
- def _single_fields(self): + def _single_fields(self) -> Tuple[str]: return ('r', 's', 'v', 'w', 'p', 'pr')
@@ -668,18 +670,18 @@ class RouterStatusEntryMicroV3(RouterStatusEntry): })
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.router_status_entry.RouterStatusEntryMicroV3'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('r', '%s ARIJF2zbqirB9IwsW0mQznccWww %s %s 9001 9030' % (_random_nickname(), _random_date(), _random_ipv4_address())), ('m', 'aiUklwBrua82obG5AsTX+iEpkjQA2+AQHxZ7GwMfY70'), ('s', 'Fast Guard HSDir Named Running Stable V2Dir Valid'), ))
- def _name(self, is_plural = False): + def _name(self, is_plural: bool = False) -> str: return 'Router status entries (micro v3)' if is_plural else 'Router status entry (micro v3)'
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: return ('r', 's', 'm')
- def _single_fields(self): + def _single_fields(self) -> Tuple[str]: return ('r', 's', 'v', 'w', 'm', 'pr') diff --git a/stem/descriptor/server_descriptor.py b/stem/descriptor/server_descriptor.py index 955b8429..11b44972 100644 --- a/stem/descriptor/server_descriptor.py +++ b/stem/descriptor/server_descriptor.py @@ -61,6 +61,7 @@ import stem.version
from stem.descriptor.certificate import Ed25519Certificate from stem.descriptor.router_status_entry import RouterStatusEntryV3 +from typing import Any, BinaryIO, Dict, Iterator, Optional, Mapping, Sequence, Tuple, Type, Union
from stem.descriptor import ( PGP_BLOCK_END, @@ -139,11 +140,11 @@ REJECT_ALL_POLICY = stem.exit_policy.ExitPolicy('reject *:*') DEFAULT_BRIDGE_DISTRIBUTION = 'any'
-def _truncated_b64encode(content): +def _truncated_b64encode(content: bytes) -> str: return stem.util.str_tools._to_unicode(base64.b64encode(content).rstrip(b'='))
-def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): +def _parse_file(descriptor_file: BinaryIO, is_bridge: bool = False, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.server_descriptor.ServerDescriptor']: """ Iterates over the server descriptors in a file.
@@ -220,7 +221,7 @@ def _parse_file(descriptor_file, is_bridge = False, validate = False, **kwargs): break # done parsing descriptors
-def _parse_router_line(descriptor, entries): +def _parse_router_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "router" nickname address ORPort SocksPort DirPort
value = _value('router', entries) @@ -246,7 +247,7 @@ def _parse_router_line(descriptor, entries): descriptor.dir_port = None if router_comp[4] == '0' else int(router_comp[4])
-def _parse_bandwidth_line(descriptor, entries): +def _parse_bandwidth_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "bandwidth" bandwidth-avg bandwidth-burst bandwidth-observed
value = _value('bandwidth', entries) @@ -266,7 +267,7 @@ def _parse_bandwidth_line(descriptor, entries): descriptor.observed_bandwidth = int(bandwidth_comp[2])
-def _parse_platform_line(descriptor, entries): +def _parse_platform_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "platform" string
_parse_bytes_line('platform', 'platform')(descriptor, entries) @@ -292,7 +293,7 @@ def _parse_platform_line(descriptor, entries): pass
-def _parse_fingerprint_line(descriptor, entries): +def _parse_fingerprint_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # This is forty hex digits split into space separated groups of four. # Checking that we match this pattern.
@@ -309,7 +310,7 @@ def _parse_fingerprint_line(descriptor, entries): descriptor.fingerprint = fingerprint
-def _parse_extrainfo_digest_line(descriptor, entries): +def _parse_extrainfo_digest_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value('extra-info-digest', entries) digest_comp = value.split(' ')
@@ -320,7 +321,7 @@ def _parse_extrainfo_digest_line(descriptor, entries): descriptor.extra_info_sha256_digest = digest_comp[1] if len(digest_comp) >= 2 else None
-def _parse_hibernating_line(descriptor, entries): +def _parse_hibernating_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: # "hibernating" 0|1 (in practice only set if one)
value = _value('hibernating', entries) @@ -331,7 +332,7 @@ def _parse_hibernating_line(descriptor, entries): descriptor.hibernating = value == '1'
-def _parse_protocols_line(descriptor, entries): +def _parse_protocols_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value('protocols', entries) protocols_match = re.match('^Link (.*) Circuit (.*)$', value)
@@ -343,7 +344,7 @@ def _parse_protocols_line(descriptor, entries): descriptor.circuit_protocols = circuit_versions.split(' ')
-def _parse_or_address_line(descriptor, entries): +def _parse_or_address_line(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: all_values = _values('or-address', entries) or_addresses = []
@@ -366,7 +367,7 @@ def _parse_or_address_line(descriptor, entries): descriptor.or_addresses = or_addresses
-def _parse_history_line(keyword, history_end_attribute, history_interval_attribute, history_values_attribute, descriptor, entries): +def _parse_history_line(keyword: str, history_end_attribute: str, history_interval_attribute: str, history_values_attribute: str, descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: value = _value(keyword, entries) timestamp, interval, remainder = stem.descriptor.extrainfo_descriptor._parse_timestamp_and_interval(keyword, value)
@@ -383,7 +384,7 @@ def _parse_history_line(keyword, history_end_attribute, history_interval_attribu setattr(descriptor, history_values_attribute, history_values)
-def _parse_exit_policy(descriptor, entries): +def _parse_exit_policy(descriptor: 'stem.descriptor.Descriptor', entries: Dict[str, Sequence[str]]) -> None: if hasattr(descriptor, '_unparsed_exit_policy'): if descriptor._unparsed_exit_policy and stem.util.str_tools._to_unicode(descriptor._unparsed_exit_policy[0]) == 'reject *:*': descriptor.exit_policy = REJECT_ALL_POLICY @@ -576,7 +577,7 @@ class ServerDescriptor(Descriptor): 'eventdns': _parse_eventdns_line, }
- def __init__(self, raw_contents, validate = False): + def __init__(self, raw_contents: str, validate: bool = False) -> None: """ Server descriptor constructor, created from an individual relay's descriptor content (as provided by 'GETINFO desc/*', cached descriptors, @@ -621,7 +622,7 @@ class ServerDescriptor(Descriptor): else: self._entries = entries
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: """ Digest of this descriptor's content. These are referenced by...
@@ -641,7 +642,7 @@ class ServerDescriptor(Descriptor):
raise NotImplementedError('Unsupported Operation: this should be implemented by the ServerDescriptor subclass')
- def _check_constraints(self, entries): + def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: """ Does a basic check that the entries conform to this descriptor type's constraints. @@ -679,16 +680,16 @@ class ServerDescriptor(Descriptor): # Constraints that the descriptor must meet to be valid. These can be None if # not applicable.
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: return REQUIRED_FIELDS
- def _single_fields(self): + def _single_fields(self) -> Tuple[str]: return REQUIRED_FIELDS + SINGLE_FIELDS
- def _first_keyword(self): + def _first_keyword(self) -> str: return 'router'
- def _last_keyword(self): + def _last_keyword(self) -> str: return 'router-signature'
@@ -753,7 +754,7 @@ class RelayDescriptor(ServerDescriptor): 'router-signature': _parse_router_signature_line, })
- def __init__(self, raw_contents, validate = False, skip_crypto_validation = False): + def __init__(self, raw_contents: str, validate: bool = False, skip_crypto_validation: bool = False) -> None: super(RelayDescriptor, self).__init__(raw_contents, validate)
if validate: @@ -785,7 +786,7 @@ class RelayDescriptor(ServerDescriptor): pass # cryptography module unavailable
@classmethod - def content(cls, attr = None, exclude = (), sign = False, signing_key = None, exit_policy = None): + def content(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> str: if attr is None: attr = {}
@@ -827,15 +828,18 @@ class RelayDescriptor(ServerDescriptor): ))
@classmethod - def create(cls, attr = None, exclude = (), validate = True, sign = False, signing_key = None, exit_policy = None): + def create(cls: Type['stem.descriptor.server_descriptor.RelayDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = (), validate: bool = True, sign: bool = False, signing_key: Optional['stem.descriptor.SigningKey'] = None, exit_policy: Optional['stem.exit_policy.ExitPolicy'] = None) -> 'stem.descriptor.server_descriptor.RelayDescriptor': return cls(cls.content(attr, exclude, sign, signing_key, exit_policy), validate = validate, skip_crypto_validation = not sign)
@functools.lru_cache() - def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: """ Provides the digest of our descriptor's content.
- :returns: the digest string encoded in uppercase hex + :param stem.descriptor.DigestHash hash_type: digest hashing algorithm + :param stem.descriptor.DigestEncoding encoding: digest encoding + + :returns: **hashlib.HASH** or **str** based on our encoding argument
:raises: ValueError if the digest cannot be calculated """ @@ -849,7 +853,7 @@ class RelayDescriptor(ServerDescriptor): else: raise NotImplementedError('Server descriptor digests are only available in sha1 and sha256, not %s' % hash_type)
- def make_router_status_entry(self): + def make_router_status_entry(self) -> 'stem.descriptor.router_status_entry.RouterStatusEntryV3': """ Provides a RouterStatusEntryV3 for this descriptor content.
@@ -888,12 +892,12 @@ class RelayDescriptor(ServerDescriptor): return RouterStatusEntryV3.create(attr)
@functools.lru_cache() - def _onion_key_crosscert_digest(self): + def _onion_key_crosscert_digest(self) -> str: """ Provides the digest of the onion-key-crosscert data. This consists of the RSA identity key sha1 and ed25519 identity key.
- :returns: **unicode** digest encoded in uppercase hex + :returns: **str** digest encoded in uppercase hex
:raises: ValueError if the digest cannot be calculated """ @@ -902,7 +906,7 @@ class RelayDescriptor(ServerDescriptor): data = signing_key_digest + base64.b64decode(stem.util.str_tools._to_bytes(self.ed25519_master_key) + b'=') return stem.util.str_tools._to_unicode(binascii.hexlify(data).upper())
- def _check_constraints(self, entries): + def _check_constraints(self, entries: Dict[str, Sequence[str]]) -> None: super(RelayDescriptor, self)._check_constraints(entries)
if self.certificate: @@ -941,7 +945,7 @@ class BridgeDescriptor(ServerDescriptor): })
@classmethod - def content(cls, attr = None, exclude = ()): + def content(cls: Type['stem.descriptor.server_descriptor.BridgeDescriptor'], attr: Optional[Mapping[str, str]] = None, exclude: Sequence[str] = ()) -> str: return _descriptor_content(attr, exclude, ( ('router', '%s %s 9001 0 0' % (_random_nickname(), _random_ipv4_address())), ('router-digest', '006FD96BA35E7785A6A3B8B75FE2E2435A13BDB4'), @@ -950,13 +954,13 @@ class BridgeDescriptor(ServerDescriptor): ('reject', '*:*'), ))
- def digest(self, hash_type = DigestHash.SHA1, encoding = DigestEncoding.HEX): + def digest(self, hash_type: 'stem.descriptor.DigestHash' = DigestHash.SHA1, encoding: 'stem.descriptor.DigestEncoding' = DigestEncoding.HEX) -> Union[str, 'hashlib.HASH']: if hash_type == DigestHash.SHA1 and encoding == DigestEncoding.HEX: return self._digest else: raise NotImplementedError('Bridge server descriptor digests are only available as sha1/hex, not %s/%s' % (hash_type, encoding))
- def is_scrubbed(self): + def is_scrubbed(self) -> bool: """ Checks if we've been properly scrubbed in accordance with the `bridge descriptor specification @@ -969,7 +973,7 @@ class BridgeDescriptor(ServerDescriptor): return self.get_scrubbing_issues() == []
@functools.lru_cache() - def get_scrubbing_issues(self): + def get_scrubbing_issues(self) -> Sequence[str]: """ Provides issues with our scrubbing.
@@ -1003,7 +1007,7 @@ class BridgeDescriptor(ServerDescriptor):
return issues
- def _required_fields(self): + def _required_fields(self) -> Tuple[str]: # bridge required fields are the same as a relay descriptor, minus items # excluded according to the format page
@@ -1019,8 +1023,8 @@ class BridgeDescriptor(ServerDescriptor):
return tuple(included_fields + [f for f in REQUIRED_FIELDS if f not in excluded_fields])
- def _single_fields(self): + def _single_fields(self) -> str: return self._required_fields() + SINGLE_FIELDS
- def _last_keyword(self): + def _last_keyword(self) -> str: return None diff --git a/stem/descriptor/tordnsel.py b/stem/descriptor/tordnsel.py index d0f57b93..c36e343d 100644 --- a/stem/descriptor/tordnsel.py +++ b/stem/descriptor/tordnsel.py @@ -14,6 +14,8 @@ import stem.util.connection import stem.util.str_tools import stem.util.tor_tools
+from typing import Any, BinaryIO, Dict, Iterator, Sequence + from stem.descriptor import ( Descriptor, _read_until_keywords, @@ -21,7 +23,7 @@ from stem.descriptor import ( )
-def _parse_file(tordnsel_file, validate = False, **kwargs): +def _parse_file(tordnsel_file: BinaryIO, validate: bool = False, **kwargs: Any) -> Iterator['stem.descriptor.tordnsel.TorDNSEL']: """ Iterates over a tordnsel file.
@@ -62,7 +64,7 @@ class TorDNSEL(Descriptor):
TYPE_ANNOTATION_NAME = 'tordnsel'
- def __init__(self, raw_contents, validate): + def __init__(self, raw_contents: str, validate: bool) -> None: super(TorDNSEL, self).__init__(raw_contents) raw_contents = stem.util.str_tools._to_unicode(raw_contents) entries = _descriptor_components(raw_contents, validate) @@ -74,8 +76,7 @@ class TorDNSEL(Descriptor):
self._parse(entries, validate)
- def _parse(self, entries, validate): - + def _parse(self, entries: Dict[str, Sequence[str]], validate: bool) -> None: for keyword, values in list(entries.items()): value, block_type, block_content = values[0]
diff --git a/stem/directory.py b/stem/directory.py index 67079c80..f96adfbb 100644 --- a/stem/directory.py +++ b/stem/directory.py @@ -49,6 +49,7 @@ import stem.util import stem.util.conf
from stem.util import connection, str_tools, tor_tools +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Pattern, Sequence, Tuple
GITWEB_AUTHORITY_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/auth_dirs.inc' GITWEB_FALLBACK_URL = 'https://gitweb.torproject.org/tor.git/plain/src/app/config/fallback_dirs.inc' @@ -68,7 +69,7 @@ FALLBACK_EXTRAINFO = re.compile('/\* extrainfo=([0-1]) \*/') FALLBACK_IPV6 = re.compile('" ipv6=\[([\da-f:]+)\]:(\d+)"')
-def _match_with(lines, regexes, required = None): +def _match_with(lines: Sequence[str], regexes: Sequence[Pattern], required: Optional[bool] = None) -> Dict[Pattern, Tuple[str]]: """ Scans the given content against a series of regex matchers, providing back a mapping of regexes to their capture groups. This maping is with the value if @@ -101,7 +102,7 @@ def _match_with(lines, regexes, required = None): return matches
-def _directory_entries(lines, pop_section_func, regexes, required = None): +def _directory_entries(lines: Sequence[str], pop_section_func: Callable[[Sequence[str]], Sequence[str]], regexes: Sequence[Pattern], required: bool = None) -> Iterator[Dict[Pattern, Tuple[str]]]: next_section = pop_section_func(lines)
while next_section: @@ -133,7 +134,7 @@ class Directory(object): ORPort, or **None** if it doesn't have one """
- def __init__(self, address, or_port, dir_port, fingerprint, nickname, orport_v6): + def __init__(self, address: str, or_port: int, dir_port: int, fingerprint: str, nickname: str, orport_v6: str) -> None: identifier = '%s (%s)' % (fingerprint, nickname) if nickname else fingerprint
if not connection.is_valid_ipv4_address(address): @@ -163,7 +164,7 @@ class Directory(object): self.orport_v6 = (orport_v6[0], int(orport_v6[1])) if orport_v6 else None
@staticmethod - def from_cache(): + def from_cache() -> Dict[str, 'stem.directory.Directory']: """ Provides cached Tor directory information. This information is hardcoded into Tor and occasionally changes, so the information provided by this @@ -181,7 +182,7 @@ class Directory(object): raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
@staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Directory']: """ Reads and parses tor's directory data `from gitweb.torproject.org https://gitweb.torproject.org/`_. Note that while convenient, this reliance on GitWeb means you should alway @@ -209,13 +210,13 @@ class Directory(object):
raise NotImplementedError('Unsupported Operation: this should be implemented by the Directory subclass')
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'address', 'or_port', 'dir_port', 'fingerprint', 'nickname', 'orport_v6')
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Directory) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -231,7 +232,7 @@ class Authority(Directory): :var str v3ident: identity key fingerprint used to sign votes and consensus """
- def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, orport_v6 = None, v3ident = None): + def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, orport_v6: Optional[int] = None, v3ident: Optional[str] = None) -> None: super(Authority, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6)
if v3ident and not tor_tools.is_valid_fingerprint(v3ident): @@ -241,11 +242,11 @@ class Authority(Directory): self.v3ident = v3ident
@staticmethod - def from_cache(): + def from_cache() -> Dict[str, 'stem.directory.Authority']: return dict(DIRECTORY_AUTHORITIES)
@staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Authority']: try: lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_AUTHORITY_URL, timeout = timeout).read()).splitlines()
@@ -284,7 +285,7 @@ class Authority(Directory): return results
@staticmethod - def _pop_section(lines): + def _pop_section(lines: Sequence[str]) -> Sequence[str]: """ Provides the next authority entry. """ @@ -299,13 +300,13 @@ class Authority(Directory):
return section_lines
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'v3ident', parent = Directory, cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Authority) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -348,13 +349,13 @@ class Fallback(Directory): :var collections.OrderedDict header: metadata about the fallback directory file this originated from """
- def __init__(self, address = None, or_port = None, dir_port = None, fingerprint = None, nickname = None, has_extrainfo = False, orport_v6 = None, header = None): + def __init__(self, address: Optional[str] = None, or_port: Optional[int] = None, dir_port: Optional[int] = None, fingerprint: Optional[str] = None, nickname: Optional[str] = None, has_extrainfo: bool = False, orport_v6: Optional[int] = None, header: Optional[Mapping[str, str]] = None) -> None: super(Fallback, self).__init__(address, or_port, dir_port, fingerprint, nickname, orport_v6) self.has_extrainfo = has_extrainfo self.header = collections.OrderedDict(header) if header else collections.OrderedDict()
@staticmethod - def from_cache(path = FALLBACK_CACHE_PATH): + def from_cache(path: str = FALLBACK_CACHE_PATH) -> Dict[str, 'stem.directory.Fallback']: conf = stem.util.conf.Config() conf.load(path) headers = collections.OrderedDict([(k.split('.', 1)[1], conf.get(k)) for k in conf.keys() if k.startswith('header.')]) @@ -393,7 +394,7 @@ class Fallback(Directory): return results
@staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> Dict[str, 'stem.directory.Fallback']: try: lines = str_tools._to_unicode(urllib.request.urlopen(GITWEB_FALLBACK_URL, timeout = timeout).read()).splitlines()
@@ -450,7 +451,7 @@ class Fallback(Directory): return results
@staticmethod - def _pop_section(lines): + def _pop_section(lines: Sequence[str]) -> Sequence[str]: """ Provides lines up through the next divider. This excludes lines with just a comma since they're an artifact of these being C strings. @@ -470,7 +471,7 @@ class Fallback(Directory): return section_lines
@staticmethod - def _write(fallbacks, tor_commit, stem_commit, headers, path = FALLBACK_CACHE_PATH): + def _write(fallbacks: Dict[str, 'stem.directory.Fallback'], tor_commit: str, stem_commit: str, headers: Mapping[str, str], path: str = FALLBACK_CACHE_PATH) -> None: """ Persists fallback directories to a location in a way that can be read by from_cache(). @@ -503,17 +504,17 @@ class Fallback(Directory):
conf.save(path)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'has_extrainfo', 'header', parent = Directory, cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Fallback) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
-def _fallback_directory_differences(previous_directories, new_directories): +def _fallback_directory_differences(previous_directories: Sequence['stem.directory.Dirctory'], new_directories: Sequence['stem.directory.Directory']) -> str: """ Provides a description of how fallback directories differ. """ diff --git a/stem/exit_policy.py b/stem/exit_policy.py index ddcd7dfd..076611d2 100644 --- a/stem/exit_policy.py +++ b/stem/exit_policy.py @@ -71,6 +71,8 @@ import stem.util.connection import stem.util.enum import stem.util.str_tools
+from typing import Any, Iterator, Optional, Sequence, Union + AddressType = stem.util.enum.Enum(('WILDCARD', 'Wildcard'), ('IPv4', 'IPv4'), ('IPv6', 'IPv6'))
# Addresses aliased by the 'private' policy. From the tor man page... @@ -89,7 +91,7 @@ PRIVATE_ADDRESSES = ( )
-def _flag_private_rules(rules): +def _flag_private_rules(rules: Sequence['ExitPolicyRule']) -> None: """ Determine if part of our policy was expanded from the 'private' keyword. This doesn't differentiate if this actually came from the 'private' keyword or a @@ -139,7 +141,7 @@ def _flag_private_rules(rules): last_rule._is_private = True
-def _flag_default_rules(rules): +def _flag_default_rules(rules: Sequence['ExitPolicyRule']) -> None: """ Determine if part of our policy ends with the defaultly appended suffix. """ @@ -162,7 +164,7 @@ class ExitPolicy(object): entries that make up this policy """
- def __init__(self, *rules): + def __init__(self, *rules: Union[str, 'stem.exit_policy.ExitPolicyRule']) -> None: # sanity check the types
for rule in rules: @@ -196,7 +198,7 @@ class ExitPolicy(object): self._is_allowed_default = True
@functools.lru_cache() - def can_exit_to(self, address = None, port = None, strict = False): + def can_exit_to(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool: """ Checks if this policy allows exiting to a given destination or not. If the address or port is omitted then this will check if we're allowed to exit to @@ -220,7 +222,7 @@ class ExitPolicy(object): return self._is_allowed_default
@functools.lru_cache() - def is_exiting_allowed(self): + def is_exiting_allowed(self) -> bool: """ Provides **True** if the policy allows exiting whatsoever, **False** otherwise. @@ -242,7 +244,7 @@ class ExitPolicy(object): return self._is_allowed_default
@functools.lru_cache() - def summary(self): + def summary(self) -> str: """ Provides a short description of our policy chain, similar to a microdescriptor. This excludes entries that don't cover all IP @@ -320,7 +322,7 @@ class ExitPolicy(object):
return (label_prefix + ', '.join(display_ranges)).strip()
- def has_private(self): + def has_private(self) -> bool: """ Checks if we have any rules expanded from the 'private' keyword. Tor appends these by default to the start of the policy and includes a dynamic @@ -338,7 +340,7 @@ class ExitPolicy(object):
return False
- def strip_private(self): + def strip_private(self) -> 'ExitPolicy': """ Provides a copy of this policy without 'private' policy entries.
@@ -349,7 +351,7 @@ class ExitPolicy(object):
return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_private()])
- def has_default(self): + def has_default(self) -> bool: """ Checks if we have the default policy suffix.
@@ -364,7 +366,7 @@ class ExitPolicy(object):
return False
- def strip_default(self): + def strip_default(self) -> 'ExitPolicy': """ Provides a copy of this policy without the default policy suffix.
@@ -375,7 +377,7 @@ class ExitPolicy(object):
return ExitPolicy(*[rule for rule in self._get_rules() if not rule.is_default()])
- def _get_rules(self): + def _get_rules(self) -> Sequence['stem.exit_policy.ExitPolicyRule']: # Local reference to our input_rules so this can be lock free. Otherwise # another thread might unset our input_rules while processing them.
@@ -437,18 +439,18 @@ class ExitPolicy(object):
return self._rules
- def __len__(self): + def __len__(self) -> int: return len(self._get_rules())
- def __iter__(self): + def __iter__(self) -> Iterator['stem.exit_policy.ExitPolicyRule']: for rule in self._get_rules(): yield rule
@functools.lru_cache() - def __str__(self): + def __str__(self) -> str: return ', '.join([str(rule) for rule in self._get_rules()])
- def __hash__(self): + def __hash__(self) -> int: if self._hash is None: my_hash = 0
@@ -460,10 +462,10 @@ class ExitPolicy(object):
return self._hash
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ExitPolicy) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -495,7 +497,7 @@ class MicroExitPolicy(ExitPolicy): :param str policy: policy string that describes this policy """
- def __init__(self, policy): + def __init__(self, policy: str) -> None: # Microdescriptor policies are of the form... # # MicrodescriptrPolicy ::= ("accept" / "reject") SP PortList NL @@ -537,16 +539,16 @@ class MicroExitPolicy(ExitPolicy): super(MicroExitPolicy, self).__init__(*rules) self._is_allowed_default = not self.is_accept
- def __str__(self): + def __str__(self) -> str: return self._policy
- def __hash__(self): + def __hash__(self) -> int: return hash(str(self))
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, MicroExitPolicy) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -580,7 +582,7 @@ class ExitPolicyRule(object): :raises: **ValueError** if input isn't a valid tor exit policy rule """
- def __init__(self, rule): + def __init__(self, rule: str) -> None: # policy ::= "accept[6]" exitpattern | "reject[6]" exitpattern # exitpattern ::= addrspec ":" portspec
@@ -634,7 +636,7 @@ class ExitPolicyRule(object): self._is_private = False self._is_default_suffix = False
- def is_address_wildcard(self): + def is_address_wildcard(self) -> bool: """ **True** if we'll match against **any** address, **False** otherwise.
@@ -646,7 +648,7 @@ class ExitPolicyRule(object):
return self._address_type == _address_type_to_int(AddressType.WILDCARD)
- def is_port_wildcard(self): + def is_port_wildcard(self) -> bool: """ **True** if we'll match against any port, **False** otherwise.
@@ -655,7 +657,7 @@ class ExitPolicyRule(object):
return self.min_port in (0, 1) and self.max_port == 65535
- def is_match(self, address = None, port = None, strict = False): + def is_match(self, address: Optional[str] = None, port: Optional[int] = None, strict: bool = False) -> bool: """ **True** if we match against the given destination, **False** otherwise. If the address or port is omitted then this will check if we're allowed to @@ -726,7 +728,7 @@ class ExitPolicyRule(object): else: return True
- def get_address_type(self): + def get_address_type(self) -> AddressType: """ Provides the :data:`~stem.exit_policy.AddressType` for our policy.
@@ -735,7 +737,7 @@ class ExitPolicyRule(object):
return _int_to_address_type(self._address_type)
- def get_mask(self, cache = True): + def get_mask(self, cache: bool = True) -> str: """ Provides the address represented by our mask. This is **None** if our address type is a wildcard. @@ -765,7 +767,7 @@ class ExitPolicyRule(object):
return self._mask
- def get_masked_bits(self): + def get_masked_bits(self) -> int: """ Provides the number of bits our subnet mask represents. This is **None** if our mask can't have a bit representation. @@ -775,7 +777,7 @@ class ExitPolicyRule(object):
return self._masked_bits
- def is_private(self): + def is_private(self) -> bool: """ Checks if this rule was expanded from the 'private' policy keyword.
@@ -786,7 +788,7 @@ class ExitPolicyRule(object):
return self._is_private
- def is_default(self): + def is_default(self) -> bool: """ Checks if this rule belongs to the default exit policy suffix.
@@ -798,7 +800,7 @@ class ExitPolicyRule(object): return self._is_default_suffix
@functools.lru_cache() - def __str__(self): + def __str__(self) -> str: """ Provides the string representation of our policy. This does not necessarily match the rule that we were constructed from (due to things @@ -842,18 +844,18 @@ class ExitPolicyRule(object): return label
@functools.lru_cache() - def _get_mask_bin(self): + def _get_mask_bin(self) -> int: # provides an integer representation of our mask
return int(stem.util.connection._address_to_binary(self.get_mask(False)), 2)
@functools.lru_cache() - def _get_address_bin(self): + def _get_address_bin(self) -> int: # provides an integer representation of our address
return stem.util.connection.address_to_int(self.address) & self._get_mask_bin()
- def _apply_addrspec(self, rule, addrspec, is_ipv6_only): + def _apply_addrspec(self, rule: str, addrspec: str, is_ipv6_only: bool) -> None: # Parses the addrspec... # addrspec ::= "*" | ip4spec | ip6spec
@@ -924,7 +926,7 @@ class ExitPolicyRule(object): else: raise ValueError("'%s' isn't a wildcard, IPv4, or IPv6 address: %s" % (addrspec, rule))
- def _apply_portspec(self, rule, portspec): + def _apply_portspec(self, rule: str, portspec: str) -> None: # Parses the portspec... # portspec ::= "*" | port | port "-" port # port ::= an integer between 1 and 65535, inclusive. @@ -955,24 +957,24 @@ class ExitPolicyRule(object): else: raise ValueError("Port value isn't a wildcard, integer, or range: %s" % rule)
- def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = stem.util._hash_attr(self, 'is_accept', 'address', 'min_port', 'max_port') * 1024 + hash(self.get_mask(False))
return self._hash
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ExitPolicyRule) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
-def _address_type_to_int(address_type): +def _address_type_to_int(address_type: AddressType) -> int: return AddressType.index_of(address_type)
-def _int_to_address_type(address_type_int): +def _int_to_address_type(address_type_int: int) -> AddressType: return list(AddressType)[address_type_int]
@@ -981,32 +983,32 @@ class MicroExitPolicyRule(ExitPolicyRule): Lighter weight ExitPolicyRule derivative for microdescriptors. """
- def __init__(self, is_accept, min_port, max_port): + def __init__(self, is_accept: bool, min_port: int, max_port: int) -> None: self.is_accept = is_accept self.address = None # wildcard address self.min_port = min_port self.max_port = max_port self._skip_rule = False
- def is_address_wildcard(self): + def is_address_wildcard(self) -> bool: return True
- def get_address_type(self): + def get_address_type(self) -> AddressType: return AddressType.WILDCARD
- def get_mask(self, cache = True): + def get_mask(self, cache = True) -> str: return None
- def get_masked_bits(self): + def get_masked_bits(self) -> int: return None
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'is_accept', 'min_port', 'max_port', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, MicroExitPolicyRule) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py index 07a5f573..2a3cff18 100644 --- a/stem/interpreter/__init__.py +++ b/stem/interpreter/__init__.py @@ -38,11 +38,11 @@ uses_settings = stem.util.conf.uses_settings('stem_interpreter', settings_path)
@uses_settings -def msg(message, config, **attr): +def msg(message: str, config: 'stem.util.conf.Config', **attr: str) -> str: return config.get(message).format(**attr)
-def main(): +def main() -> None: try: import readline except ImportError: @@ -135,7 +135,7 @@ def main(): controller.msg(args.run_cmd)
try: - raw_input() + input() except (KeyboardInterrupt, stem.SocketClosed): pass else: diff --git a/stem/interpreter/arguments.py b/stem/interpreter/arguments.py index 00c8891d..8ac1c2c1 100644 --- a/stem/interpreter/arguments.py +++ b/stem/interpreter/arguments.py @@ -12,6 +12,8 @@ import os import stem.interpreter import stem.util.connection
+from typing import NamedTuple, Sequence + DEFAULT_ARGS = { 'control_address': '127.0.0.1', 'control_port': 'default', @@ -29,7 +31,7 @@ OPT = 'i:s:h' OPT_EXPANDED = ['interface=', 'socket=', 'tor=', 'run=', 'no-color', 'help']
-def parse(argv): +def parse(argv: Sequence[str]) -> NamedTuple: """ Parses our arguments, providing a named tuple with their values.
@@ -90,7 +92,7 @@ def parse(argv): return Args(**args)
-def get_help(): +def get_help() -> str: """ Provides our --help usage information.
diff --git a/stem/interpreter/autocomplete.py b/stem/interpreter/autocomplete.py index 9f5f2659..671085a7 100644 --- a/stem/interpreter/autocomplete.py +++ b/stem/interpreter/autocomplete.py @@ -8,10 +8,11 @@ Tab completion for our interpreter prompt. import functools
from stem.interpreter import uses_settings +from typing import Optional, Sequence
@uses_settings -def _get_commands(controller, config): +def _get_commands(controller: 'stem.control.Controller', config: 'stem.util.conf.Config') -> Sequence[str]: """ Provides commands recognized by tor. """ @@ -76,11 +77,11 @@ def _get_commands(controller, config):
class Autocompleter(object): - def __init__(self, controller): + def __init__(self, controller: 'stem.control.Controller') -> None: self._commands = _get_commands(controller)
@functools.lru_cache() - def matches(self, text): + def matches(self, text: str) -> Sequence[str]: """ Provides autocompletion matches for the given text.
@@ -92,7 +93,7 @@ class Autocompleter(object): lowercase_text = text.lower() return [cmd for cmd in self._commands if cmd.lower().startswith(lowercase_text)]
- def complete(self, text, state): + def complete(self, text: str, state: int) -> Optional[str]: """ Provides case insensetive autocompletion options, acting as a functor for the readlines set_completer function. diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py index 6e61fdda..1d610dac 100644 --- a/stem/interpreter/commands.py +++ b/stem/interpreter/commands.py @@ -21,11 +21,12 @@ import stem.util.tor_tools
from stem.interpreter import STANDARD_OUTPUT, BOLD_OUTPUT, ERROR_OUTPUT, uses_settings, msg from stem.util.term import format +from typing import BinaryIO, Iterator, Sequence, Tuple
MAX_EVENTS = 100
-def _get_fingerprint(arg, controller): +def _get_fingerprint(arg: str, controller: 'stem.control.Controller') -> str: """ Resolves user input into a relay fingerprint. This accepts...
@@ -90,7 +91,7 @@ def _get_fingerprint(arg, controller):
@contextlib.contextmanager -def redirect(stdout, stderr): +def redirect(stdout: BinaryIO, stderr: BinaryIO) -> Iterator[None]: original = sys.stdout, sys.stderr sys.stdout, sys.stderr = stdout, stderr
@@ -106,7 +107,7 @@ class ControlInterpreter(code.InteractiveConsole): for special irc style subcommands. """
- def __init__(self, controller): + def __init__(self, controller: 'stem.control.Controller') -> None: self._received_events = []
code.InteractiveConsole.__init__(self, { @@ -129,7 +130,7 @@ class ControlInterpreter(code.InteractiveConsole):
handle_event_real = self._controller._handle_event
- def handle_event_wrapper(event_message): + def handle_event_wrapper(event_message: 'stem.response.events.Event') -> None: handle_event_real(event_message) self._received_events.insert(0, event_message)
@@ -138,7 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
self._controller._handle_event = handle_event_wrapper
- def get_events(self, *event_types): + def get_events(self, *event_types: 'stem.control.EventType') -> Sequence['stem.response.events.Event']: events = list(self._received_events) event_types = list(map(str.upper, event_types)) # make filtering case insensitive
@@ -147,7 +148,7 @@ class ControlInterpreter(code.InteractiveConsole):
return events
- def do_help(self, arg): + def do_help(self, arg: str) -> str: """ Performs the '/help' operation, giving usage information for the given argument or a general summary if there wasn't one. @@ -155,7 +156,7 @@ class ControlInterpreter(code.InteractiveConsole):
return stem.interpreter.help.response(self._controller, arg)
- def do_events(self, arg): + def do_events(self, arg: str) -> str: """ Performs the '/events' operation, dumping the events that we've received belonging to the given types. If no types are specified then this provides @@ -173,7 +174,7 @@ class ControlInterpreter(code.InteractiveConsole):
return '\n'.join([format(str(e), *STANDARD_OUTPUT) for e in self.get_events(*event_types)])
- def do_info(self, arg): + def do_info(self, arg: str) -> str: """ Performs the '/info' operation, looking up a relay by fingerprint, IP address, or nickname and printing its descriptor and consensus entries in a @@ -271,7 +272,7 @@ class ControlInterpreter(code.InteractiveConsole):
return '\n'.join(lines)
- def do_python(self, arg): + def do_python(self, arg: str) -> str: """ Performs the '/python' operation, toggling if we accept python commands or not. @@ -295,12 +296,11 @@ class ControlInterpreter(code.InteractiveConsole): return format(response, *STANDARD_OUTPUT)
@uses_settings - def run_command(self, command, config, print_response = False): + def run_command(self, command: str, config: 'stem.util.conf.Config', print_response: bool = False) -> Sequence[Tuple[str, int]]: """ Runs the given command. Requests starting with a '/' are special commands to the interpreter, and anything else is sent to the control port.
- :param stem.control.Controller controller: tor control connection :param str command: command to be processed :param bool print_response: prints the response to stdout if true
diff --git a/stem/interpreter/help.py b/stem/interpreter/help.py index 1f242a8e..81c76d34 100644 --- a/stem/interpreter/help.py +++ b/stem/interpreter/help.py @@ -18,7 +18,7 @@ from stem.interpreter import ( from stem.util.term import format
-def response(controller, arg): +def response(controller: 'stem.control.Controller', arg: str) -> str: """ Provides our /help response.
@@ -33,7 +33,7 @@ def response(controller, arg): return _response(controller, _normalize(arg))
-def _normalize(arg): +def _normalize(arg) -> str: arg = arg.upper()
# If there's multiple arguments then just take the first. This is @@ -52,7 +52,7 @@ def _normalize(arg):
@functools.lru_cache() @uses_settings -def _response(controller, arg, config): +def _response(controller: 'stem.control.Controller', arg: str, config: 'stem.util.conf.Config') -> str: if not arg: return _general_help()
@@ -126,7 +126,7 @@ def _response(controller, arg, config): return output.rstrip()
-def _general_help(): +def _general_help() -> str: lines = []
for line in msg('help.general').splitlines(): diff --git a/stem/manual.py b/stem/manual.py index 367b6d7e..e28e0e6f 100644 --- a/stem/manual.py +++ b/stem/manual.py @@ -63,6 +63,8 @@ import stem.util.enum import stem.util.log import stem.util.system
+from typing import Any, Dict, Mapping, Optional, Sequence, TextIO, Tuple, Union + Category = stem.util.enum.Enum('GENERAL', 'CLIENT', 'RELAY', 'DIRECTORY', 'AUTHORITY', 'HIDDEN_SERVICE', 'DENIAL_OF_SERVICE', 'TESTING', 'UNKNOWN') GITWEB_MANUAL_URL = 'https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt' CACHE_PATH = os.path.join(os.path.dirname(__file__), 'cached_manual.sqlite') @@ -103,13 +105,13 @@ class SchemaMismatch(IOError): :var tuple supported_schemas: schemas library supports """
- def __init__(self, message, database_schema, library_schema): + def __init__(self, message: str, database_schema: int, supported_schemas: Tuple[int]) -> None: super(SchemaMismatch, self).__init__(message) self.database_schema = database_schema - self.library_schema = library_schema + self.supported_schemas = supported_schemas
-def query(query, *param): +def query(query: str, *param: str) -> 'sqlite3.Cursor': """ Performs the given query on our sqlite manual cache. This database should be treated as being read-only. File permissions generally enforce this, and @@ -162,25 +164,25 @@ class ConfigOption(object): :var str description: longer manual description with details """
- def __init__(self, name, category = Category.UNKNOWN, usage = '', summary = '', description = ''): + def __init__(self, name: str, category: 'stem.manual.Category' = Category.UNKNOWN, usage: str = '', summary: str = '', description: str = '') -> None: self.name = name self.category = category self.usage = usage self.summary = summary self.description = description
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'category', 'usage', 'summary', 'description', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ConfigOption) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@functools.lru_cache() -def _config(lowercase = True): +def _config(lowercase: bool = True) -> Dict[str, Union[Sequence[str], str]]: """ Provides a dictionary for our settings.cfg. This has a couple categories...
@@ -204,7 +206,7 @@ def _config(lowercase = True): return {}
-def _manual_differences(previous_manual, new_manual): +def _manual_differences(previous_manual: 'stem.manual.Manual', new_manual: 'stem.manual.Manual') -> str: """ Provides a description of how two manuals differ. """ @@ -249,7 +251,7 @@ def _manual_differences(previous_manual, new_manual): return '\n'.join(lines)
-def is_important(option): +def is_important(option: str) -> bool: """ Indicates if a configuration option of particularly common importance or not.
@@ -262,7 +264,7 @@ def is_important(option): return option.lower() in _config()['manual.important']
-def download_man_page(path = None, file_handle = None, url = GITWEB_MANUAL_URL, timeout = 20): +def download_man_page(path: Optional[str] = None, file_handle: Optional[TextIO] = None, url: str = GITWEB_MANUAL_URL, timeout: int = 20) -> None: """ Downloads tor's latest man page from `gitweb.torproject.org https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt`_. This method is @@ -347,7 +349,7 @@ class Manual(object): :var str stem_commit: stem commit to cache this manual information """
- def __init__(self, name, synopsis, description, commandline_options, signals, files, config_options): + def __init__(self, name: str, synopsis: str, description: str, commandline_options: Mapping[str, str], signals: Mapping[str, str], files: Mapping[str, str], config_options: Mapping[str, str]) -> None: self.name = name self.synopsis = synopsis self.description = description @@ -360,7 +362,7 @@ class Manual(object): self.schema = None
@staticmethod - def from_cache(path = None): + def from_cache(path: Optional[str] = None) -> 'stem.manual.Manual': """ Provides manual information cached with Stem. Unlike :func:`~stem.manual.Manual.from_man` and @@ -424,7 +426,7 @@ class Manual(object): return manual
@staticmethod - def from_man(man_path = 'tor'): + def from_man(man_path: str = 'tor') -> 'stem.manual.Manual': """ Reads and parses a given man page.
@@ -467,7 +469,7 @@ class Manual(object): )
@staticmethod - def from_remote(timeout = 60): + def from_remote(timeout: int = 60) -> 'stem.manual.Manual': """ Reads and parses the latest tor man page `from gitweb.torproject.org https://gitweb.torproject.org/tor.git/plain/doc/tor.1.txt`_. Note that @@ -500,7 +502,7 @@ class Manual(object): download_man_page(file_handle = tmp, timeout = timeout) return Manual.from_man(tmp.name)
- def save(self, path): + def save(self, path: str) -> None: """ Persists the manual content to a given location.
@@ -549,17 +551,17 @@ class Manual(object):
os.rename(tmp_path, path)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'name', 'synopsis', 'description', 'commandline_options', 'signals', 'files', 'config_options', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, Manual) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
-def _get_categories(content): +def _get_categories(content: str) -> Dict[str, str]: """ The man page is headers followed by an indented section. First pass gets the mapping of category titles to their lines. @@ -605,7 +607,7 @@ def _get_categories(content): return categories
-def _get_indented_descriptions(lines): +def _get_indented_descriptions(lines: Sequence[str]) -> Dict[str, Sequence[str]]: """ Parses the commandline argument and signal sections. These are options followed by an indented description. For example... @@ -635,7 +637,7 @@ def _get_indented_descriptions(lines): return dict([(arg, ' '.join(desc_lines)) for arg, desc_lines in options.items() if desc_lines])
-def _add_config_options(config_options, category, lines): +def _add_config_options(config_options: Mapping[str, 'stem.manual.ConfigOption'], category: str, lines: Sequence[str]) -> None: """ Parses a section of tor configuration options. These have usage information, followed by an indented description. For instance... @@ -653,7 +655,7 @@ def _add_config_options(config_options, category, lines): since that platform lacks getrlimit(). (Default: 1000) """
- def add_option(title, description): + def add_option(title: str, description: str) -> None: if 'PER INSTANCE OPTIONS' in title: return # skip, unfortunately amid the options
@@ -697,7 +699,7 @@ def _add_config_options(config_options, category, lines): add_option(last_title, description)
-def _join_lines(lines): +def _join_lines(lines: Sequence[str]) -> str: """ Simple join, except we want empty lines to still provide a newline. """ diff --git a/stem/process.py b/stem/process.py index a1d805ec..bfab4967 100644 --- a/stem/process.py +++ b/stem/process.py @@ -29,11 +29,13 @@ import stem.util.str_tools import stem.util.system import stem.version
+from typing import Any, Callable, Mapping, Optional, Sequence, Union + NO_TORRC = '<no torrc>' DEFAULT_INIT_TIMEOUT = 90
-def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True, stdin = None): +def launch_tor(tor_cmd: str = 'tor', args: Optional[Sequence[str]] = None, torrc_path: Optional[str] = None, completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True, stdin: Optional[str] = None) -> subprocess.Popen: """ Initializes a tor process. This blocks until initialization completes or we error out. @@ -131,7 +133,7 @@ def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_perce tor_process.stdin.close()
if timeout: - def timeout_handler(signum, frame): + def timeout_handler(signum: int, frame: Any) -> None: raise OSError('reached a %i second timeout without success' % timeout)
signal.signal(signal.SIGALRM, timeout_handler) @@ -197,7 +199,7 @@ def launch_tor(tor_cmd = 'tor', args = None, torrc_path = None, completion_perce pass
-def launch_tor_with_config(config, tor_cmd = 'tor', completion_percent = 100, init_msg_handler = None, timeout = DEFAULT_INIT_TIMEOUT, take_ownership = False, close_output = True): +def launch_tor_with_config(config: Mapping[str, Union[str, Sequence[str]]], tor_cmd: str = 'tor', completion_percent: int = 100, init_msg_handler: Optional[Callable[[str], None]] = None, timeout: int = DEFAULT_INIT_TIMEOUT, take_ownership: bool = False, close_output: bool = True) -> subprocess.Popen: """ Initializes a tor process, like :func:`~stem.process.launch_tor`, but with a customized configuration. This writes a temporary torrc to disk, launches diff --git a/stem/response/__init__.py b/stem/response/__init__.py index 2fbb9c48..4b1f9533 100644 --- a/stem/response/__init__.py +++ b/stem/response/__init__.py @@ -38,6 +38,8 @@ import stem.socket import stem.util import stem.util.str_tools
+from typing import Any, Iterator, Optional, Sequence, Tuple, Union + __all__ = [ 'add_onion', 'events', @@ -54,7 +56,7 @@ __all__ = [ KEY_ARG = re.compile('^(\S+)=')
-def convert(response_type, message, **kwargs): +def convert(response_type: str, message: 'stem.response.ControlMessage', **kwargs: Any) -> None: """ Converts a :class:`~stem.response.ControlMessage` into a particular kind of tor response. This does an in-place conversion of the message from being a @@ -140,7 +142,7 @@ class ControlMessage(object): """
@staticmethod - def from_str(content, msg_type = None, normalize = False, **kwargs): + def from_str(content: str, msg_type: Optional[str] = None, normalize: bool = False, **kwargs: Any) -> 'stem.response.ControlMessage': """ Provides a ControlMessage for the given content.
@@ -171,7 +173,7 @@ class ControlMessage(object):
return msg
- def __init__(self, parsed_content, raw_content, arrived_at = None): + def __init__(self, parsed_content: Sequence[Tuple[str, str, bytes]], raw_content: bytes, arrived_at: Optional[int] = None) -> None: if not parsed_content: raise ValueError("ControlMessages can't be empty")
@@ -182,7 +184,7 @@ class ControlMessage(object): self._str = None self._hash = stem.util._hash_attr(self, '_raw_content')
- def is_ok(self): + def is_ok(self) -> bool: """ Checks if any of our lines have a 250 response.
@@ -195,7 +197,7 @@ class ControlMessage(object):
return False
- def content(self, get_bytes = False): + def content(self, get_bytes: bool = False) -> Sequence[Tuple[str, str, bytes]]: """ Provides the parsed message content. These are entries of the form...
@@ -234,7 +236,7 @@ class ControlMessage(object): else: return list(self._parsed_content)
- def raw_content(self, get_bytes = False): + def raw_content(self, get_bytes: bytes = False) -> Union[str, bytes]: """ Provides the unparsed content read from the control socket.
@@ -251,7 +253,7 @@ class ControlMessage(object): else: return self._raw_content
- def __str__(self): + def __str__(self) -> str: """ Content of the message, stripped of status code and divider protocol formatting. @@ -262,7 +264,7 @@ class ControlMessage(object):
return self._str
- def __iter__(self): + def __iter__(self) -> Iterator['stem.response.ControlLine']: """ Provides :class:`~stem.response.ControlLine` instances for the content of the message. This is stripped of status codes and dividers, for instance... @@ -290,14 +292,14 @@ class ControlMessage(object):
yield ControlLine(content)
- def __len__(self): + def __len__(self) -> int: """ :returns: number of ControlLines """
return len(self._parsed_content)
- def __getitem__(self, index): + def __getitem__(self, index: int) -> 'stem.response.ControlLine': """ :returns: :class:`~stem.response.ControlLine` at the index """ @@ -307,13 +309,13 @@ class ControlMessage(object):
return ControlLine(content)
- def __hash__(self): + def __hash__(self) -> int: return self._hash
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return hash(self) == hash(other) if isinstance(other, ControlMessage) else False
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
@@ -327,14 +329,14 @@ class ControlLine(str): immutable). All methods are thread safe. """
- def __new__(self, value): + def __new__(self, value: str) -> 'stem.response.ControlLine': return str.__new__(self, value)
- def __init__(self, value): + def __init__(self, value: str) -> None: self._remainder = value self._remainder_lock = threading.RLock()
- def remainder(self): + def remainder(self) -> str: """ Provides our unparsed content. This is an empty string after we've popped all entries. @@ -344,7 +346,7 @@ class ControlLine(str):
return self._remainder
- def is_empty(self): + def is_empty(self) -> bool: """ Checks if we have further content to pop or not.
@@ -353,7 +355,7 @@ class ControlLine(str):
return self._remainder == ''
- def is_next_quoted(self, escaped = False): + def is_next_quoted(self, escaped: bool = False) -> bool: """ Checks if our next entry is a quoted value or not.
@@ -365,7 +367,7 @@ class ControlLine(str): start_quote, end_quote = _get_quote_indices(self._remainder, escaped) return start_quote == 0 and end_quote != -1
- def is_next_mapping(self, key = None, quoted = False, escaped = False): + def is_next_mapping(self, key: Optional[str] = None, quoted: bool = False, escaped: bool = False) -> bool: """ Checks if our next entry is a KEY=VALUE mapping or not.
@@ -393,7 +395,7 @@ class ControlLine(str): else: return False # doesn't start with a key
- def peek_key(self): + def peek_key(self) -> str: """ Provides the key of the next entry, providing **None** if it isn't a key/value mapping. @@ -409,7 +411,7 @@ class ControlLine(str): else: return None
- def pop(self, quoted = False, escaped = False): + def pop(self, quoted: bool = False, escaped: bool = False) -> str: """ Parses the next space separated entry, removing it and the space from our remaining content. Examples... @@ -443,7 +445,7 @@ class ControlLine(str): self._remainder = remainder return next_entry
- def pop_mapping(self, quoted = False, escaped = False, get_bytes = False): + def pop_mapping(self, quoted: bool = False, escaped: bool = False, get_bytes: bool = False) -> Tuple[str, str]: """ Parses the next space separated entry as a KEY=VALUE mapping, removing it and the space from our remaining content. @@ -480,13 +482,14 @@ class ControlLine(str): return (key, next_entry)
-def _parse_entry(line, quoted, escaped, get_bytes): +def _parse_entry(line: str, quoted: bool, escaped: bool, get_bytes: bool) -> Tuple[Union[str, bytes], str]: """ Parses the next entry from the given space separated content.
:param str line: content to be parsed :param bool quoted: parses the next entry as a quoted value, removing the quotes :param bool escaped: unescapes the string + :param bool get_bytes: provides **bytes** for the entry rather than a **str**
:returns: **tuple** of the form (entry, remainder)
@@ -540,7 +543,7 @@ def _parse_entry(line, quoted, escaped, get_bytes): return (next_entry, remainder.lstrip())
-def _get_quote_indices(line, escaped): +def _get_quote_indices(line: str, escaped: bool) -> Tuple[int, int]: """ Provides the indices of the next two quotes in the given content.
@@ -576,7 +579,7 @@ class SingleLineResponse(ControlMessage): :var str message: content of the line """
- def is_ok(self, strict = False): + def is_ok(self, strict: bool = False) -> bool: """ Checks if the response code is "250". If strict is **True** then this checks if the response is "250 OK" @@ -593,7 +596,7 @@ class SingleLineResponse(ControlMessage):
return self.content()[0][0] == '250'
- def _parse_message(self): + def _parse_message(self) -> None: content = self.content()
if len(content) > 1: diff --git a/stem/response/add_onion.py b/stem/response/add_onion.py index 64d58282..3f52f9f2 100644 --- a/stem/response/add_onion.py +++ b/stem/response/add_onion.py @@ -15,7 +15,7 @@ class AddOnionResponse(stem.response.ControlMessage): :var dict client_auth: newly generated client credentials the service accepts """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-ServiceID=gfzprpioee3hoppz # 250-PrivateKey=RSA1024:MIICXgIBAAKBgQDZvYVxv... diff --git a/stem/response/authchallenge.py b/stem/response/authchallenge.py index d9cc5491..80a1c0f5 100644 --- a/stem/response/authchallenge.py +++ b/stem/response/authchallenge.py @@ -17,7 +17,7 @@ class AuthChallengeResponse(stem.response.ControlMessage): :var str server_nonce: server nonce provided by tor """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250 AUTHCHALLENGE SERVERHASH=680A73C9836C4F557314EA1C4EDE54C285DB9DC89C83627401AEF9D7D27A95D5 SERVERNONCE=F8EA4B1F2C8B40EF1AF68860171605B910E3BBCABADF6FC3DB1FA064F4690E85
diff --git a/stem/response/events.py b/stem/response/events.py index fdd17a25..0e112373 100644 --- a/stem/response/events.py +++ b/stem/response/events.py @@ -12,6 +12,7 @@ import stem.util import stem.version
from stem.util import connection, log, str_tools, tor_tools +from typing import Any, Dict, Sequence
# Matches keyword=value arguments. This can't be a simple "(.*)=(.*)" pattern # because some positional arguments, like circuit paths, can have an equal @@ -40,7 +41,7 @@ class Event(stem.response.ControlMessage): _SKIP_PARSING = False # skip parsing contents into our positional_args and keyword_args _VERSION_ADDED = stem.version.Version('0.1.1.1-alpha') # minimum version with control-spec V1 event support
- def _parse_message(self): + def _parse_message(self) -> None: if not str(self).strip(): raise stem.ProtocolError('Received a blank tor event. Events must at the very least have a type.')
@@ -58,10 +59,10 @@ class Event(stem.response.ControlMessage):
self._parse()
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'arrived_at', parent = stem.response.ControlMessage, cache = True)
- def _parse_standard_attr(self): + def _parse_standard_attr(self) -> None: """ Most events are of the form... 650 *( positional_args ) *( key "=" value ) @@ -122,7 +123,7 @@ class Event(stem.response.ControlMessage): for controller_attr_name, attr_name in self._KEYWORD_ARGS.items(): setattr(self, attr_name, self.keyword_args.get(controller_attr_name))
- def _iso_timestamp(self, timestamp): + def _iso_timestamp(self, timestamp: str) -> 'datetime.datetime': """ Parses an iso timestamp (ISOTime2Frac in the control-spec).
@@ -142,10 +143,10 @@ class Event(stem.response.ControlMessage): raise stem.ProtocolError('Unable to parse timestamp (%s): %s' % (exc, self))
# method overwritten by our subclasses for special handling that they do - def _parse(self): + def _parse(self) -> None: pass
- def _log_if_unrecognized(self, attr, attr_enum): + def _log_if_unrecognized(self, attr: str, attr_enum: 'stem.util.enum.Enum') -> None: """ Checks if an attribute exists in a given enumeration, logging a message if it isn't. Attributes can either be for a string or collection of strings @@ -196,7 +197,7 @@ class AddrMapEvent(Event): } _OPTIONALLY_QUOTED = ('expiry')
- def _parse(self): + def _parse(self) -> None: if self.destination == '<error>': self.destination = None
@@ -234,7 +235,7 @@ class BandwidthEvent(Event):
_POSITIONAL_ARGS = ('read', 'written')
- def _parse(self): + def _parse(self) -> None: if not self.read: raise stem.ProtocolError('BW event is missing its read value') elif not self.written: @@ -277,7 +278,7 @@ class BuildTimeoutSetEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.2.7-alpha')
- def _parse(self): + def _parse(self) -> None: # convert our integer and float parameters
for param in ('total_times', 'timeout', 'xm', 'close_timeout'): @@ -346,7 +347,7 @@ class CircuitEvent(Event): 'SOCKS_PASSWORD': 'socks_password', }
- def _parse(self): + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created)
@@ -363,7 +364,7 @@ class CircuitEvent(Event): self._log_if_unrecognized('reason', stem.CircClosureReason) self._log_if_unrecognized('remote_reason', stem.CircClosureReason)
- def _compare(self, other, method): + def _compare(self, other: Any, method: Any) -> bool: # sorting circuit events by their identifier
if not isinstance(other, CircuitEvent): @@ -374,10 +375,10 @@ class CircuitEvent(Event):
return method(my_id, their_id) if my_id != their_id else method(hash(self), hash(other))
- def __gt__(self, other): + def __gt__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s > o)
- def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s >= o)
@@ -414,7 +415,7 @@ class CircMinorEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.3.11-alpha')
- def _parse(self): + def _parse(self) -> None: self.path = tuple(stem.control._parse_circ_path(self.path)) self.created = self._iso_timestamp(self.created)
@@ -450,7 +451,7 @@ class ClientsSeenEvent(Event): } _VERSION_ADDED = stem.version.Version('0.2.1.10-alpha')
- def _parse(self): + def _parse(self) -> None: if self.start_time is not None: self.start_time = stem.util.str_tools._parse_timestamp(self.start_time)
@@ -509,7 +510,7 @@ class ConfChangedEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.3.3-alpha')
- def _parse(self): + def _parse(self) -> None: self.changed = {} self.unset = []
@@ -563,7 +564,7 @@ class GuardEvent(Event): _VERSION_ADDED = stem.version.Version('0.1.2.5-alpha') _POSITIONAL_ARGS = ('guard_type', 'endpoint', 'status')
- def _parse(self): + def _parse(self) -> None: self.endpoint_fingerprint = None self.endpoint_nickname = None
@@ -610,7 +611,7 @@ class HSDescEvent(Event): _POSITIONAL_ARGS = ('action', 'address', 'authentication', 'directory', 'descriptor_id') _KEYWORD_ARGS = {'REASON': 'reason', 'REPLICA': 'replica', 'HSDIR_INDEX': 'index'}
- def _parse(self): + def _parse(self) -> None: self.directory_fingerprint = None self.directory_nickname = None
@@ -650,7 +651,7 @@ class HSDescContentEvent(Event): _VERSION_ADDED = stem.version.Version('0.2.7.1-alpha') _POSITIONAL_ARGS = ('address', 'descriptor_id', 'directory')
- def _parse(self): + def _parse(self) -> None: if self.address == 'UNKNOWN': self.address = None
@@ -686,7 +687,7 @@ class LogEvent(Event):
_SKIP_PARSING = True
- def _parse(self): + def _parse(self) -> None: self.runlevel = self.type self._log_if_unrecognized('runlevel', stem.Runlevel)
@@ -709,7 +710,7 @@ class NetworkStatusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
- def _parse(self): + def _parse(self) -> None: content = str(self).lstrip('NS\n').rstrip('\nOK')
self.descriptors = list(stem.descriptor.router_status_entry._parse_file( @@ -753,11 +754,11 @@ class NewConsensusEvent(Event): _SKIP_PARSING = True _VERSION_ADDED = stem.version.Version('0.2.1.13-alpha')
- def _parse(self): + def _parse(self) -> None: self.consensus_content = str(self).lstrip('NEWCONSENSUS\n').rstrip('\nOK') self._parsed = None
- def entries(self): + def entries(self) -> Sequence['stem.descriptor.router_status_entry.RouterStatusEntryV3']: """ Relay router status entries residing within this consensus.
@@ -791,7 +792,7 @@ class NewDescEvent(Event): new descriptors """
- def _parse(self): + def _parse(self) -> None: self.relays = tuple([stem.control._parse_circ_entry(entry) for entry in str(self).split()[1:]])
@@ -832,7 +833,7 @@ class ORConnEvent(Event): 'ID': 'id', }
- def _parse(self): + def _parse(self) -> None: self.endpoint_fingerprint = None self.endpoint_nickname = None self.endpoint_address = None @@ -886,7 +887,7 @@ class SignalEvent(Event): _POSITIONAL_ARGS = ('signal',) _VERSION_ADDED = stem.version.Version('0.2.3.1-alpha')
- def _parse(self): + def _parse(self) -> None: # log if we recieved an unrecognized signal expected_signals = ( stem.Signal.RELOAD, @@ -918,7 +919,7 @@ class StatusEvent(Event): _POSITIONAL_ARGS = ('runlevel', 'action') _VERSION_ADDED = stem.version.Version('0.1.2.3-alpha')
- def _parse(self): + def _parse(self) -> None: if self.type == 'STATUS_GENERAL': self.status_type = stem.StatusType.GENERAL elif self.type == 'STATUS_CLIENT': @@ -970,7 +971,7 @@ class StreamEvent(Event): 'PURPOSE': 'purpose', }
- def _parse(self): + def _parse(self) -> None: if self.target is None: raise stem.ProtocolError("STREAM event didn't have a target: %s" % self) else: @@ -1029,7 +1030,7 @@ class StreamBwEvent(Event): _POSITIONAL_ARGS = ('id', 'written', 'read', 'time') _VERSION_ADDED = stem.version.Version('0.1.2.8-beta')
- def _parse(self): + def _parse(self) -> None: if not tor_tools.is_valid_stream_id(self.id): raise stem.ProtocolError("Stream IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif not self.written: @@ -1062,7 +1063,7 @@ class TransportLaunchedEvent(Event): _POSITIONAL_ARGS = ('type', 'name', 'address', 'port') _VERSION_ADDED = stem.version.Version('0.2.5.0-alpha')
- def _parse(self): + def _parse(self) -> None: if self.type not in ('server', 'client'): raise stem.ProtocolError("Transport type should either be 'server' or 'client': %s" % self)
@@ -1104,7 +1105,7 @@ class ConnectionBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self): + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CONN_BW event is missing its id') elif not self.conn_type: @@ -1163,7 +1164,7 @@ class CircuitBandwidthEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self): + def _parse(self) -> None: if not self.id: raise stem.ProtocolError('CIRC_BW event is missing its id') elif not self.read: @@ -1233,7 +1234,7 @@ class CellStatsEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self): + def _parse(self) -> None: if self.id and not tor_tools.is_valid_circuit_id(self.id): raise stem.ProtocolError("Circuit IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif self.inbound_queue and not tor_tools.is_valid_circuit_id(self.inbound_queue): @@ -1279,7 +1280,7 @@ class TokenBucketEmptyEvent(Event):
_VERSION_ADDED = stem.version.Version('0.2.5.2-alpha')
- def _parse(self): + def _parse(self) -> None: if self.id and not tor_tools.is_valid_connection_id(self.id): raise stem.ProtocolError("Connection IDs must be one to sixteen alphanumeric characters, got '%s': %s" % (self.id, self)) elif not self.read.isdigit(): @@ -1296,7 +1297,7 @@ class TokenBucketEmptyEvent(Event): self._log_if_unrecognized('bucket', stem.TokenBucket)
-def _parse_cell_type_mapping(mapping): +def _parse_cell_type_mapping(mapping: str) -> Dict[str, int]: """ Parses a mapping of the form...
diff --git a/stem/response/getconf.py b/stem/response/getconf.py index 6de49b1f..7ba972ae 100644 --- a/stem/response/getconf.py +++ b/stem/response/getconf.py @@ -16,7 +16,7 @@ class GetConfResponse(stem.response.ControlMessage): values (**list** of **str**) """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-CookieAuthentication=0 # 250-ControlPort=9100 diff --git a/stem/response/getinfo.py b/stem/response/getinfo.py index 27442ffd..7aebd70a 100644 --- a/stem/response/getinfo.py +++ b/stem/response/getinfo.py @@ -4,6 +4,8 @@ import stem.response import stem.socket
+from typing import Sequence +
class GetInfoResponse(stem.response.ControlMessage): """ @@ -12,7 +14,7 @@ class GetInfoResponse(stem.response.ControlMessage): :var dict entries: mapping between the queried options and their bytes values """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-version=0.2.3.11-alpha-dev (git-ef0bc7f8f26a917c) # 250+config-text= @@ -66,7 +68,7 @@ class GetInfoResponse(stem.response.ControlMessage):
self.entries[key] = value
- def _assert_matches(self, params): + def _assert_matches(self, params: Sequence[str]) -> None: """ Checks if we match a given set of parameters, and raise a ProtocolError if not.
diff --git a/stem/response/mapaddress.py b/stem/response/mapaddress.py index 73ed84f1..92ce16d2 100644 --- a/stem/response/mapaddress.py +++ b/stem/response/mapaddress.py @@ -17,7 +17,7 @@ class MapAddressResponse(stem.response.ControlMessage): * :class:`stem.InvalidRequest` if the addresses provided were invalid """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-127.192.10.10=torproject.org # 250 1.2.3.4=tor.freehaven.net diff --git a/stem/response/protocolinfo.py b/stem/response/protocolinfo.py index 459fef5b..330b165e 100644 --- a/stem/response/protocolinfo.py +++ b/stem/response/protocolinfo.py @@ -8,7 +8,6 @@ import stem.socket import stem.version import stem.util.str_tools
-from stem.connection import AuthMethod from stem.util import log
@@ -26,13 +25,15 @@ class ProtocolInfoResponse(stem.response.ControlMessage): :var str cookie_path: path of tor's authentication cookie """
- def _parse_message(self): + def _parse_message(self) -> None: # Example: # 250-PROTOCOLINFO 1 # 250-AUTH METHODS=COOKIE COOKIEFILE="/home/atagar/.tor/control_auth_cookie" # 250-VERSION Tor="0.2.1.30" # 250 OK
+ from stem.connection import AuthMethod + self.protocol_version = None self.tor_version = None self.auth_methods = () diff --git a/stem/socket.py b/stem/socket.py index db110973..179ae16e 100644 --- a/stem/socket.py +++ b/stem/socket.py @@ -62,8 +62,7 @@ Tor... |- is_localhost - returns if the socket is for the local system or not |- connection_time - timestamp when socket last connected or disconnected |- connect - connects a new socket - |- close - shuts down the socket - +- __enter__ / __exit__ - manages socket connection + +- close - shuts down the socket
send_message - Writes a message to a control socket. recv_message - Reads a ControlMessage from a control socket. @@ -80,6 +79,8 @@ import stem.response import stem.util.str_tools
from stem.util import log +from types import TracebackType +from typing import BinaryIO, Callable, Optional, Type
MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]') ERROR_MSG = 'Error while receiving a control message (%s): %s' @@ -94,7 +95,7 @@ class BaseSocket(object): Thread safe socket, providing common socket functionality. """
- def __init__(self): + def __init__(self) -> None: self._socket, self._socket_file = None, None self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected @@ -106,7 +107,7 @@ class BaseSocket(object): self._send_lock = threading.RLock() self._recv_lock = threading.RLock()
- def is_alive(self): + def is_alive(self) -> bool: """ Checks if the socket is known to be closed. We won't be aware if it is until we either use it or have explicitily shut it down. @@ -125,7 +126,7 @@ class BaseSocket(object):
return self._is_alive
- def is_localhost(self): + def is_localhost(self) -> bool: """ Returns if the connection is for the local system or not.
@@ -135,7 +136,7 @@ class BaseSocket(object):
return False
- def connection_time(self): + def connection_time(self) -> float: """ Provides the unix timestamp for when our socket was either connected or disconnected. That is to say, the time we connected if we're currently @@ -149,7 +150,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self): + def connect(self) -> None: """ Connects to a new socket, closing our previous one if we're already attached. @@ -181,7 +182,7 @@ class BaseSocket(object): except stem.SocketError: self._connect() # single retry
- def close(self): + def close(self) -> None: """ Shuts down the socket. If it's already closed then this is a no-op. """ @@ -217,7 +218,7 @@ class BaseSocket(object): if is_change: self._close()
- def _send(self, message, handler): + def _send(self, message: str, handler: Callable[[socket.socket, BinaryIO, str], None]) -> None: """ Send message in a thread safe manner. Handler is expected to be of the form...
@@ -241,7 +242,7 @@ class BaseSocket(object):
raise
- def _recv(self, handler): + def _recv(self, handler: Callable[[socket.socket, BinaryIO], None]) -> bytes: """ Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -283,7 +284,7 @@ class BaseSocket(object):
raise
- def _get_send_lock(self): + def _get_send_lock(self) -> threading.RLock: """ The send lock is useful to classes that interact with us at a deep level because it's used to lock :func:`stem.socket.ControlSocket.connect` / @@ -296,27 +297,27 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self): + def __enter__(self) -> 'stem.socket.BaseSocket': return self
- def __exit__(self, exit_type, value, traceback): + def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): self.close()
- def _connect(self): + def _connect(self) -> None: """ Connection callback that can be overwritten by subclasses and wrappers. """
pass
- def _close(self): + def _close(self) -> None: """ Disconnection callback that can be overwritten by subclasses and wrappers. """
pass
- def _make_socket(self): + def _make_socket(self) -> socket.socket: """ Constructs and connects new socket. This is implemented by subclasses.
@@ -342,7 +343,7 @@ class RelaySocket(BaseSocket): :var int port: ORPort our socket connects to """
- def __init__(self, address = '127.0.0.1', port = 9050, connect = True): + def __init__(self, address: str = '127.0.0.1', port: int = 9050, connect: bool = True) -> None: """ RelaySocket constructor.
@@ -361,7 +362,7 @@ class RelaySocket(BaseSocket): if connect: self.connect()
- def send(self, message): + def send(self, message: str) -> None: """ Sends a message to the relay's ORPort.
@@ -374,7 +375,7 @@ class RelaySocket(BaseSocket):
self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg))
- def recv(self, timeout = None): + def recv(self, timeout: Optional[float] = None) -> bytes: """ Receives a message from the relay.
@@ -388,7 +389,7 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- def wrapped_recv(s, sf): + def wrapped_recv(s: socket.socket, sf: BinaryIO) -> bytes: if timeout is None: return s.recv() else: @@ -404,10 +405,10 @@ class RelaySocket(BaseSocket):
return self._recv(wrapped_recv)
- def is_localhost(self): + def is_localhost(self) -> bool: return self.address == '127.0.0.1'
- def _make_socket(self): + def _make_socket(self) -> socket.socket: try: relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) relay_socket.connect((self.address, self.port)) @@ -426,10 +427,10 @@ class ControlSocket(BaseSocket): which are expected to implement the **_make_socket()** method. """
- def __init__(self): + def __init__(self) -> None: super(ControlSocket, self).__init__()
- def send(self, message): + def send(self, message: str) -> None: """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -443,7 +444,7 @@ class ControlSocket(BaseSocket):
self._send(message, lambda s, sf, msg: send_message(sf, msg))
- def recv(self): + def recv(self) -> stem.response.ControlMessage: """ Receives a message from the control socket, blocking until we've received one. For more information see the :func:`~stem.socket.recv_message` function. @@ -467,7 +468,7 @@ class ControlPort(ControlSocket): :var int port: ControlPort our socket connects to """
- def __init__(self, address = '127.0.0.1', port = 9051, connect = True): + def __init__(self, address: str = '127.0.0.1', port: int = 9051, connect: bool = True) -> None: """ ControlPort constructor.
@@ -486,10 +487,10 @@ class ControlPort(ControlSocket): if connect: self.connect()
- def is_localhost(self): + def is_localhost(self) -> bool: return self.address == '127.0.0.1'
- def _make_socket(self): + def _make_socket(self) -> socket.socket: try: control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) control_socket.connect((self.address, self.port)) @@ -506,7 +507,7 @@ class ControlSocketFile(ControlSocket): :var str path: filesystem path of the socket we connect to """
- def __init__(self, path = '/var/run/tor/control', connect = True): + def __init__(self, path: str = '/var/run/tor/control', connect: bool = True) -> None: """ ControlSocketFile constructor.
@@ -523,10 +524,10 @@ class ControlSocketFile(ControlSocket): if connect: self.connect()
- def is_localhost(self): + def is_localhost(self) -> bool: return True
- def _make_socket(self): + def _make_socket(self) -> socket.socket: try: control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) control_socket.connect(self.path) @@ -535,7 +536,7 @@ class ControlSocketFile(ControlSocket): raise stem.SocketError(exc)
-def send_message(control_file, message, raw = False): +def send_message(control_file: BinaryIO, message: str, raw: bool = False) -> None: """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -578,7 +579,7 @@ def send_message(control_file, message, raw = False): log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file, message): +def _write_to_socket(socket_file: BinaryIO, message: str) -> None: try: socket_file.write(stem.util.str_tools._to_bytes(message)) socket_file.flush() @@ -601,7 +602,7 @@ def _write_to_socket(socket_file, message): raise stem.SocketClosed('file has been closed')
-def recv_message(control_file, arrived_at = None): +def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage: """ Pulls from a control socket until we either have a complete message or encounter a problem. @@ -721,7 +722,7 @@ def recv_message(control_file, arrived_at = None): raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line)))
-def send_formatting(message): +def send_formatting(message: str) -> None: """ Performs the formatting expected from sent control messages. For more information see the :func:`~stem.socket.send_message` function. @@ -750,7 +751,7 @@ def send_formatting(message): return message + '\r\n'
-def _log_trace(response): +def _log_trace(response: bytes) -> None: if not log.is_tracing(): return
diff --git a/stem/util/__init__.py b/stem/util/__init__.py index e4e08174..050f6c91 100644 --- a/stem/util/__init__.py +++ b/stem/util/__init__.py @@ -7,6 +7,8 @@ Utility functions used by the stem library.
import datetime
+from typing import Any, Union + __all__ = [ 'conf', 'connection', @@ -43,7 +45,7 @@ __all__ = [ HASH_TYPES = True
-def _hash_value(val): +def _hash_value(val: Any) -> int: if not HASH_TYPES: my_hash = 0 else: @@ -64,7 +66,7 @@ def _hash_value(val): return my_hash
-def datetime_to_unix(timestamp): +def datetime_to_unix(timestamp: 'datetime.datetime') -> float: """ Converts a utc datetime object to a unix timestamp.
@@ -78,7 +80,7 @@ def datetime_to_unix(timestamp): return (timestamp - datetime.datetime(1970, 1, 1)).total_seconds()
-def _pubkey_bytes(key): +def _pubkey_bytes(key: Union['cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey', 'cryptography.hazmat.primitives.asymmetric.x25519.X25519PublicKey']) -> bytes: """ Normalizes X25509 and ED25519 keys into their public key bytes. """ @@ -107,7 +109,7 @@ def _pubkey_bytes(key): raise ValueError('Key must be a string or cryptographic public/private key (was %s)' % type(key).__name__)
-def _hash_attr(obj, *attributes, **kwargs): +def _hash_attr(obj: Any, *attributes: str, **kwargs: Any): """ Provide a hash value for the given set of attributes.
diff --git a/stem/util/conf.py b/stem/util/conf.py index a06f1fd7..37d1c5f4 100644 --- a/stem/util/conf.py +++ b/stem/util/conf.py @@ -163,16 +163,17 @@ import os import threading
from stem.util import log +from typing import Any, Callable, Mapping, Optional, Sequence, Union
CONFS = {} # mapping of identifier to singleton instances of configs
class _SyncListener(object): - def __init__(self, config_dict, interceptor): + def __init__(self, config_dict: Mapping[str, Any], interceptor: Callable[[str, Any], Any]) -> None: self.config_dict = config_dict self.interceptor = interceptor
- def update(self, config, key): + def update(self, config: 'stem.util.conf.Config', key: str) -> None: if key in self.config_dict: new_value = config.get(key, self.config_dict[key])
@@ -188,7 +189,7 @@ class _SyncListener(object): self.config_dict[key] = new_value
-def config_dict(handle, conf_mappings, handler = None): +def config_dict(handle: str, conf_mappings: Mapping[str, Any], handler: Optional[Callable[[str, Any], Any]] = None) -> Mapping[str, Any]: """ Makes a dictionary that stays synchronized with a configuration.
@@ -214,6 +215,8 @@ def config_dict(handle, conf_mappings, handler = None): :param str handle: unique identifier for a config instance :param dict conf_mappings: config key/value mappings used as our defaults :param functor handler: function referred to prior to assigning values + + :returns: mapping of attributes to their current configuration value """
selected_config = get_config(handle) @@ -221,7 +224,7 @@ def config_dict(handle, conf_mappings, handler = None): return conf_mappings
-def get_config(handle): +def get_config(handle: str) -> 'stem.util.conf.Config': """ Singleton constructor for configuration file instances. If a configuration already exists for the handle then it's returned. Otherwise a fresh instance @@ -236,7 +239,7 @@ def get_config(handle): return CONFS[handle]
-def uses_settings(handle, path, lazy_load = True): +def uses_settings(handle: str, path: str, lazy_load: bool = True) -> Callable: """ Provides a function that can be used as a decorator for other functions that require settings to be loaded. Functions with this decorator will be provided @@ -272,13 +275,13 @@ def uses_settings(handle, path, lazy_load = True): config.load(path) config._settings_loaded = True
- def decorator(func): - def wrapped(*args, **kwargs): + def decorator(func: Callable) -> Callable: + def wrapped(*args: Any, **kwargs: Any) -> Any: if lazy_load and not config._settings_loaded: config.load(path) config._settings_loaded = True
- if 'config' in inspect.getargspec(func).args: + if 'config' in inspect.getfullargspec(func).args: return func(*args, config = config, **kwargs) else: return func(*args, **kwargs) @@ -288,7 +291,7 @@ def uses_settings(handle, path, lazy_load = True): return decorator
-def parse_enum(key, value, enumeration): +def parse_enum(key: str, value: str, enumeration: 'stem.util.enum.Enum') -> Any: """ Provides the enumeration value for a given key. This is a case insensitive lookup and raises an exception if the enum key doesn't exist. @@ -305,7 +308,7 @@ def parse_enum(key, value, enumeration): return parse_enum_csv(key, value, enumeration, 1)[0]
-def parse_enum_csv(key, value, enumeration, count = None): +def parse_enum_csv(key: str, value: str, enumeration: 'stem.util.enum.Enum', count: Optional[Union[int, Sequence[int]]] = None) -> Sequence[Any]: """ Parses a given value as being a comma separated listing of enumeration keys, returning the corresponding enumeration values. This is intended to be a @@ -445,7 +448,7 @@ class Config(object): Class can now be used as a dictionary. """
- def __init__(self): + def __init__(self) -> None: self._path = None # location we last loaded from or saved to self._contents = collections.OrderedDict() # configuration key/value pairs self._listeners = [] # functors to be notified of config changes @@ -459,7 +462,7 @@ class Config(object): # flag to support lazy loading in uses_settings() self._settings_loaded = False
- def load(self, path = None, commenting = True): + def load(self, path: Optional[str] = None, commenting: bool = True) -> None: """ Reads in the contents of the given path, adding its configuration values to our current contents. If the path is a directory then this loads each @@ -534,7 +537,7 @@ class Config(object): else: self.set(line, '', False) # default to a key => '' mapping
- def save(self, path = None): + def save(self, path: Optional[str] = None) -> None: """ Saves configuration contents to disk. If a path is provided then it replaces the configuration location that we track. @@ -564,7 +567,7 @@ class Config(object):
output_file.write('%s %s\n' % (entry_key, entry_value))
- def clear(self): + def clear(self) -> None: """ Drops the configuration contents and reverts back to a blank, unloaded state. @@ -574,7 +577,7 @@ class Config(object): self._contents.clear() self._requested_keys = set()
- def add_listener(self, listener, backfill = True): + def add_listener(self, listener: Callable[[str, Any], Any], backfill: bool = True) -> None: """ Registers the function to be notified of configuration updates. Listeners are expected to be functors which accept (config, key). @@ -590,14 +593,14 @@ class Config(object): for key in self.keys(): listener(self, key)
- def clear_listeners(self): + def clear_listeners(self) -> None: """ Removes all attached listeners. """
self._listeners = []
- def keys(self): + def keys(self) -> Sequence[str]: """ Provides all keys in the currently loaded configuration.
@@ -606,7 +609,7 @@ class Config(object):
return list(self._contents.keys())
- def unused_keys(self): + def unused_keys(self) -> Sequence[str]: """ Provides the configuration keys that have never been provided to a caller via :func:`~stem.util.conf.config_dict` or the @@ -618,7 +621,7 @@ class Config(object):
return set(self.keys()).difference(self._requested_keys)
- def set(self, key, value, overwrite = True): + def set(self, key: str, value: Union[str, Sequence[str]], overwrite: bool = True) -> None: """ Appends the given key/value configuration mapping, behaving the same as if we'd loaded this from a configuration file. @@ -657,7 +660,7 @@ class Config(object): else: raise ValueError("Config.set() only accepts str (bytes or unicode), list, or tuple. Provided value was a '%s'" % type(value))
- def get(self, key, default = None): + def get(self, key: str, default: Optional[Any] = None) -> Any: """ Fetches the given configuration, using the key and default value to determine the type it should be. Recognized inferences are: @@ -737,7 +740,7 @@ class Config(object):
return val
- def get_value(self, key, default = None, multiple = False): + def get_value(self, key: str, default: Optional[Any] = None, multiple: bool = False) -> Union[str, Sequence[str]]: """ This provides the current value associated with a given key.
@@ -763,6 +766,6 @@ class Config(object): log.log_once(message_id, log.TRACE, "config entry '%s' not found, defaulting to '%s'" % (key, default)) return default
- def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: with self._contents_lock: return self._contents[key] diff --git a/stem/util/connection.py b/stem/util/connection.py index eaeafec4..2f815a46 100644 --- a/stem/util/connection.py +++ b/stem/util/connection.py @@ -65,6 +65,7 @@ import stem.util.proc import stem.util.system
from stem.util import conf, enum, log, str_tools +from typing import Optional, Sequence, Union
# Connection resolution is risky to log about since it's highly likely to # contain sensitive information. That said, it's also difficult to get right in @@ -157,7 +158,7 @@ class Connection(collections.namedtuple('Connection', ['local_address', 'local_p """
-def download(url, timeout = None, retries = None): +def download(url: str, timeout: Optional[int] = None, retries: Optional[int] = None) -> bytes: """ Download from the given url.
@@ -198,7 +199,7 @@ def download(url, timeout = None, retries = None): raise stem.DownloadFailed(url, exc, stacktrace)
-def get_connections(resolver = None, process_pid = None, process_name = None): +def get_connections(resolver: Optional['stem.util.connection.Resolver'] = None, process_pid: Optional[int] = None, process_name: Optional[str] = None) -> Sequence['stem.util.connection.Connection']: """ Retrieves a list of the current connections for a given process. This provides a list of :class:`~stem.util.connection.Connection`. Note that @@ -239,7 +240,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None): if not process_pid and not process_name: raise ValueError('You must provide a pid or process name to provide connections for')
- def _log(msg): + def _log(msg: str) -> None: if LOG_CONNECTION_RESOLUTION: log.debug(msg)
@@ -288,7 +289,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None): connections = [] resolver_regex = re.compile(resolver_regex_str)
- def _parse_address_str(addr_type, addr_str, line): + def _parse_address_str(addr_type: str, addr_str: str, line: str) -> str: addr, port = addr_str.rsplit(':', 1)
if not is_valid_ipv4_address(addr) and not is_valid_ipv6_address(addr, allow_brackets = True): @@ -334,7 +335,7 @@ def get_connections(resolver = None, process_pid = None, process_name = None): return connections
-def system_resolvers(system = None): +def system_resolvers(system: Optional[str] = None) -> Sequence['stem.util.connection.Resolver']: """ Provides the types of connection resolvers likely to be available on this platform.
@@ -383,7 +384,7 @@ def system_resolvers(system = None): return resolvers
-def port_usage(port): +def port_usage(port: int) -> Optional[str]: """ Provides the common use of a given port. For example, 'HTTP' for port 80 or 'SSH' for 22. @@ -429,7 +430,7 @@ def port_usage(port): return PORT_USES.get(port)
-def is_valid_ipv4_address(address): +def is_valid_ipv4_address(address: str) -> bool: """ Checks if a string is a valid IPv4 address.
@@ -458,7 +459,7 @@ def is_valid_ipv4_address(address): return True
-def is_valid_ipv6_address(address, allow_brackets = False): +def is_valid_ipv6_address(address: str, allow_brackets: bool = False) -> bool: """ Checks if a string is a valid IPv6 address.
@@ -513,7 +514,7 @@ def is_valid_ipv6_address(address, allow_brackets = False): return True
-def is_valid_port(entry, allow_zero = False): +def is_valid_port(entry: Union[str, int, Sequence[str], Sequence[int]], allow_zero: bool = False) -> bool: """ Checks if a string or int is a valid port number.
@@ -545,7 +546,7 @@ def is_valid_port(entry, allow_zero = False): return False
-def is_private_address(address): +def is_private_address(address: str) -> bool: """ Checks if the IPv4 address is in a range belonging to the local network or loopback. These include: @@ -581,7 +582,7 @@ def is_private_address(address): return False
-def address_to_int(address): +def address_to_int(address: str) -> int: """ Provides an integer representation of a IPv4 or IPv6 address that can be used for sorting. @@ -599,7 +600,7 @@ def address_to_int(address): return int(_address_to_binary(address), 2)
-def expand_ipv6_address(address): +def expand_ipv6_address(address: str) -> str: """ Expands abbreviated IPv6 addresses to their full colon separated hex format. For instance... @@ -660,7 +661,7 @@ def expand_ipv6_address(address): return address
-def get_mask_ipv4(bits): +def get_mask_ipv4(bits: int) -> str: """ Provides the IPv4 mask for a given number of bits, in the dotted-quad format.
@@ -686,7 +687,7 @@ def get_mask_ipv4(bits): return '.'.join([str(int(octet, 2)) for octet in octets])
-def get_mask_ipv6(bits): +def get_mask_ipv6(bits: int) -> str: """ Provides the IPv6 mask for a given number of bits, in the hex colon-delimited format. @@ -713,7 +714,7 @@ def get_mask_ipv6(bits): return ':'.join(['%04x' % int(group, 2) for group in groupings]).upper()
-def _get_masked_bits(mask): +def _get_masked_bits(mask: str) -> int: """ Provides the number of bits that an IPv4 subnet mask represents. Note that not all masks can be represented by a bit count. @@ -738,13 +739,15 @@ def _get_masked_bits(mask): raise ValueError('Unable to convert mask to a bit count: %s' % mask)
-def _get_binary(value, bits): +def _get_binary(value: int, bits: int) -> str: """ Provides the given value as a binary string, padded with zeros to the given number of bits.
:param int value: value to be converted :param int bits: number of bits to pad to + + :returns: **str** of this binary value """
# http://www.daniweb.com/code/snippet216539.html @@ -754,10 +757,12 @@ def _get_binary(value, bits): # TODO: In stem 2.x we should consider unifying this with # stem.client.datatype's _unpack_ipv4_address() and _unpack_ipv6_address().
-def _address_to_binary(address): +def _address_to_binary(address: str) -> str: """ Provides the binary value for an IPv4 or IPv6 address.
+ :param str address: address to convert + :returns: **str** with the binary representation of this address
:raises: **ValueError** if address is neither an IPv4 nor IPv6 address diff --git a/stem/util/enum.py b/stem/util/enum.py index 56bf119d..b70d29f4 100644 --- a/stem/util/enum.py +++ b/stem/util/enum.py @@ -40,8 +40,10 @@ constructed as simple type listings... +- __iter__ - iterator over our enum keys """
+from typing import Iterator, Sequence
-def UppercaseEnum(*args): + +def UppercaseEnum(*args: str) -> 'stem.util.enum.Enum': """ Provides an :class:`~stem.util.enum.Enum` instance where the values are identical to the keys. Since the keys are uppercase by convention this means @@ -67,7 +69,7 @@ class Enum(object): Basic enumeration. """
- def __init__(self, *args): + def __init__(self, *args: str) -> None: from stem.util.str_tools import _to_camel_case
# ordered listings of our keys and values @@ -88,7 +90,7 @@ class Enum(object): self._keys = tuple(keys) self._values = tuple(values)
- def keys(self): + def keys(self) -> Sequence[str]: """ Provides an ordered listing of the enumeration keys in this set.
@@ -97,7 +99,7 @@ class Enum(object):
return list(self._keys)
- def index_of(self, value): + def index_of(self, value: str) -> int: """ Provides the index of the given value in the collection.
@@ -110,7 +112,7 @@ class Enum(object):
return self._values.index(value)
- def next(self, value): + def next(self, value: str) -> str: """ Provides the next enumeration after the given value.
@@ -127,7 +129,7 @@ class Enum(object): next_index = (self._values.index(value) + 1) % len(self._values) return self._values[next_index]
- def previous(self, value): + def previous(self, value: str) -> str: """ Provides the previous enumeration before the given value.
@@ -144,7 +146,7 @@ class Enum(object): prev_index = (self._values.index(value) - 1) % len(self._values) return self._values[prev_index]
- def __getitem__(self, item): + def __getitem__(self, item: str) -> str: """ Provides the values for the given key.
@@ -161,7 +163,7 @@ class Enum(object): keys = ', '.join(self.keys()) raise ValueError("'%s' isn't among our enumeration keys, which includes: %s" % (item, keys))
- def __iter__(self): + def __iter__(self) -> Iterator[str]: """ Provides an ordered listing of the enums in this set. """ diff --git a/stem/util/log.py b/stem/util/log.py index 94d055ff..940469a3 100644 --- a/stem/util/log.py +++ b/stem/util/log.py @@ -92,10 +92,10 @@ DEDUPLICATION_MESSAGE_IDS = set()
class _NullHandler(logging.Handler): - def __init__(self): + def __init__(self) -> None: logging.Handler.__init__(self, level = logging.FATAL + 5) # disable logging
- def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: pass
@@ -103,7 +103,7 @@ if not LOGGER.handlers: LOGGER.addHandler(_NullHandler())
-def get_logger(): +def get_logger() -> logging.Logger: """ Provides the stem logger.
@@ -113,7 +113,7 @@ def get_logger(): return LOGGER
-def logging_level(runlevel): +def logging_level(runlevel: 'stem.util.log.Runlevel') -> int: """ Translates a runlevel into the value expected by the logging module.
@@ -126,7 +126,7 @@ def logging_level(runlevel): return logging.FATAL + 5
-def is_tracing(): +def is_tracing() -> bool: """ Checks if we're logging at the trace runlevel.
@@ -142,7 +142,7 @@ def is_tracing(): return False
-def escape(message): +def escape(message: str) -> str: """ Escapes specific sequences for logging (newlines, tabs, carriage returns). If the input is **bytes** then this converts it to **unicode** under python 3.x. @@ -160,7 +160,7 @@ def escape(message): return message
-def log(runlevel, message): +def log(runlevel: 'stem.util.log.Runlevel', message: str) -> None: """ Logs a message at the given runlevel.
@@ -172,7 +172,7 @@ def log(runlevel, message): LOGGER.log(LOG_VALUES[runlevel], message)
-def log_once(message_id, runlevel, message): +def log_once(message_id: str, runlevel: 'stem.util.log.Runlevel', message: str) -> None: """ Logs a message at the given runlevel. If a message with this ID has already been logged then this is a no-op. @@ -193,43 +193,43 @@ def log_once(message_id, runlevel, message): # shorter aliases for logging at a runlevel
-def trace(message): +def trace(message: str) -> None: log(Runlevel.TRACE, message)
-def debug(message): +def debug(message: str) -> None: log(Runlevel.DEBUG, message)
-def info(message): +def info(message: str) -> None: log(Runlevel.INFO, message)
-def notice(message): +def notice(message: str) -> None: log(Runlevel.NOTICE, message)
-def warn(message): +def warn(message: str) -> None: log(Runlevel.WARN, message)
-def error(message): +def error(message: str) -> None: log(Runlevel.ERROR, message)
class _StdoutLogger(logging.Handler): - def __init__(self, runlevel): + def __init__(self, runlevel: 'stem.util.log.Runlevel') -> None: logging.Handler.__init__(self, level = logging_level(runlevel))
self.formatter = logging.Formatter( fmt = '%(asctime)s [%(levelname)s] %(message)s', datefmt = '%m/%d/%Y %H:%M:%S')
- def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: print(self.formatter.format(record))
-def log_to_stdout(runlevel): +def log_to_stdout(runlevel: 'stem.util.log.Runlevel') -> None: """ Logs further events to stdout.
diff --git a/stem/util/proc.py b/stem/util/proc.py index 3589af13..10f2ae60 100644 --- a/stem/util/proc.py +++ b/stem/util/proc.py @@ -56,6 +56,7 @@ import stem.util.enum import stem.util.str_tools
from stem.util import log +from typing import Any, Mapping, Optional, Sequence, Set, Tuple, Type
try: # unavailable on windows (#19823) @@ -80,7 +81,7 @@ Stat = stem.util.enum.Enum(
@functools.lru_cache() -def is_available(): +def is_available() -> bool: """ Checks if proc information is available on this platform.
@@ -101,7 +102,7 @@ def is_available():
@functools.lru_cache() -def system_start_time(): +def system_start_time() -> float: """ Provides the unix time (seconds since epoch) when the system started.
@@ -124,7 +125,7 @@ def system_start_time():
@functools.lru_cache() -def physical_memory(): +def physical_memory() -> int: """ Provides the total physical memory on the system in bytes.
@@ -146,7 +147,7 @@ def physical_memory(): raise exc
-def cwd(pid): +def cwd(pid: int) -> str: """ Provides the current working directory for the given process.
@@ -174,7 +175,7 @@ def cwd(pid): return cwd
-def uid(pid): +def uid(pid: int) -> int: """ Provides the user ID the given process is running under.
@@ -199,7 +200,7 @@ def uid(pid): raise exc
-def memory_usage(pid): +def memory_usage(pid: int) -> Tuple[int, int]: """ Provides the memory usage in bytes for the given process.
@@ -232,7 +233,7 @@ def memory_usage(pid): raise exc
-def stats(pid, *stat_types): +def stats(pid: int, *stat_types: 'stem.util.proc.Stat') -> Sequence[Any]: """ Provides process specific information. See the :data:`~stem.util.proc.Stat` enum for valid options. @@ -270,6 +271,7 @@ def stats(pid, *stat_types): raise exc
results = [] + for stat_type in stat_types: if stat_type == Stat.COMMAND: if pid == 0: @@ -300,7 +302,7 @@ def stats(pid, *stat_types): return tuple(results)
-def file_descriptors_used(pid): +def file_descriptors_used(pid: int) -> int: """ Provides the number of file descriptors currently being used by a process.
@@ -327,7 +329,7 @@ def file_descriptors_used(pid): raise IOError('Unable to check number of file descriptors used: %s' % exc)
-def connections(pid = None, user = None): +def connections(pid: Optional[int] = None, user: Optional[str] = None) -> Sequence['stem.util.connection.Connection']: """ Queries connections from the proc contents. This matches netstat, lsof, and friends but is much faster. If no **pid** or **user** are provided this @@ -412,7 +414,7 @@ def connections(pid = None, user = None): raise
-def _inodes_for_sockets(pid): +def _inodes_for_sockets(pid: int) -> Set[bytes]: """ Provides inodes in use by a process for its sockets.
@@ -450,7 +452,7 @@ def _inodes_for_sockets(pid): return inodes
-def _unpack_addr(addr): +def _unpack_addr(addr: str) -> str: """ Translates an address entry in the /proc/net/* contents to a human readable form (`reference http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html`_, @@ -494,7 +496,7 @@ def _unpack_addr(addr): return ENCODED_ADDR[addr]
-def _is_float(*value): +def _is_float(*value: Any) -> bool: try: for v in value: float(v) @@ -504,11 +506,11 @@ def _is_float(*value): return False
-def _get_line(file_path, line_prefix, parameter): +def _get_line(file_path: str, line_prefix: str, parameter: str) -> str: return _get_lines(file_path, (line_prefix, ), parameter)[line_prefix]
-def _get_lines(file_path, line_prefixes, parameter): +def _get_lines(file_path: str, line_prefixes: Sequence[str], parameter: str) -> Mapping[str, str]: """ Fetches lines with the given prefixes from a file. This only provides back the first instance of each prefix. @@ -552,7 +554,7 @@ def _get_lines(file_path, line_prefixes, parameter): raise
-def _log_runtime(parameter, proc_location, start_time): +def _log_runtime(parameter: str, proc_location: str, start_time: int) -> None: """ Logs a message indicating a successful proc query.
@@ -565,7 +567,7 @@ def _log_runtime(parameter, proc_location, start_time): log.debug('proc call (%s): %s (runtime: %0.4f)' % (parameter, proc_location, runtime))
-def _log_failure(parameter, exc): +def _log_failure(parameter: str, exc: Type[Exception]) -> None: """ Logs a message indicating that the proc query failed.
diff --git a/stem/util/str_tools.py b/stem/util/str_tools.py index c1285626..c606906a 100644 --- a/stem/util/str_tools.py +++ b/stem/util/str_tools.py @@ -26,6 +26,8 @@ import sys import stem.util import stem.util.enum
+from typing import Sequence, Tuple, Union + # label conversion tuples of the form... # (bits / bytes / seconds, short label, long label)
@@ -57,7 +59,7 @@ TIME_UNITS = ( _timestamp_re = re.compile(r'(\d{4})-(\d{2})-(\d{2}) (\d{2}):(\d{2}):(\d{2})')
-def _to_bytes(msg): +def _to_bytes(msg: Union[str, bytes]) -> bytes: """ Provides the ASCII bytes for the given string. This is purely to provide python 3 compatability, normalizing the unicode/ASCII change in the version @@ -76,7 +78,7 @@ def _to_bytes(msg): return msg
-def _to_unicode(msg): +def _to_unicode(msg: Union[str, bytes]) -> str: """ Provides the unicode string for the given ASCII bytes. This is purely to provide python 3 compatability, normalizing the unicode/ASCII change in the @@ -93,7 +95,7 @@ def _to_unicode(msg): return msg
-def _decode_b64(msg): +def _decode_b64(msg: Union[str, bytes]) -> str: """ Base64 decode, without padding concerns. """ @@ -104,7 +106,7 @@ def _decode_b64(msg): return base64.b64decode(msg + padding_chr * missing_padding)
-def _to_int(msg): +def _to_int(msg: Union[str, bytes]) -> int: """ Serializes a string to a number.
@@ -120,7 +122,7 @@ def _to_int(msg): return sum([pow(256, (len(msg) - i - 1)) * ord(c) for (i, c) in enumerate(msg)])
-def _to_camel_case(label, divider = '_', joiner = ' '): +def _to_camel_case(label: str, divider: str = '_', joiner: str = ' ') -> str: """ Converts the given string to camel case, ie:
@@ -148,7 +150,7 @@ def _to_camel_case(label, divider = '_', joiner = ' '): return joiner.join(words)
-def _split_by_length(msg, size): +def _split_by_length(msg: str, size: int) -> Sequence[str]: """ Splits a string into a list of strings up to the given size.
@@ -172,7 +174,7 @@ def _split_by_length(msg, size): Ending = stem.util.enum.Enum('ELLIPSE', 'HYPHEN')
-def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE, get_remainder = False): +def crop(msg: str, size: int, min_word_length: int = 4, min_crop: int = 0, ending: 'stem.util.str_tools.Ending' = Ending.ELLIPSE, get_remainder: bool = False) -> str: """ Shortens a string to a given length.
@@ -286,7 +288,7 @@ def crop(msg, size, min_word_length = 4, min_crop = 0, ending = Ending.ELLIPSE, return (return_msg, remainder) if get_remainder else return_msg
-def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round = False): +def size_label(byte_count: int, decimal: int = 0, is_long: bool = False, is_bytes: bool = True, round: bool = False) -> str: """ Converts a number of bytes into a human readable label in its most significant units. For instance, 7500 bytes would return "7 KB". If the @@ -323,7 +325,7 @@ def size_label(byte_count, decimal = 0, is_long = False, is_bytes = True, round return _get_label(SIZE_UNITS_BITS, byte_count, decimal, is_long, round)
-def time_label(seconds, decimal = 0, is_long = False): +def time_label(seconds: int, decimal: int = 0, is_long: bool = False) -> str: """ Converts seconds into a time label truncated to its most significant units. For instance, 7500 seconds would return "2h". Units go up through days. @@ -354,7 +356,7 @@ def time_label(seconds, decimal = 0, is_long = False): return _get_label(TIME_UNITS, seconds, decimal, is_long)
-def time_labels(seconds, is_long = False): +def time_labels(seconds: int, is_long: bool = False) -> Sequence[str]: """ Provides a list of label conversions for each time unit, starting with its most significant units on down. Any counts that evaluate to zero are omitted. @@ -384,7 +386,7 @@ def time_labels(seconds, is_long = False): return time_labels
-def short_time_label(seconds): +def short_time_label(seconds: int) -> str: """ Provides a time in the following format: [[dd-]hh:]mm:ss @@ -424,7 +426,7 @@ def short_time_label(seconds): return label
-def parse_short_time_label(label): +def parse_short_time_label(label: str) -> int: """ Provides the number of seconds corresponding to the formatting used for the cputime and etime fields of ps: @@ -469,7 +471,7 @@ def parse_short_time_label(label): raise ValueError('Non-numeric value in time entry: %s' % label)
-def _parse_timestamp(entry): +def _parse_timestamp(entry: str) -> 'datetime.datetime': """ Parses the date and time that in format like like...
@@ -495,7 +497,7 @@ def _parse_timestamp(entry): return datetime.datetime(time[0], time[1], time[2], time[3], time[4], time[5])
-def _parse_iso_timestamp(entry): +def _parse_iso_timestamp(entry: str) -> 'datetime.datetime': """ Parses the ISO 8601 standard that provides for timestamps like...
@@ -533,7 +535,7 @@ def _parse_iso_timestamp(entry): return timestamp + datetime.timedelta(microseconds = int(microseconds))
-def _get_label(units, count, decimal, is_long, round = False): +def _get_label(units: Tuple[int, str, str], count: int, decimal: int, is_long: bool, round: bool = False) -> str: """ Provides label corresponding to units of the highest significance in the provided set. This rounds down (ie, integer truncation after visible units). diff --git a/stem/util/system.py b/stem/util/system.py index b3dee151..8a61b2b9 100644 --- a/stem/util/system.py +++ b/stem/util/system.py @@ -82,6 +82,7 @@ import stem.util.str_tools
from stem import UNDEFINED from stem.util import log +from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, TextIO, Union
State = stem.util.enum.UppercaseEnum( 'PENDING', @@ -189,7 +190,7 @@ class CallError(OSError): :var str stderr: stderr of the process """
- def __init__(self, msg, command, exit_status, runtime, stdout, stderr): + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str) -> None: self.msg = msg self.command = command self.exit_status = exit_status @@ -197,7 +198,7 @@ class CallError(OSError): self.stdout = stdout self.stderr = stderr
- def __str__(self): + def __str__(self) -> str: return self.msg
@@ -210,7 +211,7 @@ class CallTimeoutError(CallError): :var float timeout: time we waited """
- def __init__(self, msg, command, exit_status, runtime, stdout, stderr, timeout): + def __init__(self, msg: str, command: str, exit_status: int, runtime: float, stdout: str, stderr: str, timeout: float) -> None: super(CallTimeoutError, self).__init__(msg, command, exit_status, runtime, stdout, stderr) self.timeout = timeout
@@ -231,7 +232,7 @@ class DaemonTask(object): :var exception error: exception raised by subprocess if it failed """
- def __init__(self, runner, args = None, priority = 15, start = False): + def __init__(self, runner: Callable, args: Optional[Sequence[Any]] = None, priority: int = 15, start: bool = False) -> None: self.runner = runner self.args = args self.priority = priority @@ -247,7 +248,7 @@ class DaemonTask(object): if start: self.run()
- def run(self): + def run(self) -> None: """ Invokes the task if it hasn't already been started. If it has this is a no-op. @@ -259,7 +260,7 @@ class DaemonTask(object): self._process.start() self.status = State.RUNNING
- def join(self): + def join(self) -> Any: """ Provides the result of the daemon task. If still running this blocks until the task is completed. @@ -292,7 +293,7 @@ class DaemonTask(object): raise RuntimeError('BUG: unexpected status from daemon task, %s' % self.status)
@staticmethod - def _run_wrapper(conn, priority, runner, args): + def _run_wrapper(conn: 'multiprocessing.connection.Connection', priority: int, runner: Callable, args: Sequence[Any]) -> None: start_time = time.time() os.nice(priority)
@@ -305,7 +306,7 @@ class DaemonTask(object): conn.close()
-def is_windows(): +def is_windows() -> bool: """ Checks if we are running on Windows.
@@ -315,7 +316,7 @@ def is_windows(): return platform.system() == 'Windows'
-def is_mac(): +def is_mac() -> bool: """ Checks if we are running on Mac OSX.
@@ -325,7 +326,7 @@ def is_mac(): return platform.system() == 'Darwin'
-def is_gentoo(): +def is_gentoo() -> bool: """ Checks if we're running on Gentoo.
@@ -335,7 +336,7 @@ def is_gentoo(): return os.path.exists('/etc/gentoo-release')
-def is_slackware(): +def is_slackware() -> bool: """ Checks if we are running on a Slackware system.
@@ -345,7 +346,7 @@ def is_slackware(): return os.path.exists('/etc/slackware-version')
-def is_bsd(): +def is_bsd() -> bool: """ Checks if we are within the BSD family of operating systems. This currently recognizes Macs, FreeBSD, and OpenBSD but may be expanded later. @@ -356,7 +357,7 @@ def is_bsd(): return platform.system() in ('Darwin', 'FreeBSD', 'OpenBSD', 'NetBSD')
-def is_available(command, cached=True): +def is_available(command: str, cached: bool = True) -> bool: """ Checks the current PATH to see if a command is available or not. If more than one command is present (for instance "ls -a | grep foo") then this @@ -399,7 +400,7 @@ def is_available(command, cached=True): return cmd_exists
-def is_running(command): +def is_running(command: Union[str, int, Sequence[str]]) -> bool: """ Checks for if a process with a given name or pid is running.
@@ -461,7 +462,7 @@ def is_running(command): return None
-def size_of(obj, exclude = None): +def size_of(obj: Any, exclude: Optional[Sequence[int]] = None) -> int: """ Provides the `approximate memory usage of an object https://code.activestate.com/recipes/577504/`_. This can recurse tuples, @@ -504,7 +505,7 @@ def size_of(obj, exclude = None): return size
-def name_by_pid(pid): +def name_by_pid(pid: int) -> Optional[str]: """ Attempts to determine the name a given process is running under (not including arguments). This uses... @@ -547,7 +548,7 @@ def name_by_pid(pid): return process_name
-def pid_by_name(process_name, multiple = False): +def pid_by_name(process_name: str, multiple: bool = False) -> Union[int, Sequence[int]]: """ Attempts to determine the process id for a running process, using...
@@ -718,7 +719,7 @@ def pid_by_name(process_name, multiple = False): return [] if multiple else None
-def pid_by_port(port): +def pid_by_port(port: int) -> Optional[int]: """ Attempts to determine the process id for a process with the given port, using... @@ -838,7 +839,7 @@ def pid_by_port(port): return None # all queries failed
-def pid_by_open_file(path): +def pid_by_open_file(path: str) -> Optional[int]: """ Attempts to determine the process id for a process with the given open file, using... @@ -876,7 +877,7 @@ def pid_by_open_file(path): return None # all queries failed
-def pids_by_user(user): +def pids_by_user(user: str) -> Optional[Sequence[int]]: """ Provides processes owned by a given user.
@@ -908,7 +909,7 @@ def pids_by_user(user): return None
-def cwd(pid): +def cwd(pid: int) -> Optional[str]: """ Provides the working directory of the given process.
@@ -977,7 +978,7 @@ def cwd(pid): return None # all queries failed
-def user(pid): +def user(pid: int) -> Optional[str]: """ Provides the user a process is running under.
@@ -1010,7 +1011,7 @@ def user(pid): return None
-def start_time(pid): +def start_time(pid: str) -> Optional[float]: """ Provides the unix timestamp when the given process started.
@@ -1041,7 +1042,7 @@ def start_time(pid): return None
-def tail(target, lines = None): +def tail(target: Union[str, TextIO], lines: Optional[int] = None) -> Iterator[str]: """ Provides lines of a file starting with the end. For instance, 'tail -n 50 /tmp/my_log' could be done with... @@ -1094,7 +1095,7 @@ def tail(target, lines = None): block_number -= 1
-def bsd_jail_id(pid): +def bsd_jail_id(pid: int) -> int: """ Gets the jail id for a process. These seem to only exist for FreeBSD (this style for jails does not exist on Linux, OSX, or OpenBSD). @@ -1129,7 +1130,7 @@ def bsd_jail_id(pid): return 0
-def bsd_jail_path(jid): +def bsd_jail_path(jid: int) -> Optional[str]: """ Provides the path of the given FreeBSD jail.
@@ -1151,7 +1152,7 @@ def bsd_jail_path(jid): return None
-def is_tarfile(path): +def is_tarfile(path: str) -> bool: """ Returns if the path belongs to a tarfile or not.
@@ -1177,7 +1178,7 @@ def is_tarfile(path): return mimetypes.guess_type(path)[0] == 'application/x-tar'
-def expand_path(path, cwd = None): +def expand_path(path: str, cwd: Optional[str] = None) -> str: """ Provides an absolute path, expanding tildes with the user's home and appending a current working directory if the path was relative. @@ -1222,7 +1223,7 @@ def expand_path(path, cwd = None): return relative_path
-def files_with_suffix(base_path, suffix): +def files_with_suffix(base_path: str, suffix: str) -> Iterator[str]: """ Iterates over files in a given directory, providing filenames with a certain suffix. @@ -1245,7 +1246,7 @@ def files_with_suffix(base_path, suffix): yield os.path.join(root, filename)
-def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = None, cwd = None, env = None): +def call(command: Union[str, Sequence[str]], default: Any = UNDEFINED, ignore_exit_status: bool = False, timeout: Optional[float] = None, cwd: Optional[str] = None, env: Optional[Mapping[str, str]] = None) -> Sequence[str]: """ call(command, default = UNDEFINED, ignore_exit_status = False)
@@ -1346,7 +1347,7 @@ def call(command, default = UNDEFINED, ignore_exit_status = False, timeout = Non SYSTEM_CALL_TIME += time.time() - start_time
-def get_process_name(): +def get_process_name() -> str: """ Provides the present name of our process.
@@ -1398,7 +1399,7 @@ def get_process_name(): return _PROCESS_NAME
-def set_process_name(process_name): +def set_process_name(process_name: str) -> None: """ Renames our current process from "python <args>" to a custom name. This is best-effort, not necessarily working on all platforms. @@ -1432,7 +1433,7 @@ def set_process_name(process_name): _set_proc_title(process_name)
-def _set_argv(process_name): +def _set_argv(process_name: str) -> None: """ Overwrites our argv in a similar fashion to how it's done in C with: strcpy(argv[0], 'new_name'); @@ -1462,7 +1463,7 @@ def _set_argv(process_name): _PROCESS_NAME = process_name
-def _set_prctl_name(process_name): +def _set_prctl_name(process_name: str) -> None: """ Sets the prctl name, which is used by top and killall. This appears to be Linux specific and has the max of 15 characters. @@ -1477,7 +1478,7 @@ def _set_prctl_name(process_name): libc.prctl(PR_SET_NAME, ctypes.byref(name_buffer), 0, 0, 0)
-def _set_proc_title(process_name): +def _set_proc_title(process_name: str) -> None: """ BSD specific calls (should be compataible with both FreeBSD and OpenBSD: http://fxr.watson.org/fxr/source/gen/setproctitle.c?v=FREEBSD-LIBC diff --git a/stem/util/term.py b/stem/util/term.py index 06391441..acc52cad 100644 --- a/stem/util/term.py +++ b/stem/util/term.py @@ -50,6 +50,8 @@ Utilities for working with the terminal. import stem.util.enum import stem.util.str_tools
+from typing import Optional, Union + TERM_COLORS = ('BLACK', 'RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA', 'CYAN', 'WHITE')
# DISABLE_COLOR_SUPPORT is *not* being vended to Stem users. This is likely to @@ -70,7 +72,7 @@ CSI = '\x1B[%sm' RESET = CSI % '0'
-def encoding(*attrs): +def encoding(*attrs: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> Optional[str]: """ Provides the ANSI escape sequence for these terminal color or attributes.
@@ -81,7 +83,7 @@ def encoding(*attrs): provide an ecoding for
:returns: **str** of the ANSI escape sequence, **None** no attributes are - recognized + unrecognized """
term_encodings = [] @@ -99,7 +101,7 @@ def encoding(*attrs): return CSI % ';'.join(term_encodings)
-def format(msg, *attr): +def format(msg: str, *attr: Union['stem.util.terminal.Color', 'stem.util.terminal.BgColor', 'stem.util.terminal.Attr']) -> str: """ Simple terminal text formatting using `ANSI escape sequences https://en.wikipedia.org/wiki/ANSI_escape_code#CSI_codes`_. @@ -118,7 +120,7 @@ def format(msg, *attr): :data:`~stem.util.term.BgColor`, or :data:`~stem.util.term.Attr` enums and are case insensitive (so strings like 'red' are fine)
- :returns: **unicode** wrapped with ANSI escape encodings, starting with the given + :returns: **str** wrapped with ANSI escape encodings, starting with the given attributes and ending with a reset """
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py index 71165214..d5d0f842 100644 --- a/stem/util/test_tools.py +++ b/stem/util/test_tools.py @@ -42,6 +42,8 @@ import stem.util.conf import stem.util.enum import stem.util.system
+from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type + CONFIG = stem.util.conf.config_dict('test', { 'pycodestyle.ignore': [], 'pyflakes.ignore': [], @@ -55,7 +57,7 @@ AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED') AsyncResult = collections.namedtuple('AsyncResult', 'type msg')
-def assert_equal(expected, actual, msg = None): +def assert_equal(expected: Any, actual: Any, msg: Optional[str] = None) -> None: """ Function form of a TestCase's assertEqual.
@@ -72,7 +74,7 @@ def assert_equal(expected, actual, msg = None): raise AssertionError("Expected '%s' but was '%s'" % (expected, actual) if msg is None else msg)
-def assert_in(expected, actual, msg = None): +def assert_in(expected: Any, actual: Any, msg: Optional[str] = None) -> None: """ Asserts that a given value is within this content.
@@ -89,7 +91,7 @@ def assert_in(expected, actual, msg = None): raise AssertionError("Expected '%s' to be within '%s'" % (expected, actual) if msg is None else msg)
-def skip(msg): +def skip(msg: str) -> None: """ Function form of a TestCase's skipTest.
@@ -100,10 +102,12 @@ def skip(msg): :raises: **unittest.case.SkipTest** for this reason """
+ # TODO: remove now that python 2.x is unsupported? + raise unittest.case.SkipTest(msg)
-def asynchronous(func): +def asynchronous(func: Callable) -> Callable: test = stem.util.test_tools.AsyncTest(func) ASYNC_TESTS[test.name] = test return test.method @@ -131,7 +135,7 @@ class AsyncTest(object): .. versionadded:: 1.6.0 """
- def __init__(self, runner, args = None, threaded = False): + def __init__(self, runner: Callable, args: Optional[Any] = None, threaded: bool = False) -> None: self.name = '%s.%s' % (runner.__module__, runner.__name__)
self._runner = runner @@ -147,8 +151,8 @@ class AsyncTest(object): self._result = None self._status = AsyncStatus.PENDING
- def run(self, *runner_args, **kwargs): - def _wrapper(conn, runner, args): + def run(self, *runner_args: Any, **kwargs: Any) -> None: + def _wrapper(conn: 'multiprocessing.connection.Connection', runner: Callable, args: Any) -> None: os.nice(12)
try: @@ -187,14 +191,14 @@ class AsyncTest(object): self._process.start() self._status = AsyncStatus.RUNNING
- def pid(self): + def pid(self) -> int: with self._process_lock: return self._process.pid if (self._process and not self._threaded) else None
- def join(self): + def join(self) -> None: self.result(None)
- def result(self, test): + def result(self, test: 'unittest.TestCase') -> None: with self._process_lock: if self._status == AsyncStatus.PENDING: self.run() @@ -231,18 +235,18 @@ class TimedTestRunner(unittest.TextTestRunner): .. versionadded:: 1.6.0 """
- def run(self, test): + def run(self, test: 'unittest.TestCase') -> None: for t in test._tests: original_type = type(t)
class _TestWrapper(original_type): - def run(self, result = None): + def run(self, result: Optional[Any] = None) -> Any: start_time = time.time() result = super(type(self), self).run(result) TEST_RUNTIMES[self.id()] = time.time() - start_time return result
- def assertRaisesWith(self, exc_type, exc_msg, func, *args, **kwargs): + def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None: """ Asserts the given invokation raises the expected excepiton. This is similar to unittest's assertRaises and assertRaisesRegexp, but checks @@ -255,10 +259,10 @@ class TimedTestRunner(unittest.TextTestRunner):
return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs)
- def id(self): + def id(self) -> str: return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName)
- def __str__(self): + def __str__(self) -> str: return '%s (%s.%s)' % (self._testMethodName, original_type.__module__, original_type.__name__)
t.__class__ = _TestWrapper @@ -266,7 +270,7 @@ class TimedTestRunner(unittest.TextTestRunner): return super(TimedTestRunner, self).run(test)
-def test_runtimes(): +def test_runtimes() -> Mapping[str, float]: """ Provides the runtimes of tests executed through TimedTestRunners.
@@ -279,7 +283,7 @@ def test_runtimes(): return dict(TEST_RUNTIMES)
-def clean_orphaned_pyc(paths): +def clean_orphaned_pyc(paths: Sequence[str]) -> Sequence[str]: """ Deletes any file with a \*.pyc extention without a corresponding \*.py. This helps to address a common gotcha when deleting python files... @@ -324,7 +328,7 @@ def clean_orphaned_pyc(paths): return orphaned_pyc
-def is_pyflakes_available(): +def is_pyflakes_available() -> bool: """ Checks if pyflakes is availalbe.
@@ -334,7 +338,7 @@ def is_pyflakes_available(): return _module_exists('pyflakes.api') and _module_exists('pyflakes.reporter')
-def is_pycodestyle_available(): +def is_pycodestyle_available() -> bool: """ Checks if pycodestyle is availalbe.
@@ -349,7 +353,7 @@ def is_pycodestyle_available(): return hasattr(pycodestyle, 'BaseReport')
-def stylistic_issues(paths, check_newlines = False, check_exception_keyword = False, prefer_single_quotes = False): +def stylistic_issues(paths: Sequence[str], check_newlines: bool = False, check_exception_keyword: bool = False, prefer_single_quotes: bool = False) -> Mapping[str, 'stem.util.test_tools.Issue']: """ Checks for stylistic issues that are an issue according to the parts of PEP8 we conform to. You can suppress pycodestyle issues by making a 'test' @@ -425,7 +429,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa else: ignore_rules.append(rule)
- def is_ignored(path, rule, code): + def is_ignored(path: str, rule: str, code: str) -> bool: for ignored_path, ignored_rule, ignored_code in ignore_for_file: if path.endswith(ignored_path) and ignored_rule == rule and code.strip().startswith(ignored_code): return True @@ -440,7 +444,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa import pycodestyle
class StyleReport(pycodestyle.BaseReport): - def init_file(self, filename, lines, expected, line_offset): + def init_file(self, filename: str, lines: Sequence[str], expected: Tuple[str], line_offset: int) -> None: super(StyleReport, self).init_file(filename, lines, expected, line_offset)
if not check_newlines and not check_exception_keyword and not prefer_single_quotes: @@ -473,7 +477,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa
issues.setdefault(filename, []).append(Issue(index + 1, 'use single rather than double quotes', line))
- def error(self, line_number, offset, text, check): + def error(self, line_number: int, offset: int, text: str, check: str) -> None: code = super(StyleReport, self).error(line_number, offset, text, check)
if code: @@ -488,7 +492,7 @@ def stylistic_issues(paths, check_newlines = False, check_exception_keyword = Fa return issues
-def pyflakes_issues(paths): +def pyflakes_issues(paths: Sequence[str]) -> Mapping[str, 'stem.util.test_tools.Issue']: """ Performs static checks via pyflakes. False positives can be ignored via 'pyflakes.ignore' entries in our 'test' config. For instance... @@ -521,23 +525,23 @@ def pyflakes_issues(paths): import pyflakes.reporter
class Reporter(pyflakes.reporter.Reporter): - def __init__(self): + def __init__(self) -> None: self._ignored_issues = {}
for line in CONFIG['pyflakes.ignore']: path, issue = line.split('=>') self._ignored_issues.setdefault(path.strip(), []).append(issue.strip())
- def unexpectedError(self, filename, msg): + def unexpectedError(self, filename: str, msg: str) -> None: self._register_issue(filename, None, msg, None)
- def syntaxError(self, filename, msg, lineno, offset, text): + def syntaxError(self, filename: str, msg: str, lineno: int, offset: int, text: str) -> None: self._register_issue(filename, lineno, msg, text)
- def flake(self, msg): + def flake(self, msg: str) -> None: self._register_issue(msg.filename, msg.lineno, msg.message % msg.message_args, None)
- def _is_ignored(self, path, issue): + def _is_ignored(self, path: str, issue: str) -> bool: # Paths in pyflakes_ignore are relative, so we need to check to see if our # path ends with any of them.
@@ -556,7 +560,7 @@ def pyflakes_issues(paths):
return False
- def _register_issue(self, path, line_number, issue, line): + def _register_issue(self, path: str, line_number: int, issue: str, line: int) -> None: if not self._is_ignored(path, issue): if path and line_number and not line: line = linecache.getline(path, line_number).strip() @@ -571,7 +575,7 @@ def pyflakes_issues(paths): return issues
-def _module_exists(module_name): +def _module_exists(module_name: str) -> bool: """ Checks if a module exists.
@@ -587,7 +591,7 @@ def _module_exists(module_name): return False
-def _python_files(paths): +def _python_files(paths: Sequence[str]) -> Iterator[str]: for path in paths: for file_path in stem.util.system.files_with_suffix(path, '.py'): skip = False diff --git a/stem/util/tor_tools.py b/stem/util/tor_tools.py index 8987635e..2398b7bc 100644 --- a/stem/util/tor_tools.py +++ b/stem/util/tor_tools.py @@ -23,6 +23,8 @@ import re
import stem.util.str_tools
+from typing import Optional, Sequence, Union + # The control-spec defines the following as... # # Fingerprint = "$" 40*HEXDIG @@ -45,7 +47,7 @@ HS_V2_ADDRESS_PATTERN = re.compile('^[a-z2-7]{16}$') HS_V3_ADDRESS_PATTERN = re.compile('^[a-z2-7]{56}$')
-def is_valid_fingerprint(entry, check_prefix = False): +def is_valid_fingerprint(entry: str, check_prefix: bool = False) -> bool: """ Checks if a string is a properly formatted relay fingerprint. This checks for a '$' prefix if check_prefix is true, otherwise this only validates the hex @@ -72,11 +74,11 @@ def is_valid_fingerprint(entry, check_prefix = False): return False
-def is_valid_nickname(entry): +def is_valid_nickname(entry: str) -> bool: """ Checks if a string is a valid format for being a nickname.
- :param str entry: string to be checked + :param str entry: string to check
:returns: **True** if the string could be a nickname, **False** otherwise """ @@ -90,10 +92,12 @@ def is_valid_nickname(entry): return False
-def is_valid_circuit_id(entry): +def is_valid_circuit_id(entry: str) -> bool: """ Checks if a string is a valid format for being a circuit identifier.
+ :param str entry: string to check + :returns: **True** if the string could be a circuit id, **False** otherwise """
@@ -106,29 +110,33 @@ def is_valid_circuit_id(entry): return False
-def is_valid_stream_id(entry): +def is_valid_stream_id(entry: str) -> bool: """ Checks if a string is a valid format for being a stream identifier. Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`.
+ :param str entry: string to check + :returns: **True** if the string could be a stream id, **False** otherwise """
return is_valid_circuit_id(entry)
-def is_valid_connection_id(entry): +def is_valid_connection_id(entry: str) -> bool: """ Checks if a string is a valid format for being a connection identifier. Currently, this is just an alias to :func:`~stem.util.tor_tools.is_valid_circuit_id`.
+ :param str entry: string to check + :returns: **True** if the string could be a connection id, **False** otherwise """
return is_valid_circuit_id(entry)
-def is_valid_hidden_service_address(entry, version = None): +def is_valid_hidden_service_address(entry: str, version: Optional[Union[int, Sequence[int]]] = None) -> bool: """ Checks if a string is a valid format for being a hidden service address (not including the '.onion' suffix). @@ -137,6 +145,7 @@ def is_valid_hidden_service_address(entry, version = None): Added the **version** argument, and responds with **True** if a version 3 hidden service address rather than just version 2 addresses.
+ :param str entry: string to check :param int,list version: versions to check for, if unspecified either v2 or v3 hidden service address will provide **True**
@@ -166,7 +175,7 @@ def is_valid_hidden_service_address(entry, version = None): return False
-def is_hex_digits(entry, count): +def is_hex_digits(entry: str, count: int) -> bool: """ Checks if a string is the given number of hex digits. Digits represented by letters are case insensitive. diff --git a/stem/version.py b/stem/version.py index 181aec8a..8ec35293 100644 --- a/stem/version.py +++ b/stem/version.py @@ -42,13 +42,15 @@ import stem.util import stem.util.enum import stem.util.system
+from typing import Any, Callable + # cache for the get_system_tor_version function VERSION_CACHE = {}
VERSION_PATTERN = re.compile(r'^([0-9]+).([0-9]+).([0-9]+)(.[0-9]+)?(-\S*)?(( (\S*))*)$')
-def get_system_tor_version(tor_cmd = 'tor'): +def get_system_tor_version(tor_cmd: str = 'tor') -> 'stem.version.Version': """ Queries tor for its version. This is os dependent, only working on linux, osx, and bsd. @@ -96,7 +98,7 @@ def get_system_tor_version(tor_cmd = 'tor'):
@functools.lru_cache() -def _get_version(version_str): +def _get_version(version_str: str) -> 'stem.version.Version': return Version(version_str)
@@ -125,7 +127,7 @@ class Version(object): :raises: **ValueError** if input isn't a valid tor version """
- def __init__(self, version_str): + def __init__(self, version_str: str) -> None: self.version_str = version_str version_parts = VERSION_PATTERN.match(version_str)
@@ -157,14 +159,14 @@ class Version(object): else: raise ValueError("'%s' isn't a properly formatted tor version" % version_str)
- def __str__(self): + def __str__(self) -> str: """ Provides the string used to construct the version. """
return self.version_str
- def _compare(self, other, method): + def _compare(self, other: Any, method: Callable[[Any, Any], bool]) -> Callable[[Any, Any], bool]: """ Compares version ordering according to the spec. """ @@ -195,23 +197,23 @@ class Version(object):
return method(my_status, other_status)
- def __hash__(self): + def __hash__(self) -> int: return stem.util._hash_attr(self, 'major', 'minor', 'micro', 'patch', 'status', cache = True)
- def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s == o)
- def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other
- def __gt__(self, other): + def __gt__(self, other: Any) -> bool: """ Checks if this version meets the requirements for a given feature. """
return self._compare(other, lambda s, o: s > o)
- def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._compare(other, lambda s, o: s >= o)
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py index 8b8b3205..732ae50a 100644 --- a/test/integ/control/controller.py +++ b/test/integ/control/controller.py @@ -339,14 +339,14 @@ class TestController(unittest.TestCase): auth_methods = []
if test.runner.Torrc.COOKIE in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE) - auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE) + auth_methods.append(stem.connection.AuthMethod.COOKIE) + auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE)
if test.runner.Torrc.PASSWORD in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD) + auth_methods.append(stem.connection.AuthMethod.PASSWORD)
if not auth_methods: - auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE) + auth_methods.append(stem.connection.AuthMethod.NONE)
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py index 2fb060db..3a9ee0be 100644 --- a/test/integ/response/protocolinfo.py +++ b/test/integ/response/protocolinfo.py @@ -125,8 +125,8 @@ class TestProtocolInfo(unittest.TestCase): auth_methods, auth_cookie_path = [], None
if test.runner.Torrc.COOKIE in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.COOKIE) - auth_methods.append(stem.response.protocolinfo.AuthMethod.SAFECOOKIE) + auth_methods.append(stem.connection.AuthMethod.COOKIE) + auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE)
chroot_path = runner.get_chroot() auth_cookie_path = runner.get_auth_cookie_path() @@ -135,10 +135,10 @@ class TestProtocolInfo(unittest.TestCase): auth_cookie_path = auth_cookie_path[len(chroot_path):]
if test.runner.Torrc.PASSWORD in tor_options: - auth_methods.append(stem.response.protocolinfo.AuthMethod.PASSWORD) + auth_methods.append(stem.connection.AuthMethod.PASSWORD)
if not auth_methods: - auth_methods.append(stem.response.protocolinfo.AuthMethod.NONE) + auth_methods.append(stem.connection.AuthMethod.NONE)
self.assertEqual((), protocolinfo_response.unknown_auth_methods) self.assertEqual(tuple(auth_methods), protocolinfo_response.auth_methods) diff --git a/test/settings.cfg b/test/settings.cfg index 38a37ef9..8c6423bb 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -192,20 +192,16 @@ pycodestyle.ignore test/unit/util/connection.py => W291: _tor tor 158 # False positives from pyflakes. These are mappings between the path and the # issue.
-pyflakes.ignore run_tests.py => 'unittest' imported but unused -pyflakes.ignore stem/control.py => undefined name 'controller' -pyflakes.ignore stem/manual.py => undefined name 'unichr' +pyflakes.ignore stem/manual.py => undefined name 'sqlite3' +pyflakes.ignore stem/client/cell.py => undefined name 'cryptography' +pyflakes.ignore stem/client/cell.py => undefined name 'hashlib' pyflakes.ignore stem/client/datatype.py => redefinition of unused 'pop' from * -pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'stem.descriptor.hidden_service.*' imported but unused -pyflakes.ignore stem/descriptor/hidden_service_descriptor.py => 'from stem.descriptor.hidden_service import *' used; unable to detect undefined names -pyflakes.ignore stem/interpreter/__init__.py => undefined name 'raw_input' -pyflakes.ignore stem/response/events.py => undefined name 'long' -pyflakes.ignore stem/util/__init__.py => undefined name 'long' -pyflakes.ignore stem/util/__init__.py => undefined name 'unicode' -pyflakes.ignore stem/util/conf.py => undefined name 'unicode' -pyflakes.ignore stem/util/test_tools.py => 'pyflakes' imported but unused -pyflakes.ignore stem/util/test_tools.py => 'pycodestyle' imported but unused -pyflakes.ignore test/__init__.py => undefined name 'test' +pyflakes.ignore stem/descriptor/hidden_service.py => undefined name 'cryptography' +pyflakes.ignore stem/interpreter/autocomplete.py => undefined name 'stem' +pyflakes.ignore stem/interpreter/help.py => undefined name 'stem' +pyflakes.ignore stem/response/events.py => undefined name 'datetime' +pyflakes.ignore stem/util/conf.py => undefined name 'stem' +pyflakes.ignore stem/util/enum.py => undefined name 'stem' pyflakes.ignore test/require.py => 'cryptography.utils.int_from_bytes' imported but unused pyflakes.ignore test/require.py => 'cryptography.utils.int_to_bytes' imported but unused pyflakes.ignore test/require.py => 'cryptography.hazmat.backends.default_backend' imported but unused @@ -216,7 +212,6 @@ pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.serialization pyflakes.ignore test/require.py => 'cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey' imported but unused pyflakes.ignore test/unit/response/events.py => 'from stem import *' used; unable to detect undefined names pyflakes.ignore test/unit/response/events.py => *may be undefined, or defined from star imports: stem -pyflakes.ignore stem/util/str_tools.py => undefined name 'unicode' pyflakes.ignore test/integ/interpreter.py => 'readline' imported but unused
# Test modules we want to run. Modules are roughly ordered by the dependencies diff --git a/test/unit/response/protocolinfo.py b/test/unit/response/protocolinfo.py index dd8d2160..a71746c9 100644 --- a/test/unit/response/protocolinfo.py +++ b/test/unit/response/protocolinfo.py @@ -13,8 +13,8 @@ import stem.version
from unittest.mock import Mock, patch
+from stem.connection import AuthMethod from stem.response import ControlMessage -from stem.response.protocolinfo import AuthMethod
NO_AUTH = """250-PROTOCOLINFO 1 250-AUTH METHODS=NULL