commit f42ec9286f6dbb50fa164873424a23e067dcf7fd
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sat Jan 27 17:07:08 2018 -0800
Replace pack function with a method
One sucky thing about python: no method overriding. We can have an instance
method or class method called pack, but not both.
Trivial to get around by renaming one of them but on relfection I suspect an
instance method is where we'll want to go in the long run anyway. Lets run with
this for a bit to see how it goes.
---
stem/client/cell.py | 272 ++++++++++-------------------------------------
test/unit/client/cell.py | 32 +++---
2 files changed, 75 insertions(+), 229 deletions(-)
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 37c3d5b9..e2e589f5 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -126,7 +126,7 @@ class Cell(object):
return cls._unpack(payload, circ_id, link_version), content
def pack(self, link_version):
- raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME)
+ raise NotImplementedError('Unpacking not yet implemented for %s cells' % type(self).NAME)
@classmethod
def _pack(cls, link_version, payload, circ_id = 0):
@@ -223,24 +223,16 @@ class PaddingCell(Cell):
VALUE = 0
IS_FIXED_SIZE = True
- def __init__(self, payload):
+ def __init__(self, payload = None):
+ if not payload:
+ payload = os.urandom(FIXED_PAYLOAD_LEN)
+ elif len(payload) != FIXED_PAYLOAD_LEN:
+ raise ValueError('Padding payload should be %i bytes, but was %i' % (FIXED_PAYLOAD_LEN, len(payload)))
+
self.payload = payload
def pack(self, link_version):
- return PaddingCell.pack(link_version, self.payload)
-
- @classmethod
- def pack(cls, link_version, payload = None):
- """
- Provides a randomized padding payload.
-
- :param int link_version: link protocol version
- :param bytes payload: padding payload
-
- :returns: **bytes** with randomized content
- """
-
- return cls._pack(link_version, payload if payload else os.urandom(FIXED_PAYLOAD_LEN))
+ return PaddingCell._pack(link_version, self.payload)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -297,34 +289,15 @@ class RelayCell(CircuitCell):
raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
def pack(self, link_version):
- return RelayCell.pack(link_version, self.circ_id, self.command_int, self.data, self.digest, self.stream_id)
-
- @classmethod
- def pack(cls, link_version, circ_id, command, data, digest, stream_id = 0):
- """
- Provides payload of a relay cell.
-
- :param int link_version: link protocol version
- :param int circ_id: circuit id
- :param stem.client.RelayCommand command: reason the circuit is being closed
- :param bytes data: payload of the cell
- :param int digest: running digest held with the relay
- :param int stream_id: specific stream this concerns
-
- :returns: **bytes** to close the circuit
- """
-
- cell = RelayCell(circ_id, command, data, digest, stream_id)
-
payload = io.BytesIO()
- payload.write(Size.CHAR.pack(cell.command_int))
+ payload.write(Size.CHAR.pack(self.command_int))
payload.write(Size.SHORT.pack(0)) # 'recognized' field
- payload.write(Size.SHORT.pack(cell.stream_id))
- payload.write(Size.LONG.pack(cell.digest))
- payload.write(Size.SHORT.pack(len(cell.data)))
- payload.write(cell.data)
+ payload.write(Size.SHORT.pack(self.stream_id))
+ payload.write(Size.LONG.pack(self.digest))
+ payload.write(Size.SHORT.pack(len(self.data)))
+ payload.write(self.data)
- return cls._pack(link_version, payload.getvalue(), circ_id)
+ return RelayCell._pack(link_version, payload.getvalue(), self.circ_id)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -353,26 +326,12 @@ class DestroyCell(CircuitCell):
VALUE = 4
IS_FIXED_SIZE = True
- def __init__(self, circ_id, reason):
+ def __init__(self, circ_id, reason = stem.client.CloseReason.NONE):
super(DestroyCell, self).__init__(circ_id)
self.reason, self.reason_int = stem.client.CloseReason.get(reason)
def pack(self, link_version):
- return DestroyCell.pack(link_version, self.circ_id, self.reason_int)
-
- @classmethod
- def pack(cls, link_version, circ_id, reason = stem.client.CloseReason.NONE):
- """
- Provides payload to close the given circuit.
-
- :param int link_version: link protocol version
- :param int circ_id: circuit id
- :param stem.client.CloseReason reason: reason to close the circuit
-
- :returns: **bytes** to close the circuit
- """
-
- return cls._pack(link_version, Size.CHAR.pack(stem.client.CloseReason.get(reason)[1]), circ_id)
+ return DestroyCell._pack(link_version, Size.CHAR.pack(self.reason_int), self.circ_id)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -401,7 +360,7 @@ class CreateFastCell(CircuitCell):
VALUE = 5
IS_FIXED_SIZE = True
- def __init__(self, circ_id = None, key_material = None):
+ def __init__(self, circ_id, key_material = None):
if not key_material:
key_material = os.urandom(HASH_LEN)
elif len(key_material) != HASH_LEN:
@@ -411,30 +370,7 @@ class CreateFastCell(CircuitCell):
self.key_material = key_material
def pack(self, link_version):
- return CreateFastCell.pack(link_version, self.circ_id, self.key_material)
-
- @classmethod
- def pack(cls, link_version, circ_id = None, key_material = None):
- """
- Provides a randomized circuit construction payload.
-
- :param int link_version: link protocol version
- :param int circ_id: circuit id
- :param bytes key_material: randomized key material
-
- :returns: **bytes** with our randomized key material
- """
-
- cell = CreateFastCell(circ_id, key_material)
- circ_id = cell.circ_id
-
- if not circ_id:
- # When initiating a circuit the v4 link protocol requires us to set the
- # most significant bit. Otherwise any id will do.
-
- circ_id = 0x80000000 if link_version >= 4 else 0x01
-
- return cls._pack(link_version, cell.key_material, circ_id)
+ return CreateFastCell._pack(link_version, self.key_material, self.circ_id)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -461,36 +397,21 @@ class CreatedFastCell(CircuitCell):
VALUE = 6
IS_FIXED_SIZE = True
- def __init__(self, circ_id, key_material, derivative_key):
+ def __init__(self, circ_id, derivative_key, key_material = None):
+ if not key_material:
+ key_material = os.urandom(HASH_LEN)
+ elif len(key_material) != HASH_LEN:
+ raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material)))
+
+ if len(derivative_key) != HASH_LEN:
+ raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key)))
+
super(CreatedFastCell, self).__init__(circ_id)
self.key_material = key_material
self.derivative_key = derivative_key
def pack(self, link_version):
- return CreatedFastCell.pack(link_version, self.circ_id, delf.derived_key, self.key_material)
-
- @classmethod
- def pack(cls, link_version, circ_id, derivative_key, key_material = None):
- """
- Provides a randomized circuit construction payload.
-
- :param int link_version: link protocol version
- :param int circ_id: circuit id
- :param bytes derivative_key: hash proving the relay knows our shared key
- :param bytes key_material: randomized key material
-
- :returns: **bytes** with our randomized key material
- """
-
- if key_material and len(key_material) != HASH_LEN:
- raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material)))
- elif len(derivative_key) != HASH_LEN:
- raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key)))
-
- if not key_material:
- key_material = os.urandom(HASH_LEN)
-
- return cls._pack(link_version, key_material + derivative_key, circ_id)
+ return CreatedFastCell._pack(link_version, self.key_material + self.derivative_key, self.circ_id)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -499,7 +420,7 @@ class CreatedFastCell(CircuitCell):
if len(content) != HASH_LEN * 2:
raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
- return CreatedFastCell(circ_id, content[:HASH_LEN], content[HASH_LEN:])
+ return CreatedFastCell(circ_id, content[HASH_LEN:], content[:HASH_LEN])
def __hash__(self):
return _hash_attr(self, 'circ_id', 'derivative_key', 'key_material')
@@ -519,24 +440,12 @@ class VersionsCell(Cell):
def __init__(self, versions):
self.versions = versions
- def pack(self, link_version):
- return VersionsCell.pack(self.versions)
-
- @classmethod
- def pack(cls, versions):
- """
- Provides the payload for a series of link versions.
-
- :param list versions: link versions to serialize
-
- :returns: **bytes** with a payload for these versions
- """
-
+ def pack(self, link_version = None):
# Used for link version negotiation so we don't have that yet. This is fine
# since VERSION cells avoid most version dependent attributes.
- payload = b''.join([Size.SHORT.pack(v) for v in versions])
- return cls._pack(2, payload)
+ payload = b''.join([Size.SHORT.pack(v) for v in self.versions])
+ return VersionsCell._pack(2, payload)
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -565,39 +474,21 @@ class NetinfoCell(Cell):
VALUE = 8
IS_FIXED_SIZE = True
- def __init__(self, timestamp, receiver_address, sender_addresses):
- self.timestamp = timestamp
+ def __init__(self, receiver_address, sender_addresses, timestamp = None):
+ self.timestamp = timestamp if timestamp else datetime.datetime.now()
self.receiver_address = receiver_address
self.sender_addresses = sender_addresses
def pack(self, link_version):
- return NetinfoCell.pack(link_version, self.receiver_address, self.sender_address, self.timestamp)
-
- @classmethod
- def pack(cls, link_version, receiver_address, sender_addresses, timestamp = None):
- """
- Payload about our timestamp and versions.
-
- :param int link_version: link protocol version
- :param stem.client.Address receiver_address: address of the receiver
- :param list sender_addresses: our addresses
- :param datetime timestamp: current time according to our clock
-
- :returns: **bytes** with a payload for these versions
- """
-
- if timestamp is None:
- timestamp = datetime.datetime.now()
-
payload = io.BytesIO()
- payload.write(Size.LONG.pack(int(datetime_to_unix(timestamp))))
- payload.write(receiver_address.pack())
- payload.write(Size.CHAR.pack(len(sender_addresses)))
+ payload.write(Size.LONG.pack(int(datetime_to_unix(self.timestamp))))
+ payload.write(self.receiver_address.pack())
+ payload.write(Size.CHAR.pack(len(self.sender_addresses)))
- for addr in sender_addresses:
+ for addr in self.sender_addresses:
payload.write(addr.pack())
- return cls._pack(link_version, payload.getvalue())
+ return NetinfoCell._pack(link_version, payload.getvalue())
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -614,7 +505,7 @@ class NetinfoCell(Cell):
addr, content = Address.pop(content)
sender_addresses.append(addr)
- return NetinfoCell(datetime.datetime.utcfromtimestamp(timestamp), receiver_address, sender_addresses)
+ return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp))
def __hash__(self):
return _hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses')
@@ -655,38 +546,20 @@ class VPaddingCell(Cell):
VALUE = 128
IS_FIXED_SIZE = False
- def __init__(self, payload):
- self.payload = payload
-
- def pack(self, link_version):
- return VPaddingCell.pack(link_version, payload = self.payload)
-
- @classmethod
- def pack(cls, link_version, size = None, payload = None):
- """
- Provides a randomized padding payload. If no size or payload is provided
- then this provides padding of an arbitrarily chosen size between 128-1024.
-
- :param int link_version: link protocol version
- :param int size: number of bytes to pad
- :param bytes payload: padding payload
-
- :returns: **bytes** with randomized content
-
- :raises: **ValueError** if both a size and payload are provided, and they
- mismatch
- """
-
+ def __init__(self, size = None, payload = None):
if payload is None:
payload = os.urandom(size) if size else os.urandom(random.randint(128, 1024))
elif size is not None and size != len(payload):
- raise ValueError('VPaddingCell.pack caller specified both a size of %i bytes and payload of %i bytes' % (size, len(payload)))
+ raise ValueError('VPaddingCell constructor specified both a size of %i bytes and payload of %i bytes' % (size, len(payload)))
+
+ self.payload = payload
- return cls._pack(link_version, payload)
+ def pack(self, link_version):
+ return VPaddingCell._pack(link_version, self.payload)
@classmethod
def _unpack(cls, content, circ_id, link_version):
- return VPaddingCell(content)
+ return VPaddingCell(payload = content)
def __hash__(self):
return _hash_attr(self, 'payload')
@@ -707,20 +580,7 @@ class CertsCell(Cell):
self.certificates = certs
def pack(self, link_version):
- return CertsCell.pack(link_version, self.certificates)
-
- @classmethod
- def pack(cls, link_version, certs):
- """
- Provides the payload for a series of certificates.
-
- :param int link_version: link protocol version
- :param list certs: series of :class:`~stem.client.Certificate` for the cell
-
- :returns: **bytes** with a payload for these versions
- """
-
- return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs]))
+ return CertsCell._pack(link_version, Size.CHAR.pack(len(self.certificates)) + ''.join([cert.pack() for cert in self.certificates]))
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -753,38 +613,24 @@ class AuthChallengeCell(Cell):
VALUE = 130
IS_FIXED_SIZE = False
- def __init__(self, challenge, methods):
- self.challenge = challenge
- self.methods = methods
-
- def pack(self, link_version):
- return AuthChallengeCell.pack(link_version, self.methods, self.challenge)
-
- @classmethod
- def pack(cls, link_version, methods, challenge = None):
- """
- Provides an authentication challenge.
-
- :param int link_version: link protocol version
- :param list methods: authentication methods we support
- :param bytes challenge: randomized string for the receiver to sign
-
- :returns: **bytes** with a payload for this challenge
- """
-
- if challenge is None:
+ def __init__(self, methods, challenge = None):
+ if not challenge:
challenge = os.urandom(AUTH_CHALLENGE_SIZE)
elif len(challenge) != AUTH_CHALLENGE_SIZE:
raise ValueError('AUTH_CHALLENGE must be %i bytes, but was %i' % (AUTH_CHALLENGE_SIZE, len(challenge)))
+ self.challenge = challenge
+ self.methods = methods
+
+ def pack(self, link_version):
payload = io.BytesIO()
- payload.write(challenge)
- payload.write(Size.SHORT.pack(len(methods)))
+ payload.write(self.challenge)
+ payload.write(Size.SHORT.pack(len(self.methods)))
- for method in methods:
+ for method in self.methods:
payload.write(Size.SHORT.pack(method))
- return cls._pack(link_version, payload.getvalue())
+ return AuthChallengeCell._pack(link_version, payload.getvalue())
@classmethod
def _unpack(cls, content, circ_id, link_version):
@@ -803,7 +649,7 @@ class AuthChallengeCell(Cell):
method, content = Size.SHORT.pop(content)
methods.append(method)
- return AuthChallengeCell(challenge, methods)
+ return AuthChallengeCell(methods, challenge)
def __hash__(self):
return _hash_attr(self, 'challenge', 'methods')
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py
index e2a85b6f..ffead90f 100644
--- a/test/unit/client/cell.py
+++ b/test/unit/client/cell.py
@@ -125,7 +125,7 @@ class TestCell(unittest.TestCase):
self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
auth_challenge_cell, content = Cell.unpack(content, 2)
- self.assertEqual(AuthChallengeCell('\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K', [1, 3]), auth_challenge_cell)
+ self.assertEqual(AuthChallengeCell([1, 3], '\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K'), auth_challenge_cell)
netinfo_cell, content = Cell.unpack(content, 2)
self.assertEqual(NetinfoCell, type(netinfo_cell))
@@ -137,13 +137,13 @@ class TestCell(unittest.TestCase):
def test_padding_cell(self):
for cell_bytes, payload in PADDING_CELLS.items():
- self.assertEqual(cell_bytes, PaddingCell.pack(2, payload))
+ self.assertEqual(cell_bytes, PaddingCell(payload).pack(2))
self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
def test_relay_cell(self):
for cell_bytes, (command, command_int, circ_id, stream_id, data, digest) in RELAY_CELLS.items():
- self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command, data, digest, stream_id))
- self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command_int, data, digest, stream_id))
+ self.assertEqual(cell_bytes, RelayCell(circ_id, command, data, digest, stream_id).pack(2))
+ self.assertEqual(cell_bytes, RelayCell(circ_id, command_int, data, digest, stream_id).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0]
self.assertEqual(circ_id, cell.circ_id)
@@ -155,8 +155,8 @@ class TestCell(unittest.TestCase):
def test_destroy_cell(self):
for cell_bytes, (circ_id, reason, reason_int) in DESTROY_CELLS.items():
- self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason))
- self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason_int))
+ self.assertEqual(cell_bytes, DestroyCell(circ_id, reason).pack(5))
+ self.assertEqual(cell_bytes, DestroyCell(circ_id, reason_int).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0]
self.assertEqual(circ_id, cell.circ_id)
@@ -167,33 +167,33 @@ class TestCell(unittest.TestCase):
def test_create_fast_cell(self):
for cell_bytes, (circ_id, key_material) in CREATE_FAST_CELLS.items():
- self.assertEqual(cell_bytes, CreateFastCell.pack(5, circ_id, key_material))
+ self.assertEqual(cell_bytes, CreateFastCell(circ_id, key_material).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0]
self.assertEqual(circ_id, cell.circ_id)
self.assertEqual(key_material, cell.key_material)
- self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo')
+ self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
def test_created_fast_cell(self):
for cell_bytes, (circ_id, key_material, derivative_key) in CREATED_FAST_CELLS.items():
- self.assertEqual(cell_bytes, CreatedFastCell.pack(5, circ_id, derivative_key, key_material))
+ self.assertEqual(cell_bytes, CreatedFastCell(circ_id, derivative_key, key_material).pack(5))
cell = Cell.unpack(cell_bytes, 5)[0]
self.assertEqual(circ_id, cell.circ_id)
self.assertEqual(key_material, cell.key_material)
self.assertEqual(derivative_key, cell.derivative_key)
- self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo')
+ self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
def test_versions_cell(self):
for cell_bytes, versions in VERSIONS_CELLS.items():
- self.assertEqual(cell_bytes, VersionsCell.pack(versions))
+ self.assertEqual(cell_bytes, VersionsCell(versions).pack())
self.assertEqual(versions, Cell.unpack(cell_bytes, 2)[0].versions)
def test_netinfo_cell(self):
for cell_bytes, (timestamp, receiver_address, sender_addresses) in NETINFO_CELLS.items():
- self.assertEqual(cell_bytes, NetinfoCell.pack(2, receiver_address, sender_addresses, timestamp))
+ self.assertEqual(cell_bytes, NetinfoCell(receiver_address, sender_addresses, timestamp).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0]
self.assertEqual(timestamp, cell.timestamp)
@@ -202,14 +202,14 @@ class TestCell(unittest.TestCase):
def test_vpadding_cell(self):
for cell_bytes, payload in VPADDING_CELLS.items():
- self.assertEqual(cell_bytes, VPaddingCell.pack(2, payload = payload))
+ self.assertEqual(cell_bytes, VPaddingCell(payload = payload).pack(2))
self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
- self.assertRaisesRegexp(ValueError, 'VPaddingCell.pack caller specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell.pack, 2, 5, '\x02')
+ self.assertRaisesRegexp(ValueError, 'VPaddingCell constructor specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell, 5, '\x02')
def test_certs_cell(self):
for cell_bytes, certs in CERTS_CELLS.items():
- self.assertEqual(cell_bytes, CertsCell.pack(2, certs))
+ self.assertEqual(cell_bytes, CertsCell(certs).pack(2))
self.assertEqual(certs, Cell.unpack(cell_bytes, 2)[0].certificates)
# extra bytes after the last certificate should be ignored
@@ -223,7 +223,7 @@ class TestCell(unittest.TestCase):
def test_auth_challenge_cell(self):
for cell_bytes, (challenge, methods) in AUTH_CHALLENGE_CELLS.items():
- self.assertEqual(cell_bytes, AuthChallengeCell.pack(2, methods, challenge))
+ self.assertEqual(cell_bytes, AuthChallengeCell(methods, challenge).pack(2))
cell = Cell.unpack(cell_bytes, 2)[0]
self.assertEqual(challenge, cell.challenge)