[tor-commits] [stem/master] Start using asyncio in "stem/async_socket.py"

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:57 UTC 2020


commit 0f727a91abb27e487fad49794a428f78e1582f31
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Sat Mar 14 22:01:22 2020 +0200

    Start using asyncio in "stem/async_socket.py"
---
 stem/async_socket.py | 175 ++++++++++++++++++---------------------------------
 1 file changed, 62 insertions(+), 113 deletions(-)

diff --git a/stem/async_socket.py b/stem/async_socket.py
index 2ef42dd5..512e6bde 100644
--- a/stem/async_socket.py
+++ b/stem/async_socket.py
@@ -72,6 +72,7 @@ Tor...
 
 from __future__ import absolute_import
 
+import asyncio
 import re
 import socket
 import ssl
@@ -97,7 +98,8 @@ class BaseSocket(object):
   """
 
   def __init__(self):
-    self._socket, self._socket_file = None, None
+    self._reader = None
+    self._writer = None
     self._is_alive = False
     self._connection_time = 0.0  # time when we last connected or disconnected
 
@@ -151,7 +153,7 @@ class BaseSocket(object):
 
     return self._connection_time
 
-  def connect(self):
+  async def connect(self):
     """
     Connects to a new socket, closing our previous one if we're already
     attached.
@@ -165,11 +167,10 @@ class BaseSocket(object):
       # calls no longer block (raising SocketClosed instead).
 
       if self.is_alive():
-        self.close()
+        await self.close()
 
       with self._recv_lock:
-        self._socket = self._make_socket()
-        self._socket_file = self._socket.makefile(mode = 'rwb')
+        self._reader, self._writer = await self._open_connection()
         self._is_alive = True
         self._connection_time = time.time()
 
@@ -179,11 +180,11 @@ class BaseSocket(object):
         # It's safe to retry, so give it another try if it fails.
 
         try:
-          self._connect()
+          await self._connect()
         except stem.SocketError:
-          self._connect()  # single retry
+          await self._connect()  # single retry
 
-  def close(self):
+  async def close(self):
     """
     Shuts down the socket. If it's already closed then this is a no-op.
     """
@@ -194,32 +195,21 @@ class BaseSocket(object):
 
       is_change = self.is_alive()
 
-      if self._socket:
-        # if we haven't yet established a connection then this raises an error
-        # socket.error: [Errno 107] Transport endpoint is not connected
+      if self._writer:
+        self._writer.close()
+        # `StreamWriter.wait_closed` was added in Python 3.7.
+        if hasattr(self._writer, 'wait_closed'):
+          await self._writer.wait_closed()
 
-        try:
-          self._socket.shutdown(socket.SHUT_RDWR)
-        except socket.error:
-          pass
-
-        self._socket.close()
-
-      if self._socket_file:
-        try:
-          self._socket_file.close()
-        except BrokenPipeError:
-          pass
-
-      self._socket = None
-      self._socket_file = None
+      self._reader = None
+      self._writer = None
       self._is_alive = False
       self._connection_time = time.time()
 
       if is_change:
-        self._close()
+        await self._close()
 
-  def _send(self, message, handler):
+  async def _send(self, message, handler):
     """
     Send message in a thread safe manner. Handler is expected to be of the form...
 
@@ -233,17 +223,17 @@ class BaseSocket(object):
         if not self.is_alive():
           raise stem.SocketClosed()
 
-        handler(self._socket, self._socket_file, message)
+        await handler(self._writer, message)
       except stem.SocketClosed:
         # if send_message raises a SocketClosed then we should properly shut
         # everything down
 
         if self.is_alive():
-          self.close()
+          await self.close()
 
         raise
 
-  def _recv(self, handler):
+  async def _recv(self, handler):
     """
     Receives a message in a thread safe manner. Handler is expected to be of the form...
 
@@ -254,15 +244,15 @@ class BaseSocket(object):
 
     with self._recv_lock:
       try:
-        # makes a temporary reference to the _socket_file because connect()
+        # makes a temporary reference to the _reader because connect()
         # and close() may set or unset it
 
-        my_socket, my_socket_file = self._socket, self._socket_file
+        my_reader = self._reader
 
-        if not my_socket or not my_socket_file:
+        if not my_reader:
           raise stem.SocketClosed()
 
-        return handler(my_socket, my_socket_file)
+        return await handler(my_reader)
       except stem.SocketClosed:
         # If recv_message raises a SocketClosed then we should properly shut
         # everything down. However, there's a couple cases where this will
@@ -280,7 +270,7 @@ class BaseSocket(object):
 
         if self.is_alive():
           if self._send_lock.acquire(False):
-            self.close()
+            await self.close()
             self._send_lock.release()
 
         raise
@@ -298,37 +288,27 @@ class BaseSocket(object):
 
     return self._send_lock
 
-  def __enter__(self):
+  async def __aenter__(self):
     return self
 
-  def __exit__(self, exit_type, value, traceback):
-    self.close()
+  async def __aexit__(self, exit_type, value, traceback):
+    await self.close()
 
-  def _connect(self):
+  async def _connect(self):
     """
     Connection callback that can be overwritten by subclasses and wrappers.
     """
 
     pass
 
-  def _close(self):
+  async def _close(self):
     """
     Disconnection callback that can be overwritten by subclasses and wrappers.
     """
 
     pass
 
-  def _make_socket(self):
-    """
-    Constructs and connects new socket. This is implemented by subclasses.
-
-    :returns: **socket.socket** for our configuration
-
-    :raises:
-      * :class:`stem.SocketError` if unable to make a socket
-      * **NotImplementedError** if not implemented by a subclass
-    """
-
+  async def _open_connection(self):
     raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
 
 
@@ -344,26 +324,19 @@ class RelaySocket(BaseSocket):
   :var int port: ORPort our socket connects to
   """
 
-  def __init__(self, address = '127.0.0.1', port = 9050, connect = True):
+  def __init__(self, address = '127.0.0.1', port = 9050):
     """
     RelaySocket constructor.
 
     :param str address: ip address of the relay
     :param int port: orport of the relay
-    :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
-    :raises: :class:`stem.SocketError` if connect is **True** and we're
-      unable to establish a connection
     """
 
     super(RelaySocket, self).__init__()
     self.address = address
     self.port = port
 
-    if connect:
-      self.connect()
-
-  def send(self, message):
+  async def send(self, message):
     """
     Sends a message to the relay's ORPort.
 
@@ -374,9 +347,9 @@ class RelaySocket(BaseSocket):
       * :class:`stem.SocketClosed` if the socket is known to be shut down
     """
 
-    self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg))
+    await self._send(message, _write_to_socket)
 
-  def recv(self, timeout = None):
+  async def recv(self, timeout = None):
     """
     Receives a message from the relay.
 
@@ -390,30 +363,24 @@ class RelaySocket(BaseSocket):
       * :class:`stem.SocketClosed` if the socket closes before we receive a complete message
     """
 
-    def wrapped_recv(s, sf):
+    async def wrapped_recv(reader):
+      read_coroutine = reader.read(1024)
       if timeout is None:
-        return s.recv()
+        return await read_coroutine
       else:
-        s.setblocking(0)
-        s.settimeout(timeout)
-
         try:
-          return s.recv()
-        except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError):
+          return await asyncio.wait_for(read_coroutine, timeout)
+        except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError):
           return None
-        finally:
-          s.setblocking(1)
 
-    return self._recv(wrapped_recv)
+    return await self._recv(wrapped_recv)
 
   def is_localhost(self):
     return self.address == '127.0.0.1'
 
-  def _make_socket(self):
+  async def _open_connection(self):
     try:
-      relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-      relay_socket.connect((self.address, self.port))
-      return ssl.wrap_socket(relay_socket)
+      return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext())
     except socket.error as exc:
       raise stem.SocketError(exc)
 
@@ -431,7 +398,7 @@ class ControlSocket(BaseSocket):
   def __init__(self):
     super(ControlSocket, self).__init__()
 
-  def send(self, message):
+  async def send(self, message):
     """
     Formats and sends a message to the control socket. For more information see
     the :func:`~stem.socket.send_message` function.
@@ -443,9 +410,9 @@ class ControlSocket(BaseSocket):
       * :class:`stem.SocketClosed` if the socket is known to be shut down
     """
 
-    self._send(message, lambda s, sf, msg: send_message(sf, msg))
+    await self._send(message, send_message)
 
-  def recv(self):
+  async def recv(self):
     """
     Receives a message from the control socket, blocking until we've received
     one. For more information see the :func:`~stem.socket.recv_message` function.
@@ -457,7 +424,7 @@ class ControlSocket(BaseSocket):
       * :class:`stem.SocketClosed` if the socket closes before we receive a complete message
     """
 
-    return self._recv(lambda s, sf: recv_message(sf))
+    return await self._recv(recv_message)
 
 
 class ControlPort(ControlSocket):
@@ -469,33 +436,24 @@ class ControlPort(ControlSocket):
   :var int port: ControlPort our socket connects to
   """
 
-  def __init__(self, address = '127.0.0.1', port = 9051, connect = True):
+  def __init__(self, address = '127.0.0.1', port = 9051):
     """
     ControlPort constructor.
 
     :param str address: ip address of the controller
     :param int port: port number of the controller
-    :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
-    :raises: :class:`stem.SocketError` if connect is **True** and we're
-      unable to establish a connection
     """
 
     super(ControlPort, self).__init__()
     self.address = address
     self.port = port
 
-    if connect:
-      self.connect()
-
   def is_localhost(self):
     return self.address == '127.0.0.1'
 
-  def _make_socket(self):
+  async def _open_connection(self):
     try:
-      control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-      control_socket.connect((self.address, self.port))
-      return control_socket
+      return await asyncio.open_connection(self.address, self.port)
     except socket.error as exc:
       raise stem.SocketError(exc)
 
@@ -508,36 +466,27 @@ class ControlSocketFile(ControlSocket):
   :var str path: filesystem path of the socket we connect to
   """
 
-  def __init__(self, path = '/var/run/tor/control', connect = True):
+  def __init__(self, path = '/var/run/tor/control'):
     """
     ControlSocketFile constructor.
 
     :param str socket_path: path where the control socket is located
-    :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
-    :raises: :class:`stem.SocketError` if connect is **True** and we're
-      unable to establish a connection
     """
 
     super(ControlSocketFile, self).__init__()
     self.path = path
 
-    if connect:
-      self.connect()
-
   def is_localhost(self):
     return True
 
-  def _make_socket(self):
+  async def _open_connection(self):
     try:
-      control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-      control_socket.connect(self.path)
-      return control_socket
+      return await asyncio.open_unix_connection(self.path)
     except socket.error as exc:
       raise stem.SocketError(exc)
 
 
-def send_message(control_file, message, raw = False):
+async def send_message(writer, message, raw = False):
   """
   Sends a message to the control socket, adding the expected formatting for
   single verses multi-line messages. Neither message type should contain an
@@ -572,7 +521,7 @@ def send_message(control_file, message, raw = False):
   if not raw:
     message = send_formatting(message)
 
-  _write_to_socket(control_file, message)
+  await _write_to_socket(writer, message)
 
   if log.is_tracing():
     log_message = message.replace('\r\n', '\n').rstrip()
@@ -580,10 +529,10 @@ def send_message(control_file, message, raw = False):
     log.trace('Sent to tor:%s%s' % (msg_div, log_message))
 
 
-def _write_to_socket(socket_file, message):
+async def _write_to_socket(writer, message):
   try:
-    socket_file.write(stem.util.str_tools._to_bytes(message))
-    socket_file.flush()
+    writer.write(stem.util.str_tools._to_bytes(message))
+    await writer.drain()
   except socket.error as exc:
     log.info('Failed to send: %s' % exc)
 
@@ -603,7 +552,7 @@ def _write_to_socket(socket_file, message):
     raise stem.SocketClosed('file has been closed')
 
 
-def recv_message(control_file, arrived_at = None):
+async def recv_message(reader, arrived_at = None):
   """
   Pulls from a control socket until we either have a complete message or
   encounter a problem.
@@ -623,7 +572,7 @@ def recv_message(control_file, arrived_at = None):
 
   while True:
     try:
-      line = control_file.readline()
+      line = await reader.readline()
     except AttributeError:
       # if the control_file has been closed then we will receive:
       # AttributeError: 'NoneType' object has no attribute 'recv'
@@ -689,7 +638,7 @@ def recv_message(control_file, arrived_at = None):
 
       while True:
         try:
-          line = control_file.readline()
+          line = await reader.readline()
           raw_content += line
         except socket.error as exc:
           log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))





More information about the tor-commits mailing list