[tor-commits] [stem/master] Make LinkProtocol an integer subclass

atagar at torproject.org atagar at torproject.org
Tue Jul 31 19:21:39 UTC 2018


commit 24206c3b7301ed2fbd2a0bd539f951453dea5684
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Jul 31 11:47:12 2018 -0700

    Make LinkProtocol an integer subclass
    
    Link protocols are ints. I made a LinkProtocol class so we can bundle
    protocol characteristics (field size and such) but at the end of the
    day it's nice for LinkProtocol to still behave like an int.
    
    Dave made the good point that LinkProtocols should be comparable with
    ints...
    
      https://trac.torproject.org/projects/tor/ticket/26432
    
    Taking that one step further and making them in int subclass instead.
    Experimenting, the __init__ and __new__ methods seem to work in practice
    for our purposes but the python docs advise __new__ so using that...
    
      https://jfine-python-classes.readthedocs.io/en/latest/subclass-int.html
---
 stem/client/__init__.py           |  2 +-
 stem/client/cell.py               | 10 ++++-----
 stem/client/datatype.py           | 35 ++++++++++++++++++++++--------
 test/settings.cfg                 |  1 +
 test/unit/client/kdf.py           |  2 +-
 test/unit/client/link_protocol.py | 45 +++++++++++++++++++++++++++++++++++++++
 6 files changed, 79 insertions(+), 16 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index a12959f7..24b6de38 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -52,7 +52,7 @@ class Relay(object):
   """
 
   def __init__(self, orport, link_protocol):
-    self.link_protocol = LinkProtocol.for_version(link_protocol)
+    self.link_protocol = LinkProtocol(link_protocol)
     self._orport = orport
     self._orport_lock = threading.RLock()
     self._circuits = {}
diff --git a/stem/client/cell.py b/stem/client/cell.py
index 51e2fbd3..ec214018 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -166,7 +166,7 @@ class Cell(object):
       * NotImplementedError if unable to unpack this cell type
     """
 
-    link_protocol = LinkProtocol.for_version(link_protocol)
+    link_protocol = LinkProtocol(link_protocol)
 
     circ_id, content = link_protocol.circ_id_size.pop(content)
     command, content = Size.CHAR.pop(content)
@@ -216,7 +216,7 @@ class Cell(object):
 
       circ_id = 0  # cell doesn't concern a circuit, default field to zero
 
-    link_protocol = LinkProtocol.for_version(link_protocol)
+    link_protocol = LinkProtocol(link_protocol)
 
     cell = bytearray()
     cell += link_protocol.circ_id_size.pack(circ_id)
@@ -230,10 +230,10 @@ class Cell(object):
     # pad fixed sized cells to the required length
 
     if cls.IS_FIXED_SIZE:
-      if len(cell) > link_protocol.fixed_cell_len:
-        raise ValueError('Cell of type %s is too large (%i bytes), must not be more than %i. Check payload size (was %i bytes)' % (cls.NAME, len(cell), link_protocol.fixed_cell_len, len(payload)))
+      if len(cell) > link_protocol.fixed_cell_length:
+        raise ValueError('Cell of type %s is too large (%i bytes), must not be more than %i. Check payload size (was %i bytes)' % (cls.NAME, len(cell), link_protocol.fixed_cell_length, len(payload)))
 
-      cell += ZERO * (link_protocol.fixed_cell_len - len(cell))
+      cell += ZERO * (link_protocol.fixed_cell_length - len(cell))
 
     return bytes(cell)
 
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index 22827245..4859e24a 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -228,7 +228,7 @@ def split(content, size):
   return content[:size], content[size:]
 
 
-class LinkProtocol(collections.namedtuple('LinkProtocol', ['version', 'circ_id_size', 'fixed_cell_len', 'first_circ_id'])):
+class LinkProtocol(int):
   """
   Constants that vary by our link protocol version.
 
@@ -239,18 +239,35 @@ class LinkProtocol(collections.namedtuple('LinkProtocol', ['version', 'circ_id_s
     from a range that's determined by our link protocol.
   """
 
-  @staticmethod
-  def for_version(version):
+  def __new__(cls, version):
     if isinstance(version, LinkProtocol):
       return version  # already a LinkProtocol
-    elif isinstance(version, int):
-      circ_id_size = Size.LONG if version > 3 else Size.SHORT
-      fixed_cell_len = 514 if version > 3 else 512
-      first_circ_id = 0x80000000 if version > 3 else 0x01
 
-      return LinkProtocol(version, circ_id_size, fixed_cell_len, first_circ_id)
+    protocol = int.__new__(cls, version)
+    protocol.version = version
+    protocol.circ_id_size = Size.LONG if version > 3 else Size.SHORT
+    protocol.fixed_cell_length = 514 if version > 3 else 512
+    protocol.first_circ_id = 0x80000000 if version > 3 else 0x01
+
+    return protocol
+
+  def __hash__(self):
+    # All LinkProtocol attributes can be derived from our version, so that's
+    # all we need in our hash. Offsetting by our type so we don't hash conflict
+    # with ints.
+
+    return self.version * hash(str(type(self)))
+
+  def __eq__(self, other):
+    if isinstance(other, int):
+      return self.version == other
+    elif isinstance(other, LinkProtocol):
+      return hash(self) == hash(other)
     else:
-      raise TypeError('LinkProtocol.for_version() should receiving an int, not %s' % type(version).__name__)
+      return False
+
+  def __ne__(self, other):
+    return not self == other
 
   def __int__(self):
     return self.version
diff --git a/test/settings.cfg b/test/settings.cfg
index b8142b5c..8423614b 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -243,6 +243,7 @@ test.unit_tests
 |test.unit.response.mapaddress.TestMapAddressResponse
 |test.unit.client.size.TestSize
 |test.unit.client.address.TestAddress
+|test.unit.client.link_protocol.TestLinkProtocol
 |test.unit.client.certificate.TestCertificate
 |test.unit.client.kdf.TestKDF
 |test.unit.client.cell.TestCell
diff --git a/test/unit/client/kdf.py b/test/unit/client/kdf.py
index b9fcb9ce..085a6657 100644
--- a/test/unit/client/kdf.py
+++ b/test/unit/client/kdf.py
@@ -1,5 +1,5 @@
 """
-Unit tests for stem.client.KDF.
+Unit tests for stem.client.datatype.KDF.
 """
 
 import unittest
diff --git a/test/unit/client/link_protocol.py b/test/unit/client/link_protocol.py
new file mode 100644
index 00000000..ffaa30c2
--- /dev/null
+++ b/test/unit/client/link_protocol.py
@@ -0,0 +1,45 @@
+"""
+Unit tests for stem.client.datatype.LinkProtocol.
+"""
+
+import unittest
+
+from stem.client.datatype import Size, LinkProtocol
+
+
+class TestLinkProtocol(unittest.TestCase):
+  def test_invalid_type(self):
+    self.assertRaises(ValueError, LinkProtocol, 'hello')
+
+  def test_attributes(self):
+    protocol = LinkProtocol(1)
+    self.assertEqual(1, protocol.version)
+    self.assertEqual(Size.SHORT, protocol.circ_id_size)
+    self.assertEqual(512, protocol.fixed_cell_length)
+    self.assertEqual(0x01, protocol.first_circ_id)
+
+    protocol = LinkProtocol(10)
+    self.assertEqual(10, protocol.version)
+    self.assertEqual(Size.LONG, protocol.circ_id_size)
+    self.assertEqual(514, protocol.fixed_cell_length)
+    self.assertEqual(0x80000000, protocol.first_circ_id)
+
+  def test_use_as_int(self):
+    protocol = LinkProtocol(5)
+
+    self.assertEqual(7, protocol + 2)
+    self.assertEqual(3, protocol - 2)
+    self.assertEqual(15, protocol * 3)
+    self.assertEqual(1, protocol / 3)
+
+  def test_equality(self):
+    # LinkProtocols should be comparable with both other LinkProtocols and
+    # integers.
+
+    protocol = LinkProtocol(1)
+
+    self.assertEqual(LinkProtocol(1), protocol)
+    self.assertNotEqual(LinkProtocol(2), protocol)
+
+    self.assertEqual(1, protocol)
+    self.assertNotEqual(2, protocol)



More information about the tor-commits mailing list