[tor-commits] [stem/master] Replace pack function with a method

atagar at torproject.org atagar at torproject.org
Wed Feb 7 19:44:51 UTC 2018


commit f42ec9286f6dbb50fa164873424a23e067dcf7fd
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Jan 27 17:07:08 2018 -0800

    Replace pack function with a method
    
    One sucky thing about python: no method overriding. We can have an instance
    method or class method called pack, but not both.
    
    Trivial to get around by renaming one of them but on relfection I suspect an
    instance method is where we'll want to go in the long run anyway. Lets run with
    this for a bit to see how it goes.
---
 stem/client/cell.py      | 272 ++++++++++-------------------------------------
 test/unit/client/cell.py |  32 +++---
 2 files changed, 75 insertions(+), 229 deletions(-)

diff --git a/stem/client/cell.py b/stem/client/cell.py
index 37c3d5b9..e2e589f5 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -126,7 +126,7 @@ class Cell(object):
     return cls._unpack(payload, circ_id, link_version), content
 
   def pack(self, link_version):
-    raise NotImplementedError('Unpacking not yet implemented for %s cells' % cls.NAME)
+    raise NotImplementedError('Unpacking not yet implemented for %s cells' % type(self).NAME)
 
   @classmethod
   def _pack(cls, link_version, payload, circ_id = 0):
@@ -223,24 +223,16 @@ class PaddingCell(Cell):
   VALUE = 0
   IS_FIXED_SIZE = True
 
-  def __init__(self, payload):
+  def __init__(self, payload = None):
+    if not payload:
+      payload = os.urandom(FIXED_PAYLOAD_LEN)
+    elif len(payload) != FIXED_PAYLOAD_LEN:
+      raise ValueError('Padding payload should be %i bytes, but was %i' % (FIXED_PAYLOAD_LEN, len(payload)))
+
     self.payload = payload
 
   def pack(self, link_version):
-    return PaddingCell.pack(link_version, self.payload)
-
-  @classmethod
-  def pack(cls, link_version, payload = None):
-    """
-    Provides a randomized padding payload.
-
-    :param int link_version: link protocol version
-    :param bytes payload: padding payload
-
-    :returns: **bytes** with randomized content
-    """
-
-    return cls._pack(link_version, payload if payload else os.urandom(FIXED_PAYLOAD_LEN))
+    return PaddingCell._pack(link_version, self.payload)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -297,34 +289,15 @@ class RelayCell(CircuitCell):
       raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
 
   def pack(self, link_version):
-    return RelayCell.pack(link_version, self.circ_id, self.command_int, self.data, self.digest, self.stream_id)
-
-  @classmethod
-  def pack(cls, link_version, circ_id, command, data, digest, stream_id = 0):
-    """
-    Provides payload of a relay cell.
-
-    :param int link_version: link protocol version
-    :param int circ_id: circuit id
-    :param stem.client.RelayCommand command: reason the circuit is being closed
-    :param bytes data: payload of the cell
-    :param int digest: running digest held with the relay
-    :param int stream_id: specific stream this concerns
-
-    :returns: **bytes** to close the circuit
-    """
-
-    cell = RelayCell(circ_id, command, data, digest, stream_id)
-
     payload = io.BytesIO()
-    payload.write(Size.CHAR.pack(cell.command_int))
+    payload.write(Size.CHAR.pack(self.command_int))
     payload.write(Size.SHORT.pack(0))  # 'recognized' field
-    payload.write(Size.SHORT.pack(cell.stream_id))
-    payload.write(Size.LONG.pack(cell.digest))
-    payload.write(Size.SHORT.pack(len(cell.data)))
-    payload.write(cell.data)
+    payload.write(Size.SHORT.pack(self.stream_id))
+    payload.write(Size.LONG.pack(self.digest))
+    payload.write(Size.SHORT.pack(len(self.data)))
+    payload.write(self.data)
 
-    return cls._pack(link_version, payload.getvalue(), circ_id)
+    return RelayCell._pack(link_version, payload.getvalue(), self.circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -353,26 +326,12 @@ class DestroyCell(CircuitCell):
   VALUE = 4
   IS_FIXED_SIZE = True
 
-  def __init__(self, circ_id, reason):
+  def __init__(self, circ_id, reason = stem.client.CloseReason.NONE):
     super(DestroyCell, self).__init__(circ_id)
     self.reason, self.reason_int = stem.client.CloseReason.get(reason)
 
   def pack(self, link_version):
-    return DestroyCell.pack(link_version, self.circ_id, self.reason_int)
-
-  @classmethod
-  def pack(cls, link_version, circ_id, reason = stem.client.CloseReason.NONE):
-    """
-    Provides payload to close the given circuit.
-
-    :param int link_version: link protocol version
-    :param int circ_id: circuit id
-    :param stem.client.CloseReason reason: reason to close the circuit
-
-    :returns: **bytes** to close the circuit
-    """
-
-    return cls._pack(link_version, Size.CHAR.pack(stem.client.CloseReason.get(reason)[1]), circ_id)
+    return DestroyCell._pack(link_version, Size.CHAR.pack(self.reason_int), self.circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -401,7 +360,7 @@ class CreateFastCell(CircuitCell):
   VALUE = 5
   IS_FIXED_SIZE = True
 
-  def __init__(self, circ_id = None, key_material = None):
+  def __init__(self, circ_id, key_material = None):
     if not key_material:
       key_material = os.urandom(HASH_LEN)
     elif len(key_material) != HASH_LEN:
@@ -411,30 +370,7 @@ class CreateFastCell(CircuitCell):
     self.key_material = key_material
 
   def pack(self, link_version):
-    return CreateFastCell.pack(link_version, self.circ_id, self.key_material)
-
-  @classmethod
-  def pack(cls, link_version, circ_id = None, key_material = None):
-    """
-    Provides a randomized circuit construction payload.
-
-    :param int link_version: link protocol version
-    :param int circ_id: circuit id
-    :param bytes key_material: randomized key material
-
-    :returns: **bytes** with our randomized key material
-    """
-
-    cell = CreateFastCell(circ_id, key_material)
-    circ_id = cell.circ_id
-
-    if not circ_id:
-      # When initiating a circuit the v4 link protocol requires us to set the
-      # most significant bit. Otherwise any id will do.
-
-      circ_id = 0x80000000 if link_version >= 4 else 0x01
-
-    return cls._pack(link_version, cell.key_material, circ_id)
+    return CreateFastCell._pack(link_version, self.key_material, self.circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -461,36 +397,21 @@ class CreatedFastCell(CircuitCell):
   VALUE = 6
   IS_FIXED_SIZE = True
 
-  def __init__(self, circ_id, key_material, derivative_key):
+  def __init__(self, circ_id, derivative_key, key_material = None):
+    if not key_material:
+      key_material = os.urandom(HASH_LEN)
+    elif len(key_material) != HASH_LEN:
+      raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material)))
+
+    if len(derivative_key) != HASH_LEN:
+      raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key)))
+
     super(CreatedFastCell, self).__init__(circ_id)
     self.key_material = key_material
     self.derivative_key = derivative_key
 
   def pack(self, link_version):
-    return CreatedFastCell.pack(link_version, self.circ_id, delf.derived_key, self.key_material)
-
-  @classmethod
-  def pack(cls, link_version, circ_id, derivative_key, key_material = None):
-    """
-    Provides a randomized circuit construction payload.
-
-    :param int link_version: link protocol version
-    :param int circ_id: circuit id
-    :param bytes derivative_key: hash proving the relay knows our shared key
-    :param bytes key_material: randomized key material
-
-    :returns: **bytes** with our randomized key material
-    """
-
-    if key_material and len(key_material) != HASH_LEN:
-      raise ValueError('Key material should be %i bytes, but was %i' % (HASH_LEN, len(key_material)))
-    elif len(derivative_key) != HASH_LEN:
-      raise ValueError('Derivatived key should be %i bytes, but was %i' % (HASH_LEN, len(derivative_key)))
-
-    if not key_material:
-      key_material = os.urandom(HASH_LEN)
-
-    return cls._pack(link_version, key_material + derivative_key, circ_id)
+    return CreatedFastCell._pack(link_version, self.key_material + self.derivative_key, self.circ_id)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -499,7 +420,7 @@ class CreatedFastCell(CircuitCell):
     if len(content) != HASH_LEN * 2:
       raise ValueError('Key material and derivatived key should be %i bytes, but was %i' % (HASH_LEN * 2, len(content)))
 
-    return CreatedFastCell(circ_id, content[:HASH_LEN], content[HASH_LEN:])
+    return CreatedFastCell(circ_id, content[HASH_LEN:], content[:HASH_LEN])
 
   def __hash__(self):
     return _hash_attr(self, 'circ_id', 'derivative_key', 'key_material')
@@ -519,24 +440,12 @@ class VersionsCell(Cell):
   def __init__(self, versions):
     self.versions = versions
 
-  def pack(self, link_version):
-    return VersionsCell.pack(self.versions)
-
-  @classmethod
-  def pack(cls, versions):
-    """
-    Provides the payload for a series of link versions.
-
-    :param list versions: link versions to serialize
-
-    :returns: **bytes** with a payload for these versions
-    """
-
+  def pack(self, link_version = None):
     # Used for link version negotiation so we don't have that yet. This is fine
     # since VERSION cells avoid most version dependent attributes.
 
-    payload = b''.join([Size.SHORT.pack(v) for v in versions])
-    return cls._pack(2, payload)
+    payload = b''.join([Size.SHORT.pack(v) for v in self.versions])
+    return VersionsCell._pack(2, payload)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -565,39 +474,21 @@ class NetinfoCell(Cell):
   VALUE = 8
   IS_FIXED_SIZE = True
 
-  def __init__(self, timestamp, receiver_address, sender_addresses):
-    self.timestamp = timestamp
+  def __init__(self, receiver_address, sender_addresses, timestamp = None):
+    self.timestamp = timestamp if timestamp else datetime.datetime.now()
     self.receiver_address = receiver_address
     self.sender_addresses = sender_addresses
 
   def pack(self, link_version):
-    return NetinfoCell.pack(link_version, self.receiver_address, self.sender_address, self.timestamp)
-
-  @classmethod
-  def pack(cls, link_version, receiver_address, sender_addresses, timestamp = None):
-    """
-    Payload about our timestamp and versions.
-
-    :param int link_version: link protocol version
-    :param stem.client.Address receiver_address: address of the receiver
-    :param list sender_addresses: our addresses
-    :param datetime timestamp: current time according to our clock
-
-    :returns: **bytes** with a payload for these versions
-    """
-
-    if timestamp is None:
-      timestamp = datetime.datetime.now()
-
     payload = io.BytesIO()
-    payload.write(Size.LONG.pack(int(datetime_to_unix(timestamp))))
-    payload.write(receiver_address.pack())
-    payload.write(Size.CHAR.pack(len(sender_addresses)))
+    payload.write(Size.LONG.pack(int(datetime_to_unix(self.timestamp))))
+    payload.write(self.receiver_address.pack())
+    payload.write(Size.CHAR.pack(len(self.sender_addresses)))
 
-    for addr in sender_addresses:
+    for addr in self.sender_addresses:
       payload.write(addr.pack())
 
-    return cls._pack(link_version, payload.getvalue())
+    return NetinfoCell._pack(link_version, payload.getvalue())
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -614,7 +505,7 @@ class NetinfoCell(Cell):
       addr, content = Address.pop(content)
       sender_addresses.append(addr)
 
-    return NetinfoCell(datetime.datetime.utcfromtimestamp(timestamp), receiver_address, sender_addresses)
+    return NetinfoCell(receiver_address, sender_addresses, datetime.datetime.utcfromtimestamp(timestamp))
 
   def __hash__(self):
     return _hash_attr(self, 'timestamp', 'receiver_address', 'sender_addresses')
@@ -655,38 +546,20 @@ class VPaddingCell(Cell):
   VALUE = 128
   IS_FIXED_SIZE = False
 
-  def __init__(self, payload):
-    self.payload = payload
-
-  def pack(self, link_version):
-    return VPaddingCell.pack(link_version, payload = self.payload)
-
-  @classmethod
-  def pack(cls, link_version, size = None, payload = None):
-    """
-    Provides a randomized padding payload. If no size or payload is provided
-    then this provides padding of an arbitrarily chosen size between 128-1024.
-
-    :param int link_version: link protocol version
-    :param int size: number of bytes to pad
-    :param bytes payload: padding payload
-
-    :returns: **bytes** with randomized content
-
-    :raises: **ValueError** if both a size and payload are provided, and they
-      mismatch
-    """
-
+  def __init__(self, size = None, payload = None):
     if payload is None:
       payload = os.urandom(size) if size else os.urandom(random.randint(128, 1024))
     elif size is not None and size != len(payload):
-      raise ValueError('VPaddingCell.pack caller specified both a size of %i bytes and payload of %i bytes' % (size, len(payload)))
+      raise ValueError('VPaddingCell constructor specified both a size of %i bytes and payload of %i bytes' % (size, len(payload)))
+
+    self.payload = payload
 
-    return cls._pack(link_version, payload)
+  def pack(self, link_version):
+    return VPaddingCell._pack(link_version, self.payload)
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
-    return VPaddingCell(content)
+    return VPaddingCell(payload = content)
 
   def __hash__(self):
     return _hash_attr(self, 'payload')
@@ -707,20 +580,7 @@ class CertsCell(Cell):
     self.certificates = certs
 
   def pack(self, link_version):
-    return CertsCell.pack(link_version, self.certificates)
-
-  @classmethod
-  def pack(cls, link_version, certs):
-    """
-    Provides the payload for a series of certificates.
-
-    :param int link_version: link protocol version
-    :param list certs: series of :class:`~stem.client.Certificate` for the cell
-
-    :returns: **bytes** with a payload for these versions
-    """
-
-    return cls._pack(link_version, Size.CHAR.pack(len(certs)) + ''.join([cert.pack() for cert in certs]))
+    return CertsCell._pack(link_version, Size.CHAR.pack(len(self.certificates)) + ''.join([cert.pack() for cert in self.certificates]))
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -753,38 +613,24 @@ class AuthChallengeCell(Cell):
   VALUE = 130
   IS_FIXED_SIZE = False
 
-  def __init__(self, challenge, methods):
-    self.challenge = challenge
-    self.methods = methods
-
-  def pack(self, link_version):
-    return AuthChallengeCell.pack(link_version, self.methods, self.challenge)
-
-  @classmethod
-  def pack(cls, link_version, methods, challenge = None):
-    """
-    Provides an authentication challenge.
-
-    :param int link_version: link protocol version
-    :param list methods: authentication methods we support
-    :param bytes challenge: randomized string for the receiver to sign
-
-    :returns: **bytes** with a payload for this challenge
-    """
-
-    if challenge is None:
+  def __init__(self, methods, challenge = None):
+    if not challenge:
       challenge = os.urandom(AUTH_CHALLENGE_SIZE)
     elif len(challenge) != AUTH_CHALLENGE_SIZE:
       raise ValueError('AUTH_CHALLENGE must be %i bytes, but was %i' % (AUTH_CHALLENGE_SIZE, len(challenge)))
 
+    self.challenge = challenge
+    self.methods = methods
+
+  def pack(self, link_version):
     payload = io.BytesIO()
-    payload.write(challenge)
-    payload.write(Size.SHORT.pack(len(methods)))
+    payload.write(self.challenge)
+    payload.write(Size.SHORT.pack(len(self.methods)))
 
-    for method in methods:
+    for method in self.methods:
       payload.write(Size.SHORT.pack(method))
 
-    return cls._pack(link_version, payload.getvalue())
+    return AuthChallengeCell._pack(link_version, payload.getvalue())
 
   @classmethod
   def _unpack(cls, content, circ_id, link_version):
@@ -803,7 +649,7 @@ class AuthChallengeCell(Cell):
       method, content = Size.SHORT.pop(content)
       methods.append(method)
 
-    return AuthChallengeCell(challenge, methods)
+    return AuthChallengeCell(methods, challenge)
 
   def __hash__(self):
     return _hash_attr(self, 'challenge', 'methods')
diff --git a/test/unit/client/cell.py b/test/unit/client/cell.py
index e2a85b6f..ffead90f 100644
--- a/test/unit/client/cell.py
+++ b/test/unit/client/cell.py
@@ -125,7 +125,7 @@ class TestCell(unittest.TestCase):
       self.assertTrue(certs_cell.certificates[i].value.startswith(cert_prefix))
 
     auth_challenge_cell, content = Cell.unpack(content, 2)
-    self.assertEqual(AuthChallengeCell('\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K', [1, 3]), auth_challenge_cell)
+    self.assertEqual(AuthChallengeCell([1, 3], '\x89Y\t\x99\xb2\x1e\xd9*V\xb6\x1bn\n\x05\xd8/\xe3QH\x85\x13Z\x17\xfc\x1c\x00{\xa9\xae\x83^K'), auth_challenge_cell)
 
     netinfo_cell, content = Cell.unpack(content, 2)
     self.assertEqual(NetinfoCell, type(netinfo_cell))
@@ -137,13 +137,13 @@ class TestCell(unittest.TestCase):
 
   def test_padding_cell(self):
     for cell_bytes, payload in PADDING_CELLS.items():
-      self.assertEqual(cell_bytes, PaddingCell.pack(2, payload))
+      self.assertEqual(cell_bytes, PaddingCell(payload).pack(2))
       self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
 
   def test_relay_cell(self):
     for cell_bytes, (command, command_int, circ_id, stream_id, data, digest) in RELAY_CELLS.items():
-      self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command, data, digest, stream_id))
-      self.assertEqual(cell_bytes, RelayCell.pack(2, circ_id, command_int, data, digest, stream_id))
+      self.assertEqual(cell_bytes, RelayCell(circ_id, command, data, digest, stream_id).pack(2))
+      self.assertEqual(cell_bytes, RelayCell(circ_id, command_int, data, digest, stream_id).pack(2))
 
       cell = Cell.unpack(cell_bytes, 2)[0]
       self.assertEqual(circ_id, cell.circ_id)
@@ -155,8 +155,8 @@ class TestCell(unittest.TestCase):
 
   def test_destroy_cell(self):
     for cell_bytes, (circ_id, reason, reason_int) in DESTROY_CELLS.items():
-      self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason))
-      self.assertEqual(cell_bytes, DestroyCell.pack(5, circ_id, reason_int))
+      self.assertEqual(cell_bytes, DestroyCell(circ_id, reason).pack(5))
+      self.assertEqual(cell_bytes, DestroyCell(circ_id, reason_int).pack(5))
 
       cell = Cell.unpack(cell_bytes, 5)[0]
       self.assertEqual(circ_id, cell.circ_id)
@@ -167,33 +167,33 @@ class TestCell(unittest.TestCase):
 
   def test_create_fast_cell(self):
     for cell_bytes, (circ_id, key_material) in CREATE_FAST_CELLS.items():
-      self.assertEqual(cell_bytes, CreateFastCell.pack(5, circ_id, key_material))
+      self.assertEqual(cell_bytes, CreateFastCell(circ_id, key_material).pack(5))
 
       cell = Cell.unpack(cell_bytes, 5)[0]
       self.assertEqual(circ_id, cell.circ_id)
       self.assertEqual(key_material, cell.key_material)
 
-    self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo')
+    self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
 
   def test_created_fast_cell(self):
     for cell_bytes, (circ_id, key_material, derivative_key) in CREATED_FAST_CELLS.items():
-      self.assertEqual(cell_bytes, CreatedFastCell.pack(5, circ_id, derivative_key, key_material))
+      self.assertEqual(cell_bytes, CreatedFastCell(circ_id, derivative_key, key_material).pack(5))
 
       cell = Cell.unpack(cell_bytes, 5)[0]
       self.assertEqual(circ_id, cell.circ_id)
       self.assertEqual(key_material, cell.key_material)
       self.assertEqual(derivative_key, cell.derivative_key)
 
-    self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell.pack, 2, 5, 'boo')
+    self.assertRaisesRegexp(ValueError, 'Key material should be 20 bytes, but was 3', CreateFastCell, 5, 'boo')
 
   def test_versions_cell(self):
     for cell_bytes, versions in VERSIONS_CELLS.items():
-      self.assertEqual(cell_bytes, VersionsCell.pack(versions))
+      self.assertEqual(cell_bytes, VersionsCell(versions).pack())
       self.assertEqual(versions, Cell.unpack(cell_bytes, 2)[0].versions)
 
   def test_netinfo_cell(self):
     for cell_bytes, (timestamp, receiver_address, sender_addresses) in NETINFO_CELLS.items():
-      self.assertEqual(cell_bytes, NetinfoCell.pack(2, receiver_address, sender_addresses, timestamp))
+      self.assertEqual(cell_bytes, NetinfoCell(receiver_address, sender_addresses, timestamp).pack(2))
 
       cell = Cell.unpack(cell_bytes, 2)[0]
       self.assertEqual(timestamp, cell.timestamp)
@@ -202,14 +202,14 @@ class TestCell(unittest.TestCase):
 
   def test_vpadding_cell(self):
     for cell_bytes, payload in VPADDING_CELLS.items():
-      self.assertEqual(cell_bytes, VPaddingCell.pack(2, payload = payload))
+      self.assertEqual(cell_bytes, VPaddingCell(payload = payload).pack(2))
       self.assertEqual(payload, Cell.unpack(cell_bytes, 2)[0].payload)
 
-    self.assertRaisesRegexp(ValueError, 'VPaddingCell.pack caller specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell.pack, 2, 5, '\x02')
+    self.assertRaisesRegexp(ValueError, 'VPaddingCell constructor specified both a size of 5 bytes and payload of 1 bytes', VPaddingCell, 5, '\x02')
 
   def test_certs_cell(self):
     for cell_bytes, certs in CERTS_CELLS.items():
-      self.assertEqual(cell_bytes, CertsCell.pack(2, certs))
+      self.assertEqual(cell_bytes, CertsCell(certs).pack(2))
       self.assertEqual(certs, Cell.unpack(cell_bytes, 2)[0].certificates)
 
     # extra bytes after the last certificate should be ignored
@@ -223,7 +223,7 @@ class TestCell(unittest.TestCase):
 
   def test_auth_challenge_cell(self):
     for cell_bytes, (challenge, methods) in AUTH_CHALLENGE_CELLS.items():
-      self.assertEqual(cell_bytes, AuthChallengeCell.pack(2, methods, challenge))
+      self.assertEqual(cell_bytes, AuthChallengeCell(methods, challenge).pack(2))
 
       cell = Cell.unpack(cell_bytes, 2)[0]
       self.assertEqual(challenge, cell.challenge)





More information about the tor-commits mailing list