
commit 331127483838c416b279d30cf041deb678984ab2 Author: Damian Johnson <atagar@torproject.org> Date: Tue Feb 6 09:53:08 2018 -0800 Make Relay thread safe --- stem/client/__init__.py | 91 +++++++++++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/stem/client/__init__.py b/stem/client/__init__.py index 8d34b626..5a9d09e6 100644 --- a/stem/client/__init__.py +++ b/stem/client/__init__.py @@ -16,11 +16,18 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as | |- is_alive - reports if our connection is open or closed |- connection_time - time when we last connected or disconnected - +- close - shuts down our connection + |- close - shuts down our connection + | + +- create_circuit - establishes a new circuit + + Circuit - Circuit we've established through a relay. + |- send - sends a message through this circuit + +- close - closes this circuit """ import copy import hashlib +import threading import stem import stem.client.cell @@ -47,6 +54,7 @@ class Relay(object): def __init__(self, orport, link_protocol): self.link_protocol = link_protocol self._orport = orport + self._orport_lock = threading.RLock() self._circuits = {} @staticmethod @@ -138,40 +146,42 @@ class Relay(object): :func:`~stem.socket.BaseSocket.close` method. """ - return self._orport.close() + with self._orport_lock: + return self._orport.close() def create_circuit(self): """ Establishes a new circuit. """ - # Find an unused circuit id. Since we're initiating the circuit we pick any - # value from a range that's determined by our link protocol. + with self._orport_lock: + # Find an unused circuit id. Since we're initiating the circuit we pick any + # value from a range that's determined by our link protocol. - circ_id = 0x80000000 if self.link_protocol > 3 else 0x01 + circ_id = 0x80000000 if self.link_protocol > 3 else 0x01 - while circ_id in self._circuits: - circ_id += 1 + while circ_id in self._circuits: + circ_id += 1 - create_fast_cell = stem.client.cell.CreateFastCell(circ_id) - self._orport.send(create_fast_cell.pack(self.link_protocol)) + create_fast_cell = stem.client.cell.CreateFastCell(circ_id) + self._orport.send(create_fast_cell.pack(self.link_protocol)) - response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol) - created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response) + response = stem.client.cell.Cell.unpack(self._orport.recv(), self.link_protocol) + created_fast_cells = filter(lambda cell: isinstance(cell, stem.client.cell.CreatedFastCell), response) - if not created_fast_cells: - raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request') + if not created_fast_cells: + raise ValueError('We should get a CREATED_FAST response from a CREATE_FAST request') - created_fast_cell = created_fast_cells[0] - kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material) + created_fast_cell = created_fast_cells[0] + kdf = KDF.from_value(create_fast_cell.key_material + created_fast_cell.key_material) - if created_fast_cell.derivative_key != kdf.key_hash: - raise ValueError('Remote failed to prove that it knows our shared key') + if created_fast_cell.derivative_key != kdf.key_hash: + raise ValueError('Remote failed to prove that it knows our shared key') - circ = Circuit(self, circ_id, kdf) - self._circuits[circ.id] = circ + circ = Circuit(self, circ_id, kdf) + self._circuits[circ.id] = circ - return circ + return circ def __enter__(self): return self @@ -219,30 +229,31 @@ class Circuit(object): """ # TODO: move RelayCommand to this base module? - # TODO: add lock - orig_digest = self.forward_digest.copy() - orig_key = copy.copy(self.forward_key) + with self.relay._orport_lock: + orig_digest = self.forward_digest.copy() + orig_key = copy.copy(self.forward_key) - try: - cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id) - payload_without_digest = cell.pack(self.relay.link_protocol)[3:] - self.forward_digest.update(payload_without_digest) + try: + cell = stem.client.cell.RelayCell(self.id, command, data, 0, stream_id) + payload_without_digest = cell.pack(self.relay.link_protocol)[3:] + self.forward_digest.update(payload_without_digest) - cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id) - header, payload = split(cell.pack(self.relay.link_protocol), 3) - encrypted_payload = header + self.forward_key.update(payload) + cell = stem.client.cell.RelayCell(self.id, command, data, self.forward_digest, stream_id) + header, payload = split(cell.pack(self.relay.link_protocol), 3) + encrypted_payload = header + self.forward_key.update(payload) - self.relay._orport.send(encrypted_payload) - reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol)) + self.relay._orport.send(encrypted_payload) + reply = next(stem.client.cell.Cell.unpack(self.relay._orport.recv(), self.relay.link_protocol)) - decrypted = self.backward_key.update(reply.pack(3)[3:]) - return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3) - except: - self.forward_digest = orig_digest - self.forward_key = orig_key - raise + decrypted = self.backward_key.update(reply.pack(3)[3:]) + return stem.client.cell.RelayCell._unpack(decrypted, self.id, 3) + except: + self.forward_digest = orig_digest + self.forward_key = orig_key + raise def close(self): - self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) - del self.relay._circuits[self.id] + with self.relay._orport_lock: + self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol)) + del self.relay._circuits[self.id]
participants (1)
-
atagar@torproject.org