[stem/master] Start using asyncio in "stem/async_socket.py"

commit 0f727a91abb27e487fad49794a428f78e1582f31 Author: Illia Volochii <illia.volochii@gmail.com> Date: Sat Mar 14 22:01:22 2020 +0200 Start using asyncio in "stem/async_socket.py" --- stem/async_socket.py | 175 ++++++++++++++++++--------------------------------- 1 file changed, 62 insertions(+), 113 deletions(-) diff --git a/stem/async_socket.py b/stem/async_socket.py index 2ef42dd5..512e6bde 100644 --- a/stem/async_socket.py +++ b/stem/async_socket.py @@ -72,6 +72,7 @@ Tor... from __future__ import absolute_import +import asyncio import re import socket import ssl @@ -97,7 +98,8 @@ class BaseSocket(object): """ def __init__(self): - self._socket, self._socket_file = None, None + self._reader = None + self._writer = None self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected @@ -151,7 +153,7 @@ class BaseSocket(object): return self._connection_time - def connect(self): + async def connect(self): """ Connects to a new socket, closing our previous one if we're already attached. @@ -165,11 +167,10 @@ class BaseSocket(object): # calls no longer block (raising SocketClosed instead). if self.is_alive(): - self.close() + await self.close() with self._recv_lock: - self._socket = self._make_socket() - self._socket_file = self._socket.makefile(mode = 'rwb') + self._reader, self._writer = await self._open_connection() self._is_alive = True self._connection_time = time.time() @@ -179,11 +180,11 @@ class BaseSocket(object): # It's safe to retry, so give it another try if it fails. try: - self._connect() + await self._connect() except stem.SocketError: - self._connect() # single retry + await self._connect() # single retry - def close(self): + async def close(self): """ Shuts down the socket. If it's already closed then this is a no-op. """ @@ -194,32 +195,21 @@ class BaseSocket(object): is_change = self.is_alive() - if self._socket: - # if we haven't yet established a connection then this raises an error - # socket.error: [Errno 107] Transport endpoint is not connected + if self._writer: + self._writer.close() + # `StreamWriter.wait_closed` was added in Python 3.7. + if hasattr(self._writer, 'wait_closed'): + await self._writer.wait_closed() - try: - self._socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self._socket.close() - - if self._socket_file: - try: - self._socket_file.close() - except BrokenPipeError: - pass - - self._socket = None - self._socket_file = None + self._reader = None + self._writer = None self._is_alive = False self._connection_time = time.time() if is_change: - self._close() + await self._close() - def _send(self, message, handler): + async def _send(self, message, handler): """ Send message in a thread safe manner. Handler is expected to be of the form... @@ -233,17 +223,17 @@ class BaseSocket(object): if not self.is_alive(): raise stem.SocketClosed() - handler(self._socket, self._socket_file, message) + await handler(self._writer, message) except stem.SocketClosed: # if send_message raises a SocketClosed then we should properly shut # everything down if self.is_alive(): - self.close() + await self.close() raise - def _recv(self, handler): + async def _recv(self, handler): """ Receives a message in a thread safe manner. Handler is expected to be of the form... @@ -254,15 +244,15 @@ class BaseSocket(object): with self._recv_lock: try: - # makes a temporary reference to the _socket_file because connect() + # makes a temporary reference to the _reader because connect() # and close() may set or unset it - my_socket, my_socket_file = self._socket, self._socket_file + my_reader = self._reader - if not my_socket or not my_socket_file: + if not my_reader: raise stem.SocketClosed() - return handler(my_socket, my_socket_file) + return await handler(my_reader) except stem.SocketClosed: # If recv_message raises a SocketClosed then we should properly shut # everything down. However, there's a couple cases where this will @@ -280,7 +270,7 @@ class BaseSocket(object): if self.is_alive(): if self._send_lock.acquire(False): - self.close() + await self.close() self._send_lock.release() raise @@ -298,37 +288,27 @@ class BaseSocket(object): return self._send_lock - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, exit_type, value, traceback): - self.close() + async def __aexit__(self, exit_type, value, traceback): + await self.close() - def _connect(self): + async def _connect(self): """ Connection callback that can be overwritten by subclasses and wrappers. """ pass - def _close(self): + async def _close(self): """ Disconnection callback that can be overwritten by subclasses and wrappers. """ pass - def _make_socket(self): - """ - Constructs and connects new socket. This is implemented by subclasses. - - :returns: **socket.socket** for our configuration - - :raises: - * :class:`stem.SocketError` if unable to make a socket - * **NotImplementedError** if not implemented by a subclass - """ - + async def _open_connection(self): raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass') @@ -344,26 +324,19 @@ class RelaySocket(BaseSocket): :var int port: ORPort our socket connects to """ - def __init__(self, address = '127.0.0.1', port = 9050, connect = True): + def __init__(self, address = '127.0.0.1', port = 9050): """ RelaySocket constructor. :param str address: ip address of the relay :param int port: orport of the relay - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """ super(RelaySocket, self).__init__() self.address = address self.port = port - if connect: - self.connect() - - def send(self, message): + async def send(self, message): """ Sends a message to the relay's ORPort. @@ -374,9 +347,9 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """ - self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg)) + await self._send(message, _write_to_socket) - def recv(self, timeout = None): + async def recv(self, timeout = None): """ Receives a message from the relay. @@ -390,30 +363,24 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """ - def wrapped_recv(s, sf): + async def wrapped_recv(reader): + read_coroutine = reader.read(1024) if timeout is None: - return s.recv() + return await read_coroutine else: - s.setblocking(0) - s.settimeout(timeout) - try: - return s.recv() - except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError): + return await asyncio.wait_for(read_coroutine, timeout) + except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError): return None - finally: - s.setblocking(1) - return self._recv(wrapped_recv) + return await self._recv(wrapped_recv) def is_localhost(self): return self.address == '127.0.0.1' - def _make_socket(self): + async def _open_connection(self): try: - relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - relay_socket.connect((self.address, self.port)) - return ssl.wrap_socket(relay_socket) + return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext()) except socket.error as exc: raise stem.SocketError(exc) @@ -431,7 +398,7 @@ class ControlSocket(BaseSocket): def __init__(self): super(ControlSocket, self).__init__() - def send(self, message): + async def send(self, message): """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -443,9 +410,9 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """ - self._send(message, lambda s, sf, msg: send_message(sf, msg)) + await self._send(message, send_message) - def recv(self): + async def recv(self): """ Receives a message from the control socket, blocking until we've received one. For more information see the :func:`~stem.socket.recv_message` function. @@ -457,7 +424,7 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """ - return self._recv(lambda s, sf: recv_message(sf)) + return await self._recv(recv_message) class ControlPort(ControlSocket): @@ -469,33 +436,24 @@ class ControlPort(ControlSocket): :var int port: ControlPort our socket connects to """ - def __init__(self, address = '127.0.0.1', port = 9051, connect = True): + def __init__(self, address = '127.0.0.1', port = 9051): """ ControlPort constructor. :param str address: ip address of the controller :param int port: port number of the controller - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """ super(ControlPort, self).__init__() self.address = address self.port = port - if connect: - self.connect() - def is_localhost(self): return self.address == '127.0.0.1' - def _make_socket(self): + async def _open_connection(self): try: - control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - control_socket.connect((self.address, self.port)) - return control_socket + return await asyncio.open_connection(self.address, self.port) except socket.error as exc: raise stem.SocketError(exc) @@ -508,36 +466,27 @@ class ControlSocketFile(ControlSocket): :var str path: filesystem path of the socket we connect to """ - def __init__(self, path = '/var/run/tor/control', connect = True): + def __init__(self, path = '/var/run/tor/control'): """ ControlSocketFile constructor. :param str socket_path: path where the control socket is located - :param bool connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """ super(ControlSocketFile, self).__init__() self.path = path - if connect: - self.connect() - def is_localhost(self): return True - def _make_socket(self): + async def _open_connection(self): try: - control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - control_socket.connect(self.path) - return control_socket + return await asyncio.open_unix_connection(self.path) except socket.error as exc: raise stem.SocketError(exc) -def send_message(control_file, message, raw = False): +async def send_message(writer, message, raw = False): """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -572,7 +521,7 @@ def send_message(control_file, message, raw = False): if not raw: message = send_formatting(message) - _write_to_socket(control_file, message) + await _write_to_socket(writer, message) if log.is_tracing(): log_message = message.replace('\r\n', '\n').rstrip() @@ -580,10 +529,10 @@ def send_message(control_file, message, raw = False): log.trace('Sent to tor:%s%s' % (msg_div, log_message)) -def _write_to_socket(socket_file, message): +async def _write_to_socket(writer, message): try: - socket_file.write(stem.util.str_tools._to_bytes(message)) - socket_file.flush() + writer.write(stem.util.str_tools._to_bytes(message)) + await writer.drain() except socket.error as exc: log.info('Failed to send: %s' % exc) @@ -603,7 +552,7 @@ def _write_to_socket(socket_file, message): raise stem.SocketClosed('file has been closed') -def recv_message(control_file, arrived_at = None): +async def recv_message(reader, arrived_at = None): """ Pulls from a control socket until we either have a complete message or encounter a problem. @@ -623,7 +572,7 @@ def recv_message(control_file, arrived_at = None): while True: try: - line = control_file.readline() + line = await reader.readline() except AttributeError: # if the control_file has been closed then we will receive: # AttributeError: 'NoneType' object has no attribute 'recv' @@ -689,7 +638,7 @@ def recv_message(control_file, arrived_at = None): while True: try: - line = control_file.readline() + line = await reader.readline() raw_content += line except socket.error as exc: log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))
participants (1)
-
atagar@torproject.org