commit 0f727a91abb27e487fad49794a428f78e1582f31 Author: Illia Volochii illia.volochii@gmail.com Date: Sat Mar 14 22:01:22 2020 +0200
Start using asyncio in "stem/async_socket.py" --- stem/async_socket.py | 175 ++++++++++++++++++--------------------------------- 1 file changed, 62 insertions(+), 113 deletions(-)
diff --git a/stem/async_socket.py b/stem/async_socket.py index 2ef42dd5..512e6bde 100644 --- a/stem/async_socket.py +++ b/stem/async_socket.py @@ -72,6 +72,7 @@ Tor...
from __future__ import absolute_import
+import asyncio import re import socket import ssl @@ -97,7 +98,8 @@ class BaseSocket(object): """
def __init__(self): - self._socket, self._socket_file = None, None + self._reader = None + self._writer = None self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected
@@ -151,7 +153,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self): + async def connect(self): """ Connects to a new socket, closing our previous one if we're already attached. @@ -165,11 +167,10 @@ class BaseSocket(object): # calls no longer block (raising SocketClosed instead).
if self.is_alive(): - self.close() + await self.close()
with self._recv_lock: - self._socket = self._make_socket() - self._socket_file = self._socket.makefile(mode = 'rwb') + self._reader, self._writer = await self._open_connection() self._is_alive = True self._connection_time = time.time()
@@ -179,11 +180,11 @@ class BaseSocket(object): # It's safe to retry, so give it another try if it fails.
try: - self._connect() + await self._connect() except stem.SocketError: - self._connect() # single retry + await self._connect() # single retry
- def close(self): + async def close(self): """ Shuts down the socket. If it's already closed then this is a no-op. """ @@ -194,32 +195,21 @@ class BaseSocket(object):
is_change = self.is_alive()
- if self._socket: - # if we haven't yet established a connection then this raises an error - # socket.error: [Errno 107] Transport endpoint is not connected + if self._writer: + self._writer.close() + # `StreamWriter.wait_closed` was added in Python 3.7. + if hasattr(self._writer, 'wait_closed'): + await self._writer.wait_closed()
- try: - self._socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self._socket.close() - - if self._socket_file: - try: - self._socket_file.close() - except BrokenPipeError: - pass - - self._socket = None - self._socket_file = None + self._reader = None + self._writer = None self._is_alive = False self._connection_time = time.time()
if is_change: - self._close() + await self._close()
- def _send(self, message, handler): + async def _send(self, message, handler): """ Send message in a thread safe manner. Handler is expected to be of the form...
@@ -233,17 +223,17 @@ class BaseSocket(object): if not self.is_alive(): raise stem.SocketClosed()
- handler(self._socket, self._socket_file, message) + await handler(self._writer, message) except stem.SocketClosed: # if send_message raises a SocketClosed then we should properly shut # everything down
if self.is_alive(): - self.close() + await self.close()
raise
- def _recv(self, handler): + async def _recv(self, handler): """ Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -254,15 +244,15 @@ class BaseSocket(object):
with self._recv_lock: try: - # makes a temporary reference to the _socket_file because connect() + # makes a temporary reference to the _reader because connect() # and close() may set or unset it
- my_socket, my_socket_file = self._socket, self._socket_file + my_reader = self._reader
- if not my_socket or not my_socket_file: + if not my_reader: raise stem.SocketClosed()
- return handler(my_socket, my_socket_file) + return await handler(my_reader) except stem.SocketClosed: # If recv_message raises a SocketClosed then we should properly shut # everything down. However, there's a couple cases where this will @@ -280,7 +270,7 @@ class BaseSocket(object):
if self.is_alive(): if self._send_lock.acquire(False): - self.close() + await self.close() self._send_lock.release()
raise @@ -298,37 +288,27 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self): + async def __aenter__(self): return self
- def __exit__(self, exit_type, value, traceback): - self.close() + async def __aexit__(self, exit_type, value, traceback): + await self.close()
- def _connect(self): + async def _connect(self): """ Connection callback that can be overwritten by subclasses and wrappers. """
pass
- def _close(self): + async def _close(self): """ Disconnection callback that can be overwritten by subclasses and wrappers. """
pass
- def _make_socket(self): - """ - Constructs and connects new socket. This is implemented by subclasses. - - :returns: **socket.socket** for our configuration - - :raises: - * :class:`stem.SocketError` if unable to make a socket - * **NotImplementedError** if not implemented by a subclass - """ - + async def _open_connection(self): raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
@@ -344,26 +324,19 @@ class RelaySocket(BaseSocket): :var int port: ORPort our socket connects to """
- def __init__(self, address = '127.0.0.1', port = 9050, connect = True): + def __init__(self, address = '127.0.0.1', port = 9050): """ RelaySocket constructor.
:param str address: ip address of the relay :param int port: orport of the relay - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(RelaySocket, self).__init__() self.address = address self.port = port
- if connect: - self.connect() - - def send(self, message): + async def send(self, message): """ Sends a message to the relay's ORPort.
@@ -374,9 +347,9 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """
- self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg)) + await self._send(message, _write_to_socket)
- def recv(self, timeout = None): + async def recv(self, timeout = None): """ Receives a message from the relay.
@@ -390,30 +363,24 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- def wrapped_recv(s, sf): + async def wrapped_recv(reader): + read_coroutine = reader.read(1024) if timeout is None: - return s.recv() + return await read_coroutine else: - s.setblocking(0) - s.settimeout(timeout) - try: - return s.recv() - except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError): + return await asyncio.wait_for(read_coroutine, timeout) + except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError): return None - finally: - s.setblocking(1)
- return self._recv(wrapped_recv) + return await self._recv(wrapped_recv)
def is_localhost(self): return self.address == '127.0.0.1'
- def _make_socket(self): + async def _open_connection(self): try: - relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - relay_socket.connect((self.address, self.port)) - return ssl.wrap_socket(relay_socket) + return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext()) except socket.error as exc: raise stem.SocketError(exc)
@@ -431,7 +398,7 @@ class ControlSocket(BaseSocket): def __init__(self): super(ControlSocket, self).__init__()
- def send(self, message): + async def send(self, message): """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -443,9 +410,9 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """
- self._send(message, lambda s, sf, msg: send_message(sf, msg)) + await self._send(message, send_message)
- def recv(self): + async def recv(self): """ Receives a message from the control socket, blocking until we've received one. For more information see the :func:`~stem.socket.recv_message` function. @@ -457,7 +424,7 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- return self._recv(lambda s, sf: recv_message(sf)) + return await self._recv(recv_message)
class ControlPort(ControlSocket): @@ -469,33 +436,24 @@ class ControlPort(ControlSocket): :var int port: ControlPort our socket connects to """
- def __init__(self, address = '127.0.0.1', port = 9051, connect = True): + def __init__(self, address = '127.0.0.1', port = 9051): """ ControlPort constructor.
:param str address: ip address of the controller :param int port: port number of the controller - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(ControlPort, self).__init__() self.address = address self.port = port
- if connect: - self.connect() - def is_localhost(self): return self.address == '127.0.0.1'
- def _make_socket(self): + async def _open_connection(self): try: - control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - control_socket.connect((self.address, self.port)) - return control_socket + return await asyncio.open_connection(self.address, self.port) except socket.error as exc: raise stem.SocketError(exc)
@@ -508,36 +466,27 @@ class ControlSocketFile(ControlSocket): :var str path: filesystem path of the socket we connect to """
- def __init__(self, path = '/var/run/tor/control', connect = True): + def __init__(self, path = '/var/run/tor/control'): """ ControlSocketFile constructor.
:param str socket_path: path where the control socket is located - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(ControlSocketFile, self).__init__() self.path = path
- if connect: - self.connect() - def is_localhost(self): return True
- def _make_socket(self): + async def _open_connection(self): try: - control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - control_socket.connect(self.path) - return control_socket + return await asyncio.open_unix_connection(self.path) except socket.error as exc: raise stem.SocketError(exc)
-def send_message(control_file, message, raw = False): +async def send_message(writer, message, raw = False): """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -572,7 +521,7 @@ def send_message(control_file, message, raw = False): if not raw: message = send_formatting(message)
- _write_to_socket(control_file, message) + await _write_to_socket(writer, message)
if log.is_tracing(): log_message = message.replace('\r\n', '\n').rstrip() @@ -580,10 +529,10 @@ def send_message(control_file, message, raw = False): log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file, message): +async def _write_to_socket(writer, message): try: - socket_file.write(stem.util.str_tools._to_bytes(message)) - socket_file.flush() + writer.write(stem.util.str_tools._to_bytes(message)) + await writer.drain() except socket.error as exc: log.info('Failed to send: %s' % exc)
@@ -603,7 +552,7 @@ def _write_to_socket(socket_file, message): raise stem.SocketClosed('file has been closed')
-def recv_message(control_file, arrived_at = None): +async def recv_message(reader, arrived_at = None): """ Pulls from a control socket until we either have a complete message or encounter a problem. @@ -623,7 +572,7 @@ def recv_message(control_file, arrived_at = None):
while True: try: - line = control_file.readline() + line = await reader.readline() except AttributeError: # if the control_file has been closed then we will receive: # AttributeError: 'NoneType' object has no attribute 'recv' @@ -689,7 +638,7 @@ def recv_message(control_file, arrived_at = None):
while True: try: - line = control_file.readline() + line = await reader.readline() raw_content += line except socket.error as exc: log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))