commit 24206c3b7301ed2fbd2a0bd539f951453dea5684 Author: Damian Johnson atagar@torproject.org Date: Tue Jul 31 11:47:12 2018 -0700
Make LinkProtocol an integer subclass
Link protocols are ints. I made a LinkProtocol class so we can bundle protocol characteristics (field size and such) but at the end of the day it's nice for LinkProtocol to still behave like an int.
Dave made the good point that LinkProtocols should be comparable with ints...
https://trac.torproject.org/projects/tor/ticket/26432
Taking that one step further and making them in int subclass instead. Experimenting, the __init__ and __new__ methods seem to work in practice for our purposes but the python docs advise __new__ so using that...
https://jfine-python-classes.readthedocs.io/en/latest/subclass-int.html --- stem/client/__init__.py | 2 +- stem/client/cell.py | 10 ++++----- stem/client/datatype.py | 35 ++++++++++++++++++++++-------- test/settings.cfg | 1 + test/unit/client/kdf.py | 2 +- test/unit/client/link_protocol.py | 45 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 79 insertions(+), 16 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index a12959f7..24b6de38 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -52,7 +52,7 @@ class Relay(object): """
def __init__(self, orport, link_protocol): - self.link_protocol = LinkProtocol.for_version(link_protocol) + self.link_protocol = LinkProtocol(link_protocol) self._orport = orport self._orport_lock = threading.RLock() self._circuits = {} diff --git a/stem/client/cell.py b/stem/client/cell.py index 51e2fbd3..ec214018 100644 --- a/stem/client/cell.py +++ b/stem/client/cell.py @@ -166,7 +166,7 @@ class Cell(object): * NotImplementedError if unable to unpack this cell type """
- link_protocol = LinkProtocol.for_version(link_protocol) + link_protocol = LinkProtocol(link_protocol)
circ_id, content = link_protocol.circ_id_size.pop(content) command, content = Size.CHAR.pop(content) @@ -216,7 +216,7 @@ class Cell(object):
circ_id = 0 # cell doesn't concern a circuit, default field to zero
- link_protocol = LinkProtocol.for_version(link_protocol) + link_protocol = LinkProtocol(link_protocol)
cell = bytearray() cell += link_protocol.circ_id_size.pack(circ_id) @@ -230,10 +230,10 @@ class Cell(object): # pad fixed sized cells to the required length
if cls.IS_FIXED_SIZE: - if len(cell) > link_protocol.fixed_cell_len: - raise ValueError('Cell of type %s is too large (%i bytes), must not be more than %i. Check payload size (was %i bytes)' % (cls.NAME, len(cell), link_protocol.fixed_cell_len, len(payload))) + if len(cell) > link_protocol.fixed_cell_length: + raise ValueError('Cell of type %s is too large (%i bytes), must not be more than %i. Check payload size (was %i bytes)' % (cls.NAME, len(cell), link_protocol.fixed_cell_length, len(payload)))
- cell += ZERO * (link_protocol.fixed_cell_len - len(cell)) + cell += ZERO * (link_protocol.fixed_cell_length - len(cell))
return bytes(cell)
diff --git a/stem/client/datatype.py b/stem/client/datatype.py index 22827245..4859e24a 100644 --- a/stem/client/datatype.py +++ b/stem/client/datatype.py @@ -228,7 +228,7 @@ def split(content, size): return content[:size], content[size:]
-class LinkProtocol(collections.namedtuple('LinkProtocol', ['version', 'circ_id_size', 'fixed_cell_len', 'first_circ_id'])): +class LinkProtocol(int): """ Constants that vary by our link protocol version.
@@ -239,18 +239,35 @@ class LinkProtocol(collections.namedtuple('LinkProtocol', ['version', 'circ_id_s from a range that's determined by our link protocol. """
- @staticmethod - def for_version(version): + def __new__(cls, version): if isinstance(version, LinkProtocol): return version # already a LinkProtocol - elif isinstance(version, int): - circ_id_size = Size.LONG if version > 3 else Size.SHORT - fixed_cell_len = 514 if version > 3 else 512 - first_circ_id = 0x80000000 if version > 3 else 0x01
- return LinkProtocol(version, circ_id_size, fixed_cell_len, first_circ_id) + protocol = int.__new__(cls, version) + protocol.version = version + protocol.circ_id_size = Size.LONG if version > 3 else Size.SHORT + protocol.fixed_cell_length = 514 if version > 3 else 512 + protocol.first_circ_id = 0x80000000 if version > 3 else 0x01 + + return protocol + + def __hash__(self): + # All LinkProtocol attributes can be derived from our version, so that's + # all we need in our hash. Offsetting by our type so we don't hash conflict + # with ints. + + return self.version * hash(str(type(self))) + + def __eq__(self, other): + if isinstance(other, int): + return self.version == other + elif isinstance(other, LinkProtocol): + return hash(self) == hash(other) else: - raise TypeError('LinkProtocol.for_version() should receiving an int, not %s' % type(version).__name__) + return False + + def __ne__(self, other): + return not self == other
def __int__(self): return self.version diff --git a/test/settings.cfg b/test/settings.cfg index b8142b5c..8423614b 100644 --- a/test/settings.cfg +++ b/test/settings.cfg @@ -243,6 +243,7 @@ test.unit_tests |test.unit.response.mapaddress.TestMapAddressResponse |test.unit.client.size.TestSize |test.unit.client.address.TestAddress +|test.unit.client.link_protocol.TestLinkProtocol |test.unit.client.certificate.TestCertificate |test.unit.client.kdf.TestKDF |test.unit.client.cell.TestCell diff --git a/test/unit/client/kdf.py b/test/unit/client/kdf.py index b9fcb9ce..085a6657 100644 --- a/test/unit/client/kdf.py +++ b/test/unit/client/kdf.py @@ -1,5 +1,5 @@ """ -Unit tests for stem.client.KDF. +Unit tests for stem.client.datatype.KDF. """
import unittest diff --git a/test/unit/client/link_protocol.py b/test/unit/client/link_protocol.py new file mode 100644 index 00000000..ffaa30c2 --- /dev/null +++ b/test/unit/client/link_protocol.py @@ -0,0 +1,45 @@ +""" +Unit tests for stem.client.datatype.LinkProtocol. +""" + +import unittest + +from stem.client.datatype import Size, LinkProtocol + + +class TestLinkProtocol(unittest.TestCase): + def test_invalid_type(self): + self.assertRaises(ValueError, LinkProtocol, 'hello') + + def test_attributes(self): + protocol = LinkProtocol(1) + self.assertEqual(1, protocol.version) + self.assertEqual(Size.SHORT, protocol.circ_id_size) + self.assertEqual(512, protocol.fixed_cell_length) + self.assertEqual(0x01, protocol.first_circ_id) + + protocol = LinkProtocol(10) + self.assertEqual(10, protocol.version) + self.assertEqual(Size.LONG, protocol.circ_id_size) + self.assertEqual(514, protocol.fixed_cell_length) + self.assertEqual(0x80000000, protocol.first_circ_id) + + def test_use_as_int(self): + protocol = LinkProtocol(5) + + self.assertEqual(7, protocol + 2) + self.assertEqual(3, protocol - 2) + self.assertEqual(15, protocol * 3) + self.assertEqual(1, protocol / 3) + + def test_equality(self): + # LinkProtocols should be comparable with both other LinkProtocols and + # integers. + + protocol = LinkProtocol(1) + + self.assertEqual(LinkProtocol(1), protocol) + self.assertNotEqual(LinkProtocol(2), protocol) + + self.assertEqual(1, protocol) + self.assertNotEqual(2, protocol)