commit f42ec9286f6dbb50fa164873424a23e067dcf7fd Author: Damian Johnson atagar@torproject.org Date: Sat Jan 27 17:07:08 2018 -0800
Replace pack function with a method
One sucky thing about python: no method overriding. We can have an instance method or class method called pack, but not both.
Trivial to get around by renaming one of them but on relfection I suspect an instance method is where we'll want to go in the long run anyway. Lets run with this for a bit to see how it goes. --- stem/client/cell.py | 272 ++++++++++------------------------------------- test/unit/client/cell.py | 32 +++--- 2 files changed, 75 insertions(+), 229 deletions(-)
diff --git a/stem/client/cell.py b/stem/client/cell.py index 37c3d5b9..e2e589f5 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -126,7 +126,7 @@ class Cell(object): return cls._unpack(payload, circ_id, link_version), content
def pack(self, link_version): - raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME) + raise NotImplementedError('Unpacking not yet implemented for %s cells' % type(self).NAME)
@classmethod def _pack(cls, link_version, payload, circ_id = 0): @@ -223,24 +223,16 @@ class PaddingCell(Cell): VALUE = 0 IS_FIXED_SIZE = True
- def __init__(self, payload): + def __init__(self, payload = None): + if not payload: + payload = os.urandom(FIXED_PAYLOAD_LEN) + elif len(payload) != FIXED_PAYLOAD_LEN: + raise ValueError('Padding payload should be %i bytes, but was %i' % (FIXED_PAYLOAD_LEN, len(payload))) + self.payload = payload
def pack(self, link_version): - return PaddingCell.pack(link_version, self.payload) - - @classmethod - def pack(cls, link_version, payload = None): - """ - Provides a randomized padding payload. - - :param int link_version: link protocol version - :param bytes payload: padding payload - - :returns: **bytes** with randomized content - """ - - return cls._pack(link_version, payload if payload else os.urandom(FIXED_PAYLOAD_LEN)) + return PaddingCell._pack(link_version, self.payload)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -297,34 +289,15 @@ class RelayCell(CircuitCell): raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
def pack(self, link_version): - return RelayCell.pack(link_version, self.circ_id, self.command_int, self.data, self.digest, self.stream_id) - - @classmethod - def pack(cls, link_version, circ_id, command, data, digest, stream_id = 0): - """ - Provides payload of a relay cell. - - :param int link_version: link protocol version - :param int circ_id: circuit id - :param stem.client.RelayCommand command: reason the circuit is being closed - :param bytes data: payload of the cell - :param int digest: running digest held with the relay - :param int stream_id: specific stream this concerns - - :returns: **bytes** to close the circuit - """ - - cell = RelayCell(circ_id, command, data, digest, stream_id) - payload = io.BytesIO() - payload.write(Size.CHAR.pack(cell.command_int)) + payload.write(Size.CHAR.pack(self.command_int)) payload.write(Size.SHORT.pack(0)) # 'recognized' field - payload.write(Size.SHORT.pack(cell.stream_id)) - payload.write(Size.LONG.pack(cell.digest)) - payload.write(Size.SHORT.pack(len(cell.data))) - payload.write(cell.data) + payload.write(Size.SHORT.pack(self.stream_id)) + payload.write(Size.LONG.pack(self.digest)) + payload.write(Size.SHORT.pack(len(self.data))) + payload.write(self.data)
- return cls._pack(link_version, payload.getvalue(), circ_id) + return RelayCell._pack(link_version, payload.getvalue(), self.circ_id)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -353,26 +326,12 @@ class DestroyCell(CircuitCell): VALUE = 4 IS_FIXED_SIZE = True
- def __init__(self, circ_id, reason): + def __init__(self, circ_id, reason = stem.client.CloseReason.NONE): super(DestroyCell, self).__init__(circ_id) self.reason, self.reason_int = stem.client.CloseReason.get(reason)
def pack(self, link_version): - return DestroyCell.pack(link_version, self.circ_id, self.reason_int) - - @classmethod - def pack(cls, link_version, circ_id, reason = stem.client.CloseReason.NONE): - """ - Provides payload to close the given circuit. - - :param int link_version: link protocol version - :param int circ_id: circuit id - :param stem.client.CloseReason reason: reason to close the circuit - - :returns: **bytes** to close the circuit - """ - - return cls._pack(link_version, Size.CHAR.pack(stem.client.CloseReason.get(reason)[1]), circ_id) + return DestroyCell._pack(link_version, Size.CHAR.pack(self.reason_int), self.circ_id)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -401,7 +360,7 @@ class CreateFastCell(CircuitCell): VALUE = 5 IS_FIXED_SIZE = True
- def __init__(self, circ_id = None, key_material = None): + def __init__(self, circ_id, key_material = None): if not key_material: key_material = os.urandom(HASH_LEN) elif len(key_material) != HASH_LEN: @@ -411,30 +370,7 @@ class CreateFastCell(CircuitCell): self.key_material = key_material
def pack(self, link_version): - return CreateFastCell.pack(link_version, self.circ_id, self.key_material) - - @classmethod - def pack(cls, link_version, circ_id = None, key_material = None): - """ - Provides a randomized circuit construction payload. - - :param int link_version: link protocol version - :param int circ_id: circuit id - :param bytes key_material: randomized key material - - :returns: **bytes** with our randomized key material - """ - - cell = CreateFastCell(circ_id, key_material) - circ_id = cell.circ_id - - if not circ_id: - # When initiating a circuit the v4 link protocol requires us to set the - # most significant bit. Otherwise any id will do. - - circ_id = 0x80000000 if link_version >= 4 else 0x01 - - return cls._pack(link_version, cell.key_material, circ_id) + return CreateFastCell._pack(link_version, self.key_material, self.circ_id)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -461,36 +397,21 @@ class CreatedFastCell(CircuitCell): VALUE = 6 IS_FIXED_SIZE = True
- def __init__(self, circ_id, key_material, derivative_key): + def __init__(self, circ_id, derivative_key, key_material = None): + if not key_material: + key_material = os.urandom(HASH_LEN) + elif len(key_material) != HASH_LEN: + raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material))) + + if len(derivative_key) != HASH_LEN: + raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key))) + super(CreatedFastCell, self).__init__(circ_id) self.key_material = key_material self.derivative_key = derivative_key
def pack(self, link_version): - return CreatedFastCell.pack(link_version, self.circ_id, delf.derived_key, self.key_material) - - @classmethod - def pack(cls, link_version, circ_id, derivative_key, key_material = None): - """ - Provides a randomized circuit construction payload. - - :param int link_version: link protocol version - :param int circ_id: circuit id - :param bytes derivative_key: hash proving the relay knows our shared key - :param bytes key_material: randomized key material - - :returns: **bytes** with our randomized key material - """ - - if key_material and len(key_material) != HASH_LEN: - raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material))) - elif len(derivative_key) != HASH_LEN: - raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key))) - - if not key_material: - key_material = os.urandom(HASH_LEN) - - return cls._pack(link_version, key_material + derivative_key, circ_id) + return CreatedFastCell._pack(link_version, self.key_material + self.derivative_key, self.circ_id)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -499,7 +420,7 @@ class CreatedFastCell(CircuitCell): if len(content) != HASH_LEN * 2: raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
- return CreatedFastCell(circ_id, content[:HASH_LEN], content[HASH_LEN:]) + return CreatedFastCell(circ_id, content[HASH_LEN:], content[:HASH_LEN])
def __hash__(self): return _hash_attr(self, 'circ_id', 'derivative_key', 'key_material') @@ -519,24 +440,12 @@ class VersionsCell(Cell): def __init__(self, versions): self.versions = versions
- def pack(self, link_version): - return VersionsCell.pack(self.versions) - - @classmethod - def pack(cls, versions): - """ - Provides the payload for a series of link versions. - - :param list versions: link versions to serialize - - :returns: **bytes** with a payload for these versions - """ - + def pack(self, link_version = None): # Used for link version negotiation so we don't have that yet. This is fine # since VERSION cells avoid most version dependent attributes.
- payload = b''.join([Size.SHORT.pack(v) for v in versions]) - return cls._pack(2, payload) + payload = b''.join([Size.SHORT.pack(v) for v in self.versions]) + return VersionsCell._pack(2, payload)
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -565,39 +474,21 @@ class NetinfoCell(Cell): VALUE = 8 IS_FIXED_SIZE = True
- def __init__(self, timestamp, receiver_address, sender_addresses): - self.timestamp = timestamp + def __init__(self, receiver_address, sender_addresses, timestamp = None): + self.timestamp = timestamp if timestamp else datetime.datetime.now() self.receiver_address = receiver_address self.sender_addresses = sender_addresses
def pack(self, link_version): - return NetinfoCell.pack(link_version, self.receiver_address, self.sender_address, self.timestamp) - - @classmethod - def pack(cls, link_version, receiver_address, sender_addresses, timestamp = None): - """ - Payload about our timestamp and versions. - - :param int link_version: link protocol version - :param stem.client.Address receiver_address: address of the receiver - :param list sender_addresses: our addresses - :param datetime timestamp: current time according to our clock - - :returns: **bytes** with a payload for these versions - """ - - if timestamp is None: - timestamp = datetime.datetime.now() - payload = io.BytesIO() - payload.write(Size.LONG.pack(int(datetime_to_unix(timestamp)))) - payload.write(receiver_address.pack()) - payload.write(Size.CHAR.pack(len(sender_addresses))) + payload.write(Size.LONG.pack(int(datetime_to_unix(self.timestamp)))) + payload.write(self.receiver_address.pack()) + payload.write(Size.CHAR.pack(len(self.sender_addresses)))
- for addr in sender_addresses: + for addr in self.sender_addresses: payload.write(addr.pack())
- return cls._pack(link_version, payload.getvalue()) + return NetinfoCell._pack(link_version, payload.getvalue())
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -614,7 +505,7 @@ class NetinfoCell(Cell): addr, content = Address.pop(content) sender_addresses.append(addr)
- return NetinfoCell(datetime.datetime.utcfromtimestamp(timestamp), receiver_address, sender_addresses) + return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp))
def __hash__(self): return _hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses') @@ -655,38 +546,20 @@ class VPaddingCell(Cell): VALUE = 128 IS_FIXED_SIZE = False
- def __init__(self, payload): - self.payload = payload - - def pack(self, link_version): - return VPaddingCell.pack(link_version, payload = self.payload) - - @classmethod - def pack(cls, link_version, size = None, payload = None): - """ - Provides a randomized padding payload. If no size or payload is provided - then this provides padding of an arbitrarily chosen size between 128-1024. - - :param int link_version: link protocol version - :param int size: number of bytes to pad - :param bytes payload: padding payload - - :returns: **bytes** with randomized content - - :raises: **ValueError** if both a size and payload are provided, and they - mismatch - """ - + def __init__(self, size = None, payload = None): if payload is None: payload = os.urandom(size) if size else os.urandom(random.randint(128, 1024)) elif size is not None and size != len(payload): - raise ValueError('VPaddingCell.pack caller specified both a size of %i bytes and payload of %i bytes' % (size, len(payload))) + raise ValueError('VPaddingCell constructor specified both a size of %i bytes and payload of %i bytes' % (size, len(payload))) + + self.payload = payload
- return cls._pack(link_version, payload) + def pack(self, link_version): + return VPaddingCell._pack(link_version, self.payload)
@classmethod def _unpack(cls, content, circ_id, link_version): - return VPaddingCell(content) + return VPaddingCell(payload = content)
def __hash__(self): return _hash_attr(self, 'payload') @@ -707,20 +580,7 @@ class CertsCell(Cell): self.certificates = certs
def pack(self, link_version): - return CertsCell.pack(link_version, self.certificates) - - @classmethod - def pack(cls, link_version, certs): - """ - Provides the payload for a series of certificates. - - :param int link_version: link protocol version - :param list certs: series of :class:`~stem.client.Certificate` for the cell - - :returns: **bytes** with a payload for these versions - """ - - return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs])) + return CertsCell._pack(link_version, Size.CHAR.pack(len(self.certificates)) + ''.join([cert.pack() for cert in self.certificates]))
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -753,38 +613,24 @@ class AuthChallengeCell(Cell): VALUE = 130 IS_FIXED_SIZE = False
- def __init__(self, challenge, methods): - self.challenge = challenge - self.methods = methods - - def pack(self, link_version): - return AuthChallengeCell.pack(link_version, self.methods, self.challenge) - - @classmethod - def pack(cls, link_version, methods, challenge = None): - """ - Provides an authentication challenge. - - :param int link_version: link protocol version - :param list methods: authentication methods we support - :param bytes challenge: randomized string for the receiver to sign - - :returns: **bytes** with a payload for this challenge - """ - - if challenge is None: + def __init__(self, methods, challenge = None): + if not challenge: challenge = os.urandom(AUTH_CHALLENGE_SIZE) elif len(challenge) != AUTH_CHALLENGE_SIZE: raise ValueError('AUTH_CHALLENGE must be %i bytes, but was %i' % (AUTH_CHALLENGE_SIZE, len(challenge)))
+ self.challenge = challenge + self.methods = methods + + def pack(self, link_version): payload = io.BytesIO() - payload.write(challenge) - payload.write(Size.SHORT.pack(len(methods))) + payload.write(self.challenge) + payload.write(Size.SHORT.pack(len(self.methods)))
- for method in methods: + for method in self.methods: payload.write(Size.SHORT.pack(method))
- return cls._pack(link_version, payload.getvalue()) + return AuthChallengeCell._pack(link_version, payload.getvalue())
@classmethod def _unpack(cls, content, circ_id, link_version): @@ -803,7 +649,7 @@ class AuthChallengeCell(Cell): method, content = Size.SHORT.pop(content) methods.append(method)
- return AuthChallengeCell(challenge, methods) + return AuthChallengeCell(methods, challenge)
def __hash__(self): return _hash_attr(self, 'challenge', 'methods') diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py index e2a85b6f..ffead90f 100644 --- a/test/unit/client/cell.py +++ b/test/unit/client/cell.py @@ -125,7 +125,7 @@ class TestCell(unittest.TestCase): self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
auth_challenge_cell, content = Cell.unpack(content, 2) - self.assertEqual(AuthChallengeCell('\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K', [1, 3]), auth_challenge_cell) + self.assertEqual(AuthChallengeCell([1, 3], '\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K'), auth_challenge_cell)
netinfo_cell, content = Cell.unpack(content, 2) self.assertEqual(NetinfoCell, type(netinfo_cell)) @@ -137,13 +137,13 @@ class TestCell(unittest.TestCase):
def test_padding_cell(self): for cell_bytes, payload in PADDING_CELLS.items(): - self.assertEqual(cell_bytes, PaddingCell.pack(2, payload)) + self.assertEqual(cell_bytes, PaddingCell(payload).pack(2)) self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
def test_relay_cell(self): for cell_bytes, (command, command_int, circ_id, stream_id, data, digest) in RELAY_CELLS.items(): - self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command, data, digest, stream_id)) - self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command_int, data, digest, stream_id)) + self.assertEqual(cell_bytes, RelayCell(circ_id, command, data, digest, stream_id).pack(2)) + self.assertEqual(cell_bytes, RelayCell(circ_id, command_int, data, digest, stream_id).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0] self.assertEqual(circ_id, cell.circ_id) @@ -155,8 +155,8 @@ class TestCell(unittest.TestCase):
def test_destroy_cell(self): for cell_bytes, (circ_id, reason, reason_int) in DESTROY_CELLS.items(): - self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason)) - self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason_int)) + self.assertEqual(cell_bytes, DestroyCell(circ_id, reason).pack(5)) + self.assertEqual(cell_bytes, DestroyCell(circ_id, reason_int).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0] self.assertEqual(circ_id, cell.circ_id) @@ -167,33 +167,33 @@ class TestCell(unittest.TestCase):
def test_create_fast_cell(self): for cell_bytes, (circ_id, key_material) in CREATE_FAST_CELLS.items(): - self.assertEqual(cell_bytes, CreateFastCell.pack(5, circ_id, key_material)) + self.assertEqual(cell_bytes, CreateFastCell(circ_id, key_material).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0] self.assertEqual(circ_id, cell.circ_id) self.assertEqual(key_material, cell.key_material)
- self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo') + self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
def test_created_fast_cell(self): for cell_bytes, (circ_id, key_material, derivative_key) in CREATED_FAST_CELLS.items(): - self.assertEqual(cell_bytes, CreatedFastCell.pack(5, circ_id, derivative_key, key_material)) + self.assertEqual(cell_bytes, CreatedFastCell(circ_id, derivative_key, key_material).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0] self.assertEqual(circ_id, cell.circ_id) self.assertEqual(key_material, cell.key_material) self.assertEqual(derivative_key, cell.derivative_key)
- self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo') + self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
def test_versions_cell(self): for cell_bytes, versions in VERSIONS_CELLS.items(): - self.assertEqual(cell_bytes, VersionsCell.pack(versions)) + self.assertEqual(cell_bytes, VersionsCell(versions).pack()) self.assertEqual(versions, Cell.unpack(cell_bytes, 2)[0].versions)
def test_netinfo_cell(self): for cell_bytes, (timestamp, receiver_address, sender_addresses) in NETINFO_CELLS.items(): - self.assertEqual(cell_bytes, NetinfoCell.pack(2, receiver_address, sender_addresses, timestamp)) + self.assertEqual(cell_bytes, NetinfoCell(receiver_address, sender_addresses, timestamp).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0] self.assertEqual(timestamp, cell.timestamp) @@ -202,14 +202,14 @@ class TestCell(unittest.TestCase):
def test_vpadding_cell(self): for cell_bytes, payload in VPADDING_CELLS.items(): - self.assertEqual(cell_bytes, VPaddingCell.pack(2, payload = payload)) + self.assertEqual(cell_bytes, VPaddingCell(payload = payload).pack(2)) self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
- self.assertRaisesRegexp(ValueError, 'VPaddingCell.pack caller specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell.pack, 2, 5, '\x02') + self.assertRaisesRegexp(ValueError, 'VPaddingCell constructor specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell, 5, '\x02')
def test_certs_cell(self): for cell_bytes, certs in CERTS_CELLS.items(): - self.assertEqual(cell_bytes, CertsCell.pack(2, certs)) + self.assertEqual(cell_bytes, CertsCell(certs).pack(2)) self.assertEqual(certs, Cell.unpack(cell_bytes, 2)[0].certificates)
# extra bytes after the last certificate should be ignored @@ -223,7 +223,7 @@ class TestCell(unittest.TestCase):
def test_auth_challenge_cell(self): for cell_bytes, (challenge, methods) in AUTH_CHALLENGE_CELLS.items(): - self.assertEqual(cell_bytes, AuthChallengeCell.pack(2, methods, challenge)) + self.assertEqual(cell_bytes, AuthChallengeCell(methods, challenge).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0] self.assertEqual(challenge, cell.challenge)