[tor-commits] [stem/master] Move circuit construction into the Relay class

atagar at torproject.org atagar at torproject.org
Wed Feb 7 19:44:51 UTC 2018


commit 2ca71ddb1d91154b7dd4a9c8e0c9e3bac06be204
Author: Damian Johnson <atagar at torproject.org>
Date:   Sun Feb 4 14:49:30 2018 -0800

    Move circuit construction into the Relay class
    
    Preferably I'd like to keep all socket activity within the Relay class. We
    might need to bend this in practice but lets first give it a try.
---
 stem/client/__init__.py |  68 +++++++++++++++++++++++++++++++-
 stem/client/cell.py     | 100 +++++++++++++++++++++++-------------------------
 stem/client/datatype.py |  51 ------------------------
 3 files changed, 114 insertions(+), 105 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 85473685..fdb436b6 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -19,12 +19,14 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
     +- close - shuts down our connection
 """
 
+import hashlib
+
 import stem
 import stem.client.cell
 import stem.socket
 import stem.util.connection
 
-from stem.client.datatype import AddrType, Address
+from stem.client.datatype import ZERO, AddrType, Address, KDF
 
 __all__ = [
   'cell',
@@ -44,6 +46,7 @@ class Relay(object):
   def __init__(self, orport, link_protocol):
     self.link_protocol = link_protocol
     self._orport = orport
+    self._circuits = {}
 
   @staticmethod
   def connect(address, port, link_protocols = DEFAULT_LINK_PROTOCOLS):
@@ -136,8 +139,71 @@ class Relay(object):
 
     return self._orport.close()
 
+  def create_circuit(self):
+    """
+    Establishes a new circuit.
+    """
+
+    # Find an unused circuit id. Since we're initiating the circuit we pick any
+    # value from a range that's determined by our link protocol.
+
+    circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
+
+    while circ_id in self._circuits:
+      circ_id += 1
+
+    create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
+    self._orport.send(create_fast_cell.pack(self.link_protocol))
+
+    response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol)
+    created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
+
+    if not created_fast_cells:
+      raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
+
+    created_fast_cell = created_fast_cells[0]
+    kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
+
+    if created_fast_cell.derivative_key != kdf.key_hash:
+      raise ValueError('Remote failed to prove that it knows our shared key')
+
+    circ = Circuit(self, circ_id, kdf)
+    self._circuits[circ.id] = circ
+
+    return circ
+
   def __enter__(self):
     return self
 
   def __exit__(self, exit_type, value, traceback):
     self.close()
+
+
+class Circuit(object):
+  """
+  Circuit through which requests can be made of a `Tor relay's ORPort
+  <https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt>`_.
+
+  :var stem.client.Relay relay: relay through which this circuit has been established
+  :var int id: circuit id
+  :var hashlib.sha1 forward_digest: digest for forward integrity check
+  :var hashlib.sha1 backward_digest: digest for backward integrity check
+  :var bytes forward_key: forward encryption key
+  :var bytes backward_key: backward encryption key
+  """
+
+  def __init__(self, relay, circ_id, kdf):
+    if not stem.prereq.is_crypto_available():
+      raise ImportError('Circuit construction requires the cryptography module')
+
+    from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+    from cryptography.hazmat.backends import default_backend
+
+    ctr = modes.CTR(ZERO * (algorithms.AES.block_size / 8))
+
+    self.relay = relay
+    self.id = circ_id
+    self.forward_digest = hashlib.sha1(kdf.forward_digest)
+    self.backward_digest = hashlib.sha1(kdf.backward_digest)
+    self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
+    self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
diff --git a/stem/client/cell.py b/stem/client/cell.py
index a018b2a3..ed8932c4 100644
--- a/stem/client/cell.py
+++ b/stem/client/cell.py
@@ -113,16 +113,16 @@ class Cell(object):
 
     raise ValueError("'%s' isn't a valid cell value" % value)
 
-  def pack(self, link_version):
+  def pack(self, link_protocol):
     raise NotImplementedError('Unpacking not yet implemented for %s cells' % type(self).NAME)
 
   @staticmethod
-  def unpack(content, link_version):
+  def unpack(content, link_protocol):
     """
     Unpacks all cells from a response.
 
     :param bytes content: payload to decode
-    :param int link_version: link protocol version
+    :param int link_protocol: link protocol version
 
     :returns: :class:`~stem.client.cell.Cell` generator
 
@@ -132,16 +132,16 @@ class Cell(object):
     """
 
     while content:
-      cell, content = Cell.pop(content, link_version)
+      cell, content = Cell.pop(content, link_protocol)
       yield cell
 
   @staticmethod
-  def pop(content, link_version):
+  def pop(content, link_protocol):
     """
     Unpacks the first cell.
 
     :param bytes content: payload to decode
-    :param int link_version: link protocol version
+    :param int link_protocol: link protocol version
 
     :returns: (:class:`~stem.client.cell.Cell`, remainder) tuple
 
@@ -150,7 +150,7 @@ class Cell(object):
       * NotImplementedError if unable to unpack this cell type
     """
 
-    circ_id, content = Size.SHORT.pop(content) if link_version < 4 else Size.LONG.pop(content)
+    circ_id, content = Size.SHORT.pop(content) if link_protocol < 4 else Size.LONG.pop(content)
     command, content = Size.CHAR.pop(content)
     cls = Cell.by_value(command)
 
@@ -163,10 +163,10 @@ class Cell(object):
       raise ValueError('%s cell should have a payload of %i bytes, but only had %i' % (cls.NAME, payload_len, len(content)))
 
     payload, content = split(content, payload_len)
-    return cls._unpack(payload, circ_id, link_version), content
+    return cls._unpack(payload, circ_id, link_protocol), content
 
   @classmethod
-  def _pack(cls, link_version, payload, circ_id = 0):
+  def _pack(cls, link_protocol, payload, circ_id = 0):
     """
     Provides bytes that can be used on the wire for these cell attributes.
     Format of a properly packed cell depends on if it's fixed or variable
@@ -178,7 +178,7 @@ class Cell(object):
       Variable: [ CircuitID ][ Command ][ Size ][ Payload ]
 
     :param str name: cell command
-    :param int link_version: link protocol version
+    :param int link_protocol: link protocol version
     :param bytes payload: cell payload
     :param int circ_id: circuit id, if a CircuitCell
 
@@ -188,16 +188,10 @@ class Cell(object):
     """
 
     if isinstance(cls, CircuitCell) and circ_id is None:
-      if cls.NAME.startswith('CREATE'):
-        # Since we're initiating the circuit we pick any value from a range
-        # that's determined by our link version.
-
-        circ_id = 0x80000000 if link_version > 3 else 0x01
-      else:
-        raise ValueError('%s cells require a circ_id' % cls.NAME)
+      raise ValueError('%s cells require a circ_id' % cls.NAME)
 
     cell = io.BytesIO()
-    cell.write(Size.LONG.pack(circ_id) if link_version > 3 else Size.SHORT.pack(circ_id))
+    cell.write(Size.LONG.pack(circ_id) if link_protocol > 3 else Size.SHORT.pack(circ_id))
     cell.write(Size.CHAR.pack(cls.VALUE))
     cell.write(b'' if cls.IS_FIXED_SIZE else Size.SHORT.pack(len(payload)))
     cell.write(payload)
@@ -206,7 +200,7 @@ class Cell(object):
 
     if cls.IS_FIXED_SIZE:
       cell_size = cell.seek(0, io.SEEK_END)
-      fixed_cell_len = 514 if link_version > 3 else 512
+      fixed_cell_len = 514 if link_protocol > 3 else 512
 
       if cell_size > fixed_cell_len:
         raise ValueError('Payload of %s is too large (%i bytes), must be less than %i' % (cls.NAME, cell_size, fixed_cell_len))
@@ -216,12 +210,12 @@ class Cell(object):
     return cell.getvalue()
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     """
     Subclass implementation for unpacking cell content.
 
     :param bytes content: payload to decode
-    :param int link_version: link protocol version
+    :param int link_protocol: link protocol version
     :param int circ_id: circuit id cell is for
 
     :returns: instance of this cell type
@@ -268,11 +262,11 @@ class PaddingCell(Cell):
 
     self.payload = payload
 
-  def pack(self, link_version):
-    return PaddingCell._pack(link_version, self.payload)
+  def pack(self, link_protocol):
+    return PaddingCell._pack(link_protocol, self.payload)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     return PaddingCell(content)
 
   def __hash__(self):
@@ -325,7 +319,7 @@ class RelayCell(CircuitCell):
     elif stream_id and self.command in STREAM_ID_DISALLOWED:
       raise ValueError('%s relay cells concern the circuit itself and cannot have a stream id' % self.command)
 
-  def pack(self, link_version):
+  def pack(self, link_protocol):
     payload = io.BytesIO()
     payload.write(Size.CHAR.pack(self.command_int))
     payload.write(Size.SHORT.pack(0))  # 'recognized' field
@@ -334,10 +328,10 @@ class RelayCell(CircuitCell):
     payload.write(Size.SHORT.pack(len(self.data)))
     payload.write(self.data)
 
-    return RelayCell._pack(link_version, payload.getvalue(), self.circ_id)
+    return RelayCell._pack(link_protocol, payload.getvalue(), self.circ_id)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     command, content = Size.CHAR.pop(content)
     _, content = Size.SHORT.pop(content)  # 'recognized' field
     stream_id, content = Size.SHORT.pop(content)
@@ -367,11 +361,11 @@ class DestroyCell(CircuitCell):
     super(DestroyCell, self).__init__(circ_id)
     self.reason, self.reason_int = CloseReason.get(reason)
 
-  def pack(self, link_version):
-    return DestroyCell._pack(link_version, Size.CHAR.pack(self.reason_int), self.circ_id)
+  def pack(self, link_protocol):
+    return DestroyCell._pack(link_protocol, Size.CHAR.pack(self.reason_int), self.circ_id)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     content = content.rstrip(ZERO)
 
     if not content:
@@ -406,11 +400,11 @@ class CreateFastCell(CircuitCell):
     super(CreateFastCell, self).__init__(circ_id)
     self.key_material = key_material
 
-  def pack(self, link_version):
-    return CreateFastCell._pack(link_version, self.key_material, self.circ_id)
+  def pack(self, link_protocol):
+    return CreateFastCell._pack(link_protocol, self.key_material, self.circ_id)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     content = content.rstrip(ZERO)
 
     if len(content) != HASH_LEN:
@@ -447,11 +441,11 @@ class CreatedFastCell(CircuitCell):
     self.key_material = key_material
     self.derivative_key = derivative_key
 
-  def pack(self, link_version):
-    return CreatedFastCell._pack(link_version, self.key_material + self.derivative_key, self.circ_id)
+  def pack(self, link_protocol):
+    return CreatedFastCell._pack(link_protocol, self.key_material + self.derivative_key, self.circ_id)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     content = content.rstrip(ZERO)
 
     if len(content) != HASH_LEN * 2:
@@ -477,7 +471,7 @@ class VersionsCell(Cell):
   def __init__(self, versions):
     self.versions = versions
 
-  def pack(self, link_version = None):
+  def pack(self, link_protocol = None):
     # Used for link version negotiation so we don't have that yet. This is fine
     # since VERSION cells avoid most version dependent attributes.
 
@@ -485,14 +479,14 @@ class VersionsCell(Cell):
     return VersionsCell._pack(2, payload)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
-    link_versions = []
+  def _unpack(cls, content, circ_id, link_protocol):
+    link_protocols = []
 
     while content:
       version, content = Size.SHORT.pop(content)
-      link_versions.append(version)
+      link_protocols.append(version)
 
-    return VersionsCell(link_versions)
+    return VersionsCell(link_protocols)
 
   def __hash__(self):
     return _hash_attr(self, 'versions')
@@ -516,7 +510,7 @@ class NetinfoCell(Cell):
     self.receiver_address = receiver_address
     self.sender_addresses = sender_addresses
 
-  def pack(self, link_version):
+  def pack(self, link_protocol):
     payload = io.BytesIO()
     payload.write(Size.LONG.pack(int(datetime_to_unix(self.timestamp))))
     payload.write(self.receiver_address.pack())
@@ -525,10 +519,10 @@ class NetinfoCell(Cell):
     for addr in self.sender_addresses:
       payload.write(addr.pack())
 
-    return NetinfoCell._pack(link_version, payload.getvalue())
+    return NetinfoCell._pack(link_protocol, payload.getvalue())
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     if len(content) < 4:
       raise ValueError('NETINFO cell expected to start with a timestamp')
 
@@ -591,11 +585,11 @@ class VPaddingCell(Cell):
 
     self.payload = payload
 
-  def pack(self, link_version):
-    return VPaddingCell._pack(link_version, self.payload)
+  def pack(self, link_protocol):
+    return VPaddingCell._pack(link_protocol, self.payload)
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     return VPaddingCell(payload = content)
 
   def __hash__(self):
@@ -616,11 +610,11 @@ class CertsCell(Cell):
   def __init__(self, certs):
     self.certificates = certs
 
-  def pack(self, link_version):
-    return CertsCell._pack(link_version, Size.CHAR.pack(len(self.certificates)) + ''.join([cert.pack() for cert in self.certificates]))
+  def pack(self, link_protocol):
+    return CertsCell._pack(link_protocol, Size.CHAR.pack(len(self.certificates)) + ''.join([cert.pack() for cert in self.certificates]))
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     cert_count, content = Size.CHAR.pop(content)
     certs = []
 
@@ -659,7 +653,7 @@ class AuthChallengeCell(Cell):
     self.challenge = challenge
     self.methods = methods
 
-  def pack(self, link_version):
+  def pack(self, link_protocol):
     payload = io.BytesIO()
     payload.write(self.challenge)
     payload.write(Size.SHORT.pack(len(self.methods)))
@@ -667,10 +661,10 @@ class AuthChallengeCell(Cell):
     for method in self.methods:
       payload.write(Size.SHORT.pack(method))
 
-    return AuthChallengeCell._pack(link_version, payload.getvalue())
+    return AuthChallengeCell._pack(link_protocol, payload.getvalue())
 
   @classmethod
-  def _unpack(cls, content, circ_id, link_version):
+  def _unpack(cls, content, circ_id, link_protocol):
     if len(content) < AUTH_CHALLENGE_SIZE + 2:
       raise ValueError('AUTH_CHALLENGE payload should be at least 34 bytes, but was %i' % len(content))
 
diff --git a/stem/client/datatype.py b/stem/client/datatype.py
index f5805a2a..8b4c8e64 100644
--- a/stem/client/datatype.py
+++ b/stem/client/datatype.py
@@ -476,57 +476,6 @@ class KDF(collections.namedtuple('KDF', ['key_hash', 'forward_digest', 'backward
     return KDF(key_hash, forward_digest, backward_digest, forward_key, backward_key)
 
 
-class Circuit(collections.namedtuple('Circuit', ['socket', 'id', 'forward_digest', 'backward_digest', 'forward_key', 'backward_key'])):
-  """
-  Circuit through which requests can be made of a `Tor relay's ORPort
-  <https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt>`_.
-
-  :var stem.socket.RelaySocket socket: socket through which this circuit has been established
-  :var int id: circuit id
-  :var hashlib.sha1 forward_digest: digest for forward integrity check
-  :var hashlib.sha1 backward_digest: digest for backward integrity check
-  :var bytes forward_key: forward encryption key
-  :var bytes backward_key: backward encryption key
-  """
-
-  @staticmethod
-  def create(relay_socket, circ_id, link_version):
-    """
-    Constructs a new circuit over the given ORPort.
-    """
-
-    if not stem.prereq.is_crypto_available():
-      raise ImportError('Circuit construction requires the cryptography module')
-
-    from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-    from cryptography.hazmat.backends import default_backend
-
-    create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
-    relay_socket.send(create_fast_cell.pack(link_version))
-
-    response = stem.client.cell.Cell.unpack(relay_socket.recv(), link_version)
-    created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
-
-    if not created_fast_cells:
-      raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
-
-    created_fast_cell = created_fast_cells[0]
-    kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
-    ctr = modes.CTR(ZERO * (algorithms.AES.block_size / 8))
-
-    if created_fast_cell.derivative_key != kdf.key_hash:
-      raise ValueError('Remote failed to prove that it knows our shared key')
-
-    return Circuit(
-      relay_socket,
-      circ_id,
-      hashlib.sha1(kdf.forward_digest),
-      hashlib.sha1(kdf.backward_digest),
-      Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor(),
-      Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor(),
-    )
-
-
 setattr(Size, 'CHAR', Size('CHAR', 1, '!B'))
 setattr(Size, 'SHORT', Size('SHORT', 2, '!H'))
 setattr(Size, 'LONG', Size('LONG', 4, '!L'))





More information about the tor-commits mailing list