[tor-commits] [stem/master] Switch to asyncio locks in socket classes to make them usable too

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:29:00 UTC 2020


commit cad4b7204b7bbec52f3eb5b04811f332a12aa85d
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Sun May 24 02:23:16 2020 +0300

    Switch to asyncio locks in socket classes to make them usable too
---
 stem/control.py | 18 +++++++----
 stem/socket.py  | 93 ++++++++++++++++++++++++++++-----------------------------
 2 files changed, 58 insertions(+), 53 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 6ca6c23b..293d4bd3 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -858,7 +858,7 @@ class BaseController(_BaseControllerSocketMixin):
 
   async def _connect(self) -> None:
     self._create_loop_tasks()
-    self._notify_status_listeners(State.INIT)
+    await self._notify_status_listeners(State.INIT, acquire_send_lock=False)
     await self._socket_connect()
     self._is_authenticated = False
 
@@ -879,7 +879,7 @@ class BaseController(_BaseControllerSocketMixin):
     if event_loop_task:
       await event_loop_task
 
-    self._notify_status_listeners(State.CLOSED)
+    await self._notify_status_listeners(State.CLOSED, acquire_send_lock=False)
 
     await self._socket_close()
 
@@ -888,7 +888,7 @@ class BaseController(_BaseControllerSocketMixin):
 
     self._is_authenticated = True
 
-  def _notify_status_listeners(self, state: 'stem.control.State') -> None:
+  async def _notify_status_listeners(self, state: 'stem.control.State', acquire_send_lock: bool = True) -> None:
     """
     Informs our status listeners that a state change occurred.
 
@@ -898,7 +898,10 @@ class BaseController(_BaseControllerSocketMixin):
     # Any changes to our is_alive() state happen under the send lock, so we
     # need to have it to ensure it doesn't change beneath us.
 
-    with self._socket._get_send_lock():
+    send_lock = self._socket._get_send_lock()
+    try:
+      if acquire_send_lock:
+        await send_lock.acquire()
       with self._status_listeners_lock:
         # States imply that our socket is either alive or not, which may not
         # hold true when multiple events occur in quick succession. For
@@ -931,6 +934,9 @@ class BaseController(_BaseControllerSocketMixin):
             self._state_change_threads.append(notice_thread)
           else:
             listener(self, state, change_timestamp)
+    finally:
+      if acquire_send_lock:
+        send_lock.release()
 
   def _create_loop_tasks(self) -> None:
     """
@@ -1064,10 +1070,10 @@ class AsyncController(BaseController):
 
     super(AsyncController, self).__init__(control_socket, is_authenticated)
 
-    def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
+    async def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
       if event.signal == Signal.RELOAD:
         self.clear_cache()
-        self._notify_status_listeners(State.RESET)
+        await self._notify_status_listeners(State.RESET)
 
     def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None:
       if self.is_caching_enabled():
diff --git a/stem/socket.py b/stem/socket.py
index ff99c5b1..a0e5bf55 100644
--- a/stem/socket.py
+++ b/stem/socket.py
@@ -86,7 +86,6 @@ import re
 import socket
 import ssl
 import sys
-import threading
 import time
 
 import stem.response
@@ -119,8 +118,14 @@ class BaseSocket(object):
     # so prevents deadlock where we block writes because we're waiting to read
     # a message that isn't coming.
 
-    self._send_lock = threading.RLock()
-    self._recv_lock = threading.RLock()
+    # The class is often initialized in a thread with an event loop different
+    # from one where it will be used. The asyncio locks are bound to the loop
+    # running in a thread where they are initialized. Therefore, we are
+    # creating them in _get_send_lock and _get_recv_lock when they are used the
+    # first time.
+
+    self._send_lock = None  # type: Optional[asyncio.Lock]
+    self._recv_lock = None  # type: Optional[asyncio.Lock]
 
   def is_alive(self) -> bool:
     """
@@ -173,15 +178,15 @@ class BaseSocket(object):
     :raises: :class:`stem.SocketError` if unable to make a socket
     """
 
-    with self._send_lock:
+    async with self._get_send_lock():
       # Closes the socket if we're currently attached to one. Once we're no
       # longer alive it'll be safe to acquire the recv lock because recv()
       # calls no longer block (raising SocketClosed instead).
 
       if self.is_alive():
-        await self.close()
+        await self._close_wo_send_lock()
 
-      with self._recv_lock:
+      async with self._get_recv_lock():
         self._reader, self._writer = await self._open_connection()
         self._is_alive = True
         self._connection_time = time.time()
@@ -201,32 +206,35 @@ class BaseSocket(object):
     Shuts down the socket. If it's already closed then this is a no-op.
     """
 
-    with self._send_lock:
-      # Function is idempotent with one exception: we notify _close() if this
-      # is causing our is_alive() state to change.
+    async with self._get_send_lock():
+      await self._close_wo_send_lock()
 
-      is_change = self.is_alive()
+  async def _close_wo_send_lock(self) -> None:
+    # Function is idempotent with one exception: we notify _close() if this
+    # is causing our is_alive() state to change.
 
-      if self._writer:
-        self._writer.close()
-        # `StreamWriter.wait_closed` was added in Python 3.7.
-        if sys.version_info >= (3, 7):
-          await self._writer.wait_closed()
+    is_change = self.is_alive()
 
-      self._reader = None
-      self._writer = None
-      self._is_alive = False
-      self._connection_time = time.time()
+    if self._writer:
+      self._writer.close()
+      # `StreamWriter.wait_closed` was added in Python 3.7.
+      if sys.version_info >= (3, 7):
+        await self._writer.wait_closed()
+
+    self._reader = None
+    self._writer = None
+    self._is_alive = False
+    self._connection_time = time.time()
 
-      if is_change:
-        await self._close()
+    if is_change:
+      await self._close()
 
   async def _send(self, message: Union[bytes, str], handler: Callable[[asyncio.StreamWriter, Union[bytes, str]], Awaitable[None]]) -> None:
     """
     Send message in a thread safe manner.
     """
 
-    with self._send_lock:
+    async with self._get_send_lock():
       try:
         if not self.is_alive():
           raise stem.SocketClosed()
@@ -237,7 +245,7 @@ class BaseSocket(object):
         # everything down
 
         if self.is_alive():
-          await self.close()
+          await self._close_wo_send_lock()
 
         raise
 
@@ -254,8 +262,8 @@ class BaseSocket(object):
     Receives a message in a thread safe manner.
     """
 
-    with self._recv_lock:
-      try:
+    try:
+      async with self._get_recv_lock():
         # makes a temporary reference to the _reader because connect()
         # and close() may set or unset it
 
@@ -265,41 +273,32 @@ class BaseSocket(object):
           raise stem.SocketClosed()
 
         return await handler(my_reader)
-      except stem.SocketClosed:
-        # If recv_message raises a SocketClosed then we should properly shut
-        # everything down. However, there's a couple cases where this will
-        # cause deadlock...
-        #
-        # * This SocketClosed was *caused by* a close() call, which is joining
-        #   on our thread.
-        #
-        # * A send() call that's currently in flight is about to call close(),
-        #   also attempting to join on us.
-        #
-        # To resolve this we make a non-blocking call to acquire the send lock.
-        # If we get it then great, we can close safely. If not then one of the
-        # above are in progress and we leave the close to them.
-
-        if self.is_alive():
-          if self._send_lock.acquire(False):
-            await self.close()
-            self._send_lock.release()
+    except stem.SocketClosed:
+      if self.is_alive():
+        await self.close()
 
-        raise
+      raise
 
-  def _get_send_lock(self) -> threading.RLock:
+  def _get_send_lock(self) -> asyncio.Lock:
     """
     The send lock is useful to classes that interact with us at a deep level
     because it's used to lock :func:`stem.socket.ControlSocket.connect` /
     :func:`stem.socket.BaseSocket.close`, and by extension our
     :func:`stem.socket.BaseSocket.is_alive` state changes.
 
-    :returns: **threading.RLock** that governs sending messages to our socket
+    :returns: **asyncio.Lock** that governs sending messages to our socket
       and state changes
     """
 
+    if self._send_lock is None:
+      self._send_lock = asyncio.Lock()
     return self._send_lock
 
+  def _get_recv_lock(self) -> asyncio.Lock:
+    if self._recv_lock is None:
+      self._recv_lock = asyncio.Lock()
+    return self._recv_lock
+
   async def __aenter__(self) -> 'stem.socket.BaseSocket':
     return self
 





More information about the tor-commits mailing list