commit 331127483838c416b279d30cf041deb678984ab2 Author: Damian Johnson atagar@torproject.org Date: Tue Feb 6 09:53:08 2018 -0800
Make Relay thread safe --- stem/client/__init__.py | 91 +++++++++++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 40 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 8d34b626..5a9d09e6 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -16,11 +16,18 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as | |- is_alive - reports if our connection is open or closed |- connection_time - time when we last connected or disconnected - +- close - shuts down our connection + |- close - shuts down our connection + | + +- create_circuit - establishes a new circuit + + Circuit - Circuit we've established through a relay. + |- send - sends a message through this circuit + +- close - closes this circuit """
import copy import hashlib +import threading
import stem import stem.client.cell @@ -47,6 +54,7 @@ class Relay(object): def __init__(self, orport, link_protocol): self.link_protocol = link_protocol self._orport = orport + self._orport_lock = threading.RLock() self._circuits = {}
@staticmethod @@ -138,40 +146,42 @@ class Relay(object): :func:`~stem.socket.BaseSocket.close` method. """
- return self._orport.close() + with self._orport_lock: + return self._orport.close()
def create_circuit(self): """ Establishes a new circuit. """
- # Find an unused circuit id. Since we're initiating the circuit we pick any - # value from a range that's determined by our link protocol. + with self._orport_lock: + # Find an unused circuit id. Since we're initiating the circuit we pick any + # value from a range that's determined by our link protocol.
- circ_id = 0x80000000 if self.link_protocol > 3 else 0x01 + circ_id = 0x80000000 if self.link_protocol > 3 else 0x01
- while circ_id in self._circuits: - circ_id += 1 + while circ_id in self._circuits: + circ_id += 1
- create_fast_cell = stem.client.cell.CreateFastCell(circ_id) - self._orport.send(create_fast_cell.pack(self.link_protocol)) + create_fast_cell = stem.client.cell.CreateFastCell(circ_id) + self._orport.send(create_fast_cell.pack(self.link_protocol))
- response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol) - created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response) + response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol) + created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response)
- if not created_fast_cells: - raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request') + if not created_fast_cells: + raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request')
- created_fast_cell = created_fast_cells[0] - kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material) + created_fast_cell = created_fast_cells[0] + kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material)
- if created_fast_cell.derivative_key != kdf.key_hash: - raise ValueError('Remote failed to prove that it knows our shared key') + if created_fast_cell.derivative_key != kdf.key_hash: + raise ValueError('Remote failed to prove that it knows our shared key')
- circ = Circuit(self, circ_id, kdf) - self._circuits[circ.id] = circ + circ = Circuit(self, circ_id, kdf) + self._circuits[circ.id] = circ
- return circ + return circ
def __enter__(self): return self @@ -219,30 +229,31 @@ class Circuit(object): """
# TODO: move RelayCommand to this base module? - # TODO: add lock
- orig_digest = self.forward_digest.copy() - orig_key = copy.copy(self.forward_key) + with self.relay._orport_lock: + orig_digest = self.forward_digest.copy() + orig_key = copy.copy(self.forward_key)
- try: - cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id) - payload_without_digest = cell.pack(self.relay.link_protocol)[3:] - self.forward_digest.update(payload_without_digest) + try: + cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id) + payload_without_digest = cell.pack(self.relay.link_protocol)[3:] + self.forward_digest.update(payload_without_digest)
- cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id) - header, payload = split(cell.pack(self.relay.link_protocol), 3) - encrypted_payload = header + self.forward_key.update(payload) + cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id) + header, payload = split(cell.pack(self.relay.link_protocol), 3) + encrypted_payload = header + self.forward_key.update(payload)
- self.relay._orport.send(encrypted_payload) - reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol)) + self.relay._orport.send(encrypted_payload) + reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol))
- decrypted = self.backward_key.update(reply.pack(3)[3:]) - return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3) - except: - self.forward_digest = orig_digest - self.forward_key = orig_key - raise + decrypted = self.backward_key.update(reply.pack(3)[3:]) + return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3) + except: + self.forward_digest = orig_digest + self.forward_key = orig_key + raise
def close(self): - self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) - del self.relay._circuits[self.id] + with self.relay._orport_lock: + self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) + del self.relay._circuits[self.id]