[tor-commits] [stem/master] Make Relay thread safe

atagar at torproject.org atagar at torproject.org
Wed Feb 7 19:44:51 UTC 2018


commit 331127483838c416b279d30cf041deb678984ab2
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Feb 6 09:53:08 2018 -0800

    Make Relay thread safe
---
 stem/client/__init__.py | 91 +++++++++++++++++++++++++++----------------------
 1 file changed, 51 insertions(+), 40 deletions(-)

diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 8d34b626..5a9d09e6 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -16,11 +16,18 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
     |
     |- is_alive - reports if our connection is open or closed
     |- connection_time - time when we last connected or disconnected
-    +- close - shuts down our connection
+    |- close - shuts down our connection
+    |
+    +- create_circuit - establishes a new circuit
+
+  Circuit - Circuit we've established through a relay.
+    |- send - sends a message through this circuit
+    +- close - closes this circuit
 """
 
 import copy
 import hashlib
+import threading
 
 import stem
 import stem.client.cell
@@ -47,6 +54,7 @@ class Relay(object):
   def __init__(self, orport, link_protocol):
     self.link_protocol = link_protocol
     self._orport = orport
+    self._orport_lock = threading.RLock()
     self._circuits = {}
 
   @staticmethod
@@ -138,40 +146,42 @@ class Relay(object):
     :func:`~stem.socket.BaseSocket.close` method.
     """
 
-    return self._orport.close()
+    with self._orport_lock:
+      return self._orport.close()
 
   def create_circuit(self):
     """
     Establishes a new circuit.
     """
 
-    # Find an unused circuit id. Since we're initiating the circuit we pick any
-    # value from a range that's determined by our link protocol.
+    with self._orport_lock:
+      # Find an unused circuit id. Since we're initiating the circuit we pick any
+      # value from a range that's determined by our link protocol.
 
-    circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
+      circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
 
-    while circ_id in self._circuits:
-      circ_id += 1
+      while circ_id in self._circuits:
+        circ_id += 1
 
-    create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
-    self._orport.send(create_fast_cell.pack(self.link_protocol))
+      create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
+      self._orport.send(create_fast_cell.pack(self.link_protocol))
 
-    response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol)
-    created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
+      response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol)
+      created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
 
-    if not created_fast_cells:
-      raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
+      if not created_fast_cells:
+        raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
 
-    created_fast_cell = created_fast_cells[0]
-    kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
+      created_fast_cell = created_fast_cells[0]
+      kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
 
-    if created_fast_cell.derivative_key != kdf.key_hash:
-      raise ValueError('Remote failed to prove that it knows our shared key')
+      if created_fast_cell.derivative_key != kdf.key_hash:
+        raise ValueError('Remote failed to prove that it knows our shared key')
 
-    circ = Circuit(self, circ_id, kdf)
-    self._circuits[circ.id] = circ
+      circ = Circuit(self, circ_id, kdf)
+      self._circuits[circ.id] = circ
 
-    return circ
+      return circ
 
   def __enter__(self):
     return self
@@ -219,30 +229,31 @@ class Circuit(object):
     """
 
     # TODO: move RelayCommand to this base module?
-    # TODO: add lock
 
-    orig_digest = self.forward_digest.copy()
-    orig_key = copy.copy(self.forward_key)
+    with self.relay._orport_lock:
+      orig_digest = self.forward_digest.copy()
+      orig_key = copy.copy(self.forward_key)
 
-    try:
-      cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id)
-      payload_without_digest = cell.pack(self.relay.link_protocol)[3:]
-      self.forward_digest.update(payload_without_digest)
+      try:
+        cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id)
+        payload_without_digest = cell.pack(self.relay.link_protocol)[3:]
+        self.forward_digest.update(payload_without_digest)
 
-      cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id)
-      header, payload = split(cell.pack(self.relay.link_protocol), 3)
-      encrypted_payload = header + self.forward_key.update(payload)
+        cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id)
+        header, payload = split(cell.pack(self.relay.link_protocol), 3)
+        encrypted_payload = header + self.forward_key.update(payload)
 
-      self.relay._orport.send(encrypted_payload)
-      reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol))
+        self.relay._orport.send(encrypted_payload)
+        reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol))
 
-      decrypted = self.backward_key.update(reply.pack(3)[3:])
-      return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3)
-    except:
-      self.forward_digest = orig_digest
-      self.forward_key = orig_key
-      raise
+        decrypted = self.backward_key.update(reply.pack(3)[3:])
+        return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3)
+      except:
+        self.forward_digest = orig_digest
+        self.forward_key = orig_key
+        raise
 
   def close(self):
-    self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
-    del self.relay._circuits[self.id]
+    with self.relay._orport_lock:
+      self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
+      del self.relay._circuits[self.id]





More information about the tor-commits mailing list