[tor-commits] [stem/master] Localize protocol attributes into a LinkProtocol class

atagar at torproject.org atagar at torproject.org
Sun Jun 17 00:23:10 UTC 2018


commit 526dc8b27f391d7c24136a2bdc461e514c7da07c
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Jun 16 16:36:04 2018 -0700

    Localize protocol attributes into a LinkProtocol class
    
    Several constants vary depending on our link protocol version. Rather than
    encoding a bunch of 'x if link_protocol > y else z' conditionals centralizing
    this in a LinkProtocol class.
    
    LinkPrococol.for_version() normalizes an integer *or* LinkProtocol into a
    LinkProtocol.
    
    Pattern we should follow is:
    
      * Public methods should always accept an integer.
    
      * Any time we use or store a link protocol version it's normalized to a
        LinkPrococol class.
---
 stem/client/__init__.py | 16 +++++-----------
 stem/client/cell.py     | 36 +++++++++++-------------------------
 stem/client/datatype.py | 30 ++++++++++++++++++++++++++++++
 3 files changed, 46 insertions(+), 36 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 1d87966f..c0099182 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -34,7 +34,7 @@ import stem.client.cell
 import stem.socket
 import stem.util.connection
 
-from stem.client.datatype import ZERO, Address, Size, KDF, split
+from stem.client.datatype import ZERO, LinkProtocol, Address, Size, KDF, split
 
 __all__ = [
   'cell',
@@ -52,7 +52,7 @@ class Relay(object):
   """
 
   def __init__(self, orport, link_protocol):
-    self.link_protocol = link_protocol
+    self.link_protocol = LinkProtocol.for_version(link_protocol)
     self._orport = orport
     self._orport_lock = threading.RLock()
     self._circuits = {}
@@ -151,13 +151,7 @@ class Relay(object):
     """
 
     with self._orport_lock:
-      # Find an unused circuit id. Since we're initiating the circuit we pick any
-      # value from a range that's determined by our link protocol.
-
-      circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
-
-      while circ_id in self._circuits:
-        circ_id += 1
+      circ_id = max(self._circuits) + 1 if self._circuits else self.link_protocol.first_circ_id
 
       create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
       self._orport.send(create_fast_cell.pack(self.link_protocol))
@@ -239,7 +233,7 @@ class Circuit(object):
       # doesn't include the initial circuit id and cell type fields.
       # Circuit ids vary in length depending on the protocol version.
 
-      header_size = 5 if self.relay.link_protocol > 3 else 3
+      header_size = self.relay.link_protocol.circ_id_size.size + 1
 
       try:
         cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id)
@@ -263,7 +257,7 @@ class Circuit(object):
           raise stem.ProtocolError('Circuit response should be a series of RELAY cells, but received an unexpected size for a response: %i' % len(reply))
 
         while reply:
-          circ_id, reply = Size.SHORT.pop(reply) if self.relay.link_protocol < 4 else Size.LONG.pop(reply)
+          circ_id, reply = self.relay.link_protocol.circ_id_size.pop(reply)
           command, reply = Size.CHAR.pop(reply)
           payload, reply = split(reply, stem.client.cell.FIXED_PAYLOAD_LEN)
 
diff --git a/stem/client/cell.py b/stem/client/cell.py
index e2aba8b6..7e3e78ef 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -45,7 +45,7 @@ import sys
 import stem.util
 
 from stem import UNDEFINED
-from stem.client.datatype import HASH_LEN, ZERO, Address, Certificate, CloseReason, RelayCommand, Size, split
+from stem.client.datatype import HASH_LEN, ZERO, LinkProtocol, Address, Certificate, CloseReason, RelayCommand, Size, split
 from stem.util import _hash_attr, datetime_to_unix, str_tools
 
 FIXED_PAYLOAD_LEN = 509
@@ -113,22 +113,6 @@ class Cell(object):
 
     raise ValueError("'%s' isn't a valid cell value" % value)
 
-  @staticmethod
-  def _get_circ_id_size(link_protocol):
-    """
-    Gets the proper Size for the link_protocol.
-
-    :param int link_protocol: link protocol version
-
-    :returns: :class:`~stem.client.datatype.Size`
-    """
-
-    # per tor-spec section 3
-    # CIRCID_LEN :=
-    #   2 for link protocol versions 1, 2, and 3
-    #   4 for link protocol versions 4+
-    return Size.LONG if link_protocol >= 4 else Size.SHORT
-
   def pack(self, link_protocol):
     raise NotImplementedError('Unpacking not yet implemented for %s cells' % type(self).NAME)
 
@@ -166,7 +150,9 @@ class Cell(object):
       * NotImplementedError if unable to unpack this cell type
     """
 
-    circ_id, content = Cell._get_circ_id_size(link_protocol).pop(content)
+    link_protocol = LinkProtocol.for_version(link_protocol)
+
+    circ_id, content = link_protocol.circ_id_size.pop(content)
     command, content = Size.CHAR.pop(content)
     cls = Cell.by_value(command)
 
@@ -206,8 +192,10 @@ class Cell(object):
     if isinstance(cls, CircuitCell) and circ_id is None:
       raise ValueError('%s cells require a circ_id' % cls.NAME)
 
+    link_protocol = LinkProtocol.for_version(link_protocol)
+
     cell = bytearray()
-    cell += Cell._get_circ_id_size(link_protocol).pack(circ_id)
+    cell += link_protocol.circ_id_size.pack(circ_id)
     cell += Size.CHAR.pack(cls.VALUE)
     cell += b'' if cls.IS_FIXED_SIZE else Size.SHORT.pack(len(payload))
     cell += payload
@@ -215,12 +203,10 @@ class Cell(object):
     # pad fixed sized cells to the required length
 
     if cls.IS_FIXED_SIZE:
-      fixed_cell_len = 514 if link_protocol > 3 else 512
-
-      if len(cell) > fixed_cell_len:
-        raise ValueError('Payload of %s is too large (%i bytes), must be less than %i' % (cls.NAME, len(cell), fixed_cell_len))
+      if len(cell) > link_protocol.fixed_cell_len:
+        raise ValueError('Payload of %s is too large (%i bytes), must be less than %i' % (cls.NAME, len(cell), link_protocol.fixed_cell_len))
 
-      cell += ZERO * (fixed_cell_len - len(cell))
+      cell += ZERO * (link_protocol.fixed_cell_len - len(cell))
 
     return bytes(cell)
 
@@ -230,7 +216,7 @@ class Cell(object):
     Subclass implementation for unpacking cell content.
 
     :param bytes content: payload to decode
-    :param int link_protocol: link protocol version
+    :param stem.client.datatype.LinkProtocol link_protocol: link protocol version
     :param int circ_id: circuit id cell is for
 
     :returns: instance of this cell type
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index 75fae662..9812ced0 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -14,6 +14,8 @@ users.** See our :class:`~stem.client.Relay` the API you probably want.
 
   split - splits bytes into substrings
 
+  LinkProtocol - ORPort protocol version.
+
   Field - Packable and unpackable datatype.
     |- Size - Field of a static size.
     |- Address - Relay address.
@@ -228,6 +230,34 @@ def split(content, size):
   return content[:size], content[size:]
 
 
+class LinkProtocol(collections.namedtuple('LinkProtocol', ['version', 'circ_id_size', 'fixed_cell_len', 'first_circ_id'])):
+  """
+  Constants that vary by our link protocol version.
+
+  :var int version: link protocol version
+  :var stem.client.datatype.Size circ_id_size: circuit identifier field size
+  :var int fixed_cell_length: size of cells with a fixed length
+  :var int first_circ_id: When creating circuits we pick an unused identifier
+    from a range that's determined by our link protocol.
+  """
+
+  @staticmethod
+  def for_version(version):
+    if isinstance(version, LinkProtocol):
+      return version  # already a LinkProtocol
+    elif isinstance(version, int):
+      circ_id_size = Size.LONG if version > 3 else Size.SHORT
+      fixed_cell_len = 514 if version > 3 else 512
+      first_circ_id = 0x80000000 if version > 3 else 0x01
+
+      return LinkProtocol(version, circ_id_size, fixed_cell_len, first_circ_id)
+    else:
+      raise TypeError('LinkProtocol.for_version() should receiving an int, not %s' % type(version).__name__)
+
+  def __int__(self):
+    return self.version
+
+
 class Field(object):
   """
   Packable and unpackable datatype.





More information about the tor-commits mailing list