commit 08d1c08dc39b9f2535fb185f4339496b8e3ea2de Author: Illia Volochii illia.volochii@gmail.com Date: Sun Apr 12 18:47:24 2020 +0300
Replace "stem/socket.py" with its asynchronous implementation --- stem/async_socket.py | 717 --------------------------------------------------- stem/socket.py | 177 +++++-------- 2 files changed, 65 insertions(+), 829 deletions(-)
diff --git a/stem/async_socket.py b/stem/async_socket.py deleted file mode 100644 index 512e6bde..00000000 --- a/stem/async_socket.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2011-2020, Damian Johnson and The Tor Project -# See LICENSE for licensing information - -""" -Supports communication with sockets speaking Tor protocols. This -allows us to send messages as basic strings, and receive responses as -:class:`~stem.response.ControlMessage` instances. - -**This module only consists of low level components, and is not intended for -users.** See our `tutorials <../tutorials.html>`_ and `Control Module -<control.html>`_ if you're new to Stem and looking to get started. - -With that aside, these can still be used for raw socket communication with -Tor... - -:: - - import stem - import stem.connection - import stem.socket - - if __name__ == '__main__': - try: - control_socket = stem.socket.ControlPort(port = 9051) - stem.connection.authenticate(control_socket) - except stem.SocketError as exc: - print 'Unable to connect to tor on port 9051: %s' % exc - sys.exit(1) - except stem.connection.AuthenticationFailure as exc: - print 'Unable to authenticate: %s' % exc - sys.exit(1) - - print "Issuing 'GETINFO version' query...\n" - control_socket.send('GETINFO version') - print control_socket.recv() - -:: - - % python example.py - Issuing 'GETINFO version' query... - - version=0.2.4.10-alpha-dev (git-8be6058d8f31e578) - OK - -**Module Overview:** - -:: - - BaseSocket - Thread safe socket. - |- RelaySocket - Socket for a relay's ORPort. - | |- send - sends a message to the socket - | +- recv - receives a response from the socket - | - |- ControlSocket - Socket wrapper that speaks the tor control protocol. - | |- ControlPort - Control connection via a port. - | |- ControlSocketFile - Control connection via a local file socket. - | | - | |- send - sends a message to the socket - | +- recv - receives a ControlMessage from the socket - | - |- is_alive - reports if the socket is known to be closed - |- is_localhost - returns if the socket is for the local system or not - |- connection_time - timestamp when socket last connected or disconnected - |- connect - connects a new socket - |- close - shuts down the socket - +- __enter__ / __exit__ - manages socket connection - - send_message - Writes a message to a control socket. - recv_message - Reads a ControlMessage from a control socket. - send_formatting - Performs the formatting expected from sent messages. -""" - -from __future__ import absolute_import - -import asyncio -import re -import socket -import ssl -import threading -import time - -import stem.response -import stem.util.str_tools - -from stem.util import log - -MESSAGE_PREFIX = re.compile(b'^[a-zA-Z0-9]{3}[-+ ]') -ERROR_MSG = 'Error while receiving a control message (%s): %s' - -# lines to limit our trace logging to, you can disable this by setting it to None - -TRUNCATE_LOGS = 10 - - -class BaseSocket(object): - """ - Thread safe socket, providing common socket functionality. - """ - - def __init__(self): - self._reader = None - self._writer = None - self._is_alive = False - self._connection_time = 0.0 # time when we last connected or disconnected - - # Tracks sending and receiving separately. This should be safe, and doing - # so prevents deadlock where we block writes because we're waiting to read - # a message that isn't coming. - - self._send_lock = threading.RLock() - self._recv_lock = threading.RLock() - - def is_alive(self): - """ - Checks if the socket is known to be closed. We won't be aware if it is - until we either use it or have explicitily shut it down. - - In practice a socket derived from a port knows about its disconnection - after failing to receive data, whereas socket file derived connections - know after either sending or receiving data. - - This means that to have reliable detection for when we're disconnected - you need to continually pull from the socket (which is part of what the - :class:`~stem.control.BaseController` does). - - :returns: **bool** that's **True** if our socket is connected and **False** - otherwise - """ - - return self._is_alive - - def is_localhost(self): - """ - Returns if the connection is for the local system or not. - - :returns: **bool** that's **True** if the connection is for the local host - and **False** otherwise - """ - - return False - - def connection_time(self): - """ - Provides the unix timestamp for when our socket was either connected or - disconnected. That is to say, the time we connected if we're currently - connected and the time we disconnected if we're not connected. - - .. versionadded:: 1.3.0 - - :returns: **float** for when we last connected or disconnected, zero if - we've never connected - """ - - return self._connection_time - - async def connect(self): - """ - Connects to a new socket, closing our previous one if we're already - attached. - - :raises: :class:`stem.SocketError` if unable to make a socket - """ - - with self._send_lock: - # Closes the socket if we're currently attached to one. Once we're no - # longer alive it'll be safe to acquire the recv lock because recv() - # calls no longer block (raising SocketClosed instead). - - if self.is_alive(): - await self.close() - - with self._recv_lock: - self._reader, self._writer = await self._open_connection() - self._is_alive = True - self._connection_time = time.time() - - # It's possible for this to have a transient failure... - # SocketError: [Errno 4] Interrupted system call - # - # It's safe to retry, so give it another try if it fails. - - try: - await self._connect() - except stem.SocketError: - await self._connect() # single retry - - async def close(self): - """ - Shuts down the socket. If it's already closed then this is a no-op. - """ - - with self._send_lock: - # Function is idempotent with one exception: we notify _close() if this - # is causing our is_alive() state to change. - - is_change = self.is_alive() - - if self._writer: - self._writer.close() - # `StreamWriter.wait_closed` was added in Python 3.7. - if hasattr(self._writer, 'wait_closed'): - await self._writer.wait_closed() - - self._reader = None - self._writer = None - self._is_alive = False - self._connection_time = time.time() - - if is_change: - await self._close() - - async def _send(self, message, handler): - """ - Send message in a thread safe manner. Handler is expected to be of the form... - - :: - - my_handler(socket, socket_file, message) - """ - - with self._send_lock: - try: - if not self.is_alive(): - raise stem.SocketClosed() - - await handler(self._writer, message) - except stem.SocketClosed: - # if send_message raises a SocketClosed then we should properly shut - # everything down - - if self.is_alive(): - await self.close() - - raise - - async def _recv(self, handler): - """ - Receives a message in a thread safe manner. Handler is expected to be of the form... - - :: - - my_handler(socket, socket_file) - """ - - with self._recv_lock: - try: - # makes a temporary reference to the _reader because connect() - # and close() may set or unset it - - my_reader = self._reader - - if not my_reader: - raise stem.SocketClosed() - - return await handler(my_reader) - except stem.SocketClosed: - # If recv_message raises a SocketClosed then we should properly shut - # everything down. However, there's a couple cases where this will - # cause deadlock... - # - # * This SocketClosed was *caused by* a close() call, which is joining - # on our thread. - # - # * A send() call that's currently in flight is about to call close(), - # also attempting to join on us. - # - # To resolve this we make a non-blocking call to acquire the send lock. - # If we get it then great, we can close safely. If not then one of the - # above are in progress and we leave the close to them. - - if self.is_alive(): - if self._send_lock.acquire(False): - await self.close() - self._send_lock.release() - - raise - - def _get_send_lock(self): - """ - The send lock is useful to classes that interact with us at a deep level - because it's used to lock :func:`stem.socket.ControlSocket.connect` / - :func:`stem.socket.BaseSocket.close`, and by extension our - :func:`stem.socket.BaseSocket.is_alive` state changes. - - :returns: **threading.RLock** that governs sending messages to our socket - and state changes - """ - - return self._send_lock - - async def __aenter__(self): - return self - - async def __aexit__(self, exit_type, value, traceback): - await self.close() - - async def _connect(self): - """ - Connection callback that can be overwritten by subclasses and wrappers. - """ - - pass - - async def _close(self): - """ - Disconnection callback that can be overwritten by subclasses and wrappers. - """ - - pass - - async def _open_connection(self): - raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass') - - -class RelaySocket(BaseSocket): - """ - `Link-level connection - https://gitweb.torproject.org/torspec.git/tree/tor-spec.txt`_ to a Tor - relay. - - .. versionadded:: 1.7.0 - - :var str address: address our socket connects to - :var int port: ORPort our socket connects to - """ - - def __init__(self, address = '127.0.0.1', port = 9050): - """ - RelaySocket constructor. - - :param str address: ip address of the relay - :param int port: orport of the relay - """ - - super(RelaySocket, self).__init__() - self.address = address - self.port = port - - async def send(self, message): - """ - Sends a message to the relay's ORPort. - - :param str message: message to be formatted and sent to the socket - - :raises: - * :class:`stem.SocketError` if a problem arises in using the socket - * :class:`stem.SocketClosed` if the socket is known to be shut down - """ - - await self._send(message, _write_to_socket) - - async def recv(self, timeout = None): - """ - Receives a message from the relay. - - :param float timeout: maxiumum number of seconds to await a response, this - blocks indefinitely if **None** - - :returns: bytes for the message received - - :raises: - * :class:`stem.ProtocolError` the content from the socket is malformed - * :class:`stem.SocketClosed` if the socket closes before we receive a complete message - """ - - async def wrapped_recv(reader): - read_coroutine = reader.read(1024) - if timeout is None: - return await read_coroutine - else: - try: - return await asyncio.wait_for(read_coroutine, timeout) - except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError): - return None - - return await self._recv(wrapped_recv) - - def is_localhost(self): - return self.address == '127.0.0.1' - - async def _open_connection(self): - try: - return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext()) - except socket.error as exc: - raise stem.SocketError(exc) - - -class ControlSocket(BaseSocket): - """ - Wrapper for a socket connection that speaks the Tor control protocol. To the - better part this transparently handles the formatting for sending and - receiving complete messages. - - Callers should not instantiate this class directly, but rather use subclasses - which are expected to implement the **_make_socket()** method. - """ - - def __init__(self): - super(ControlSocket, self).__init__() - - async def send(self, message): - """ - Formats and sends a message to the control socket. For more information see - the :func:`~stem.socket.send_message` function. - - :param str message: message to be formatted and sent to the socket - - :raises: - * :class:`stem.SocketError` if a problem arises in using the socket - * :class:`stem.SocketClosed` if the socket is known to be shut down - """ - - await self._send(message, send_message) - - async def recv(self): - """ - Receives a message from the control socket, blocking until we've received - one. For more information see the :func:`~stem.socket.recv_message` function. - - :returns: :class:`~stem.response.ControlMessage` for the message received - - :raises: - * :class:`stem.ProtocolError` the content from the socket is malformed - * :class:`stem.SocketClosed` if the socket closes before we receive a complete message - """ - - return await self._recv(recv_message) - - -class ControlPort(ControlSocket): - """ - Control connection to tor. For more information see tor's ControlPort torrc - option. - - :var str address: address our socket connects to - :var int port: ControlPort our socket connects to - """ - - def __init__(self, address = '127.0.0.1', port = 9051): - """ - ControlPort constructor. - - :param str address: ip address of the controller - :param int port: port number of the controller - """ - - super(ControlPort, self).__init__() - self.address = address - self.port = port - - def is_localhost(self): - return self.address == '127.0.0.1' - - async def _open_connection(self): - try: - return await asyncio.open_connection(self.address, self.port) - except socket.error as exc: - raise stem.SocketError(exc) - - -class ControlSocketFile(ControlSocket): - """ - Control connection to tor. For more information see tor's ControlSocket torrc - option. - - :var str path: filesystem path of the socket we connect to - """ - - def __init__(self, path = '/var/run/tor/control'): - """ - ControlSocketFile constructor. - - :param str socket_path: path where the control socket is located - """ - - super(ControlSocketFile, self).__init__() - self.path = path - - def is_localhost(self): - return True - - async def _open_connection(self): - try: - return await asyncio.open_unix_connection(self.path) - except socket.error as exc: - raise stem.SocketError(exc) - - -async def send_message(writer, message, raw = False): - """ - Sends a message to the control socket, adding the expected formatting for - single verses multi-line messages. Neither message type should contain an - ending newline (if so it'll be treated as a multi-line message with a blank - line at the end). If the message doesn't contain a newline then it's sent - as... - - :: - - <message>\r\n - - and if it does contain newlines then it's split on ``\n`` and sent as... - - :: - - +<line 1>\r\n - <line 2>\r\n - <line 3>\r\n - .\r\n - - :param file control_file: file derived from the control socket (see the - socket's makefile() method for more information) - :param str message: message to be sent on the control socket - :param bool raw: leaves the message formatting untouched, passing it to the - socket as-is - - :raises: - * :class:`stem.SocketError` if a problem arises in using the socket - * :class:`stem.SocketClosed` if the socket is known to be shut down - """ - - if not raw: - message = send_formatting(message) - - await _write_to_socket(writer, message) - - if log.is_tracing(): - log_message = message.replace('\r\n', '\n').rstrip() - msg_div = '\n' if '\n' in log_message else ' ' - log.trace('Sent to tor:%s%s' % (msg_div, log_message)) - - -async def _write_to_socket(writer, message): - try: - writer.write(stem.util.str_tools._to_bytes(message)) - await writer.drain() - except socket.error as exc: - log.info('Failed to send: %s' % exc) - - # When sending there doesn't seem to be a reliable method for - # distinguishing between failures from a disconnect verses other things. - # Just accounting for known disconnection responses. - - if str(exc) == '[Errno 32] Broken pipe': - raise stem.SocketClosed(exc) - else: - raise stem.SocketError(exc) - except AttributeError: - # if the control_file has been closed then flush will receive: - # AttributeError: 'NoneType' object has no attribute 'sendall' - - log.info('Failed to send: file has been closed') - raise stem.SocketClosed('file has been closed') - - -async def recv_message(reader, arrived_at = None): - """ - Pulls from a control socket until we either have a complete message or - encounter a problem. - - :param file control_file: file derived from the control socket (see the - socket's makefile() method for more information) - - :returns: :class:`~stem.response.ControlMessage` read from the socket - - :raises: - * :class:`stem.ProtocolError` the content from the socket is malformed - * :class:`stem.SocketClosed` if the socket closes before we receive - a complete message - """ - - parsed_content, raw_content, first_line = None, None, True - - while True: - try: - line = await reader.readline() - except AttributeError: - # if the control_file has been closed then we will receive: - # AttributeError: 'NoneType' object has no attribute 'recv' - - log.info(ERROR_MSG % ('SocketClosed', 'socket file has been closed')) - raise stem.SocketClosed('socket file has been closed') - except (OSError, ValueError) as exc: - # when disconnected this errors with... - # - # * ValueError: I/O operation on closed file - # * OSError: [Errno 107] Transport endpoint is not connected - # * OSError: [Errno 9] Bad file descriptor - - log.info(ERROR_MSG % ('SocketClosed', 'received exception "%s"' % exc)) - raise stem.SocketClosed(exc) - - # Parses the tor control lines. These are of the form... - # <status code><divider><content>\r\n - - if not line: - # if the socket is disconnected then the readline() method will provide - # empty content - - log.info(ERROR_MSG % ('SocketClosed', 'empty socket content')) - raise stem.SocketClosed('Received empty socket content.') - elif not MESSAGE_PREFIX.match(line): - log.info(ERROR_MSG % ('ProtocolError', 'malformed status code/divider, "%s"' % log.escape(line))) - raise stem.ProtocolError('Badly formatted reply line: beginning is malformed') - elif not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'no CRLF linebreak, "%s"' % log.escape(line))) - raise stem.ProtocolError('All lines should end with CRLF') - - status_code, divider, content = line[:3], line[3:4], line[4:-2] # strip CRLF off content - - status_code = stem.util.str_tools._to_unicode(status_code) - divider = stem.util.str_tools._to_unicode(divider) - - # Most controller responses are single lines, in which case we don't need - # so much overhead. - - if first_line: - if divider == ' ': - _log_trace(line) - return stem.response.ControlMessage([(status_code, divider, content)], line, arrived_at = arrived_at) - else: - parsed_content, raw_content, first_line = [], bytearray(), False - - raw_content += line - - if divider == '-': - # mid-reply line, keep pulling for more content - parsed_content.append((status_code, divider, content)) - elif divider == ' ': - # end of the message, return the message - parsed_content.append((status_code, divider, content)) - _log_trace(bytes(raw_content)) - return stem.response.ControlMessage(parsed_content, bytes(raw_content), arrived_at = arrived_at) - elif divider == '+': - # data entry, all of the following lines belong to the content until we - # get a line with just a period - - content_block = bytearray(content) - - while True: - try: - line = await reader.readline() - raw_content += line - except socket.error as exc: - log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content))))) - raise stem.SocketClosed(exc) - - if not line.endswith(b'\r\n'): - log.info(ERROR_MSG % ('ProtocolError', 'CRLF linebreaks missing from a data reply, "%s"' % log.escape(bytes(raw_content)))) - raise stem.ProtocolError('All lines should end with CRLF') - elif line == b'.\r\n': - break # data block termination - - line = line[:-2] # strips off the CRLF - - # lines starting with a period are escaped by a second period (as per - # section 2.4 of the control-spec) - - if line.startswith(b'..'): - line = line[1:] - - content_block += b'\n' + line - - # joins the content using a newline rather than CRLF separator (more - # conventional for multi-line string content outside the windows world) - - parsed_content.append((status_code, divider, bytes(content_block))) - else: - # this should never be reached due to the prefix regex, but might as well - # be safe... - - log.warn(ERROR_MSG % ('ProtocolError', ""%s" isn't a recognized divider type" % divider)) - raise stem.ProtocolError("Unrecognized divider type '%s': %s" % (divider, stem.util.str_tools._to_unicode(line))) - - -def send_formatting(message): - """ - Performs the formatting expected from sent control messages. For more - information see the :func:`~stem.socket.send_message` function. - - :param str message: message to be formatted - - :returns: **str** of the message wrapped by the formatting expected from - controllers - """ - - # From control-spec section 2.2... - # Command = Keyword OptArguments CRLF / "+" Keyword OptArguments CRLF CmdData - # Keyword = 1*ALPHA - # OptArguments = [ SP *(SP / VCHAR) ] - # - # A command is either a single line containing a Keyword and arguments, or a - # multiline command whose initial keyword begins with +, and whose data - # section ends with a single "." on a line of its own. - - # if we already have \r\n entries then standardize on \n to start with - message = message.replace('\r\n', '\n') - - if '\n' in message: - return '+%s\r\n.\r\n' % message.replace('\n', '\r\n') - else: - return message + '\r\n' - - -def _log_trace(response): - if not log.is_tracing(): - return - - log_message = stem.util.str_tools._to_unicode(response.replace(b'\r\n', b'\n').rstrip()) - log_message_lines = log_message.split('\n') - - if TRUNCATE_LOGS and len(log_message_lines) > TRUNCATE_LOGS: - log_message = '\n'.join(log_message_lines[:TRUNCATE_LOGS] + ['... %i more lines...' % (len(log_message_lines) - TRUNCATE_LOGS)]) - - if len(log_message_lines) > 2: - log.trace('Received from tor:\n%s' % log_message) - else: - log.trace('Received from tor: %s' % log_message.replace('\n', '\n')) diff --git a/stem/socket.py b/stem/socket.py index 81019cf2..dd123751 100644 --- a/stem/socket.py +++ b/stem/socket.py @@ -69,6 +69,7 @@ Tor... send_formatting - Performs the formatting expected from sent messages. """
+import asyncio import re import socket import ssl @@ -96,8 +97,8 @@ class BaseSocket(object): """
def __init__(self) -> None: - self._socket = None # type: Optional[Union[socket.socket, ssl.SSLSocket]] - self._socket_file = None # type: Optional[BinaryIO] + self._reader = None + self._writer = None self._is_alive = False self._connection_time = 0.0 # time when we last connected or disconnected
@@ -151,7 +152,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self) -> None: + async def connect(self) -> None: """ Connects to a new socket, closing our previous one if we're already attached. @@ -165,11 +166,10 @@ class BaseSocket(object): # calls no longer block (raising SocketClosed instead).
if self.is_alive(): - self.close() + await self.close()
with self._recv_lock: - self._socket = self._make_socket() - self._socket_file = self._socket.makefile(mode = 'rwb') + self._reader, self._writer = await self._open_connection() self._is_alive = True self._connection_time = time.time()
@@ -179,11 +179,11 @@ class BaseSocket(object): # It's safe to retry, so give it another try if it fails.
try: - self._connect() + await self._connect() except stem.SocketError: - self._connect() # single retry + await self._connect() # single retry
- def close(self) -> None: + async def close(self) -> None: """ Shuts down the socket. If it's already closed then this is a no-op. """ @@ -194,32 +194,21 @@ class BaseSocket(object):
is_change = self.is_alive()
- if self._socket: - # if we haven't yet established a connection then this raises an error - # socket.error: [Errno 107] Transport endpoint is not connected + if self._writer: + self._writer.close() + # `StreamWriter.wait_closed` was added in Python 3.7. + if hasattr(self._writer, 'wait_closed'): + await self._writer.wait_closed()
- try: - self._socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self._socket.close() - - if self._socket_file: - try: - self._socket_file.close() - except BrokenPipeError: - pass - - self._socket = None - self._socket_file = None + self._reader = None + self._writer = None self._is_alive = False self._connection_time = time.time()
if is_change: - self._close() + await self._close()
- def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None: + async def _send(self, message: Union[bytes, str], handler: Callable[[Union[socket.socket, ssl.SSLSocket], BinaryIO, Union[bytes, str]], None]) -> None: """ Send message in a thread safe manner. Handler is expected to be of the form...
@@ -233,25 +222,25 @@ class BaseSocket(object): if not self.is_alive(): raise stem.SocketClosed()
- handler(self._socket, self._socket_file, message) + await handler(self._writer, message) except stem.SocketClosed: # if send_message raises a SocketClosed then we should properly shut # everything down
if self.is_alive(): - self.close() + await self.close()
raise
@overload - def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes: + async def _recv(self, handler: Callable[[ssl.SSLSocket, BinaryIO], bytes]) -> bytes: ...
@overload - def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage: + async def _recv(self, handler: Callable[[socket.socket, BinaryIO], stem.response.ControlMessage]) -> stem.response.ControlMessage: ...
- def _recv(self, handler): + async def _recv(self, handler): """ Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -262,15 +251,15 @@ class BaseSocket(object):
with self._recv_lock: try: - # makes a temporary reference to the _socket_file because connect() + # makes a temporary reference to the _reader because connect() # and close() may set or unset it
- my_socket, my_socket_file = self._socket, self._socket_file + my_reader = self._reader
- if not my_socket or not my_socket_file: + if not my_reader: raise stem.SocketClosed()
- return handler(my_socket, my_socket_file) + return await handler(my_reader) except stem.SocketClosed: # If recv_message raises a SocketClosed then we should properly shut # everything down. However, there's a couple cases where this will @@ -288,7 +277,7 @@ class BaseSocket(object):
if self.is_alive(): if self._send_lock.acquire(False): - self.close() + await self.close() self._send_lock.release()
raise @@ -306,35 +295,31 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self) -> 'stem.socket.BaseSocket': + async def __aenter__(self) -> 'stem.socket.BaseSocket': return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): + async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): self.close()
- def _connect(self) -> None: + async def _connect(self) -> None: """ Connection callback that can be overwritten by subclasses and wrappers. """
pass
- def _close(self) -> None: + async def _close(self) -> None: """ Disconnection callback that can be overwritten by subclasses and wrappers. """
pass
- def _make_socket(self) -> Union[socket.socket, ssl.SSLSocket]: + async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ Constructs and connects new socket. This is implemented by subclasses.
- :returns: **socket.socket** for our configuration - - :raises: - * :class:`stem.SocketError` if unable to make a socket - * **NotImplementedError** if not implemented by a subclass + :returns: **tuple** with our reader and writer streams """
raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass') @@ -352,26 +337,19 @@ class RelaySocket(BaseSocket): :var int port: ORPort our socket connects to """
- def __init__(self, address: str = '127.0.0.1', port: int = 9050, connect: bool = True) -> None: + def __init__(self, address: str = '127.0.0.1', port: int = 9050) -> None: """ RelaySocket constructor.
:param address: ip address of the relay :param port: orport of the relay - :param connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(RelaySocket, self).__init__() self.address = address self.port = port
- if connect: - self.connect() - - def send(self, message: Union[str, bytes]) -> None: + async def send(self, message: Union[str, bytes]) -> None: """ Sends a message to the relay's ORPort.
@@ -382,9 +360,9 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """
- self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg)) + await self._send(message, _write_to_socket)
- def recv(self, timeout: Optional[float] = None) -> bytes: + async def recv(self, timeout: Optional[float] = None) -> bytes: """ Receives a message from the relay.
@@ -398,30 +376,24 @@ class RelaySocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes: + async def wrapped_recv(s: ssl.SSLSocket, sf: BinaryIO) -> bytes: + read_coroutine = reader.read(1024) if timeout is None: - return s.recv(1024) + return await read_coroutine else: - s.setblocking(False) - s.settimeout(timeout) - try: - return s.recv(1024) - except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError): + return await asyncio.wait_for(read_coroutine, timeout) + except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError): return None - finally: - s.setblocking(True)
- return self._recv(wrapped_recv) + return await self._recv(wrapped_recv)
def is_localhost(self) -> bool: return self.address == '127.0.0.1'
- def _make_socket(self) -> ssl.SSLSocket: + async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: try: - relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - relay_socket.connect((self.address, self.port)) - return ssl.wrap_socket(relay_socket) + return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext()) except socket.error as exc: raise stem.SocketError(exc)
@@ -439,7 +411,7 @@ class ControlSocket(BaseSocket): def __init__(self) -> None: super(ControlSocket, self).__init__()
- def send(self, message: Union[bytes, str]) -> None: + async def send(self, message: Union[bytes, str]) -> None: """ Formats and sends a message to the control socket. For more information see the :func:`~stem.socket.send_message` function. @@ -451,9 +423,9 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket is known to be shut down """
- self._send(message, lambda s, sf, msg: send_message(sf, msg)) + await self._send(message, send_message)
- def recv(self) -> stem.response.ControlMessage: + async def recv(self) -> stem.response.ControlMessage: """ Receives a message from the control socket, blocking until we've received one. For more information see the :func:`~stem.socket.recv_message` function. @@ -465,7 +437,7 @@ class ControlSocket(BaseSocket): * :class:`stem.SocketClosed` if the socket closes before we receive a complete message """
- return self._recv(lambda s, sf: recv_message(sf)) + return await self._recv(recv_message)
class ControlPort(ControlSocket): @@ -477,33 +449,24 @@ class ControlPort(ControlSocket): :var int port: ControlPort our socket connects to """
- def __init__(self, address: str = '127.0.0.1', port: int = 9051, connect: bool = True) -> None: + def __init__(self, address: str = '127.0.0.1', port: int = 9051) -> None: """ ControlPort constructor.
:param address: ip address of the controller :param port: port number of the controller - :param connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(ControlPort, self).__init__() self.address = address self.port = port
- if connect: - self.connect() - def is_localhost(self) -> bool: return self.address == '127.0.0.1'
- def _make_socket(self) -> socket.socket: + async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: try: - control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - control_socket.connect((self.address, self.port)) - return control_socket + return await asyncio.open_connection(self.address, self.port) except socket.error as exc: raise stem.SocketError(exc)
@@ -516,36 +479,27 @@ class ControlSocketFile(ControlSocket): :var str path: filesystem path of the socket we connect to """
- def __init__(self, path: str = '/var/run/tor/control', connect: bool = True) -> None: + def __init__(self, path: str = '/var/run/tor/control') -> None: """ ControlSocketFile constructor.
:param socket_path: path where the control socket is located - :param connect: connects to the socket if True, leaves it unconnected otherwise - - :raises: :class:`stem.SocketError` if connect is **True** and we're - unable to establish a connection """
super(ControlSocketFile, self).__init__() self.path = path
- if connect: - self.connect() - def is_localhost(self) -> bool: return True
- def _make_socket(self) -> socket.socket: + async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: try: - control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - control_socket.connect(self.path) - return control_socket + return await asyncio.open_unix_connection(self.path) except socket.error as exc: raise stem.SocketError(exc)
-def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = False) -> None: +async def send_message(writer: asyncio.StreamWriter, message: Union[bytes, str], raw: bool = False) -> None: """ Sends a message to the control socket, adding the expected formatting for single verses multi-line messages. Neither message type should contain an @@ -566,8 +520,7 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = <line 3>\r\n .\r\n
- :param control_file: file derived from the control socket (see the - socket's makefile() method for more information) + :param writer: stream derived from the control socket :param message: message to be sent on the control socket :param raw: leaves the message formatting untouched, passing it to the socket as-is @@ -582,7 +535,7 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = if not raw: message = send_formatting(message)
- _write_to_socket(control_file, message) + await _write_to_socket(writer, message)
if log.is_tracing(): log_message = message.replace('\r\n', '\n').rstrip() @@ -590,10 +543,10 @@ def send_message(control_file: BinaryIO, message: Union[bytes, str], raw: bool = log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None: +async def _write_to_socket(writer: asyncio.StreamWriter, message: Union[str, bytes]) -> None: try: - socket_file.write(stem.util.str_tools._to_bytes(message)) - socket_file.flush() + writer.write(stem.util.str_tools._to_bytes(message)) + await writer.drain() except socket.error as exc: log.info('Failed to send: %s' % exc)
@@ -613,7 +566,7 @@ def _write_to_socket(socket_file: BinaryIO, message: Union[str, bytes]) -> None: raise stem.SocketClosed('file has been closed')
-def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) -> stem.response.ControlMessage: +async def recv_message(reader: asyncio.StreamReader, arrived_at: Optional[float] = None) -> stem.response.ControlMessage: """ Pulls from a control socket until we either have a complete message or encounter a problem. @@ -635,7 +588,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
while True: try: - line = control_file.readline() + line = await reader.readline() except AttributeError: # if the control_file has been closed then we will receive: # AttributeError: 'NoneType' object has no attribute 'recv' @@ -701,7 +654,7 @@ def recv_message(control_file: BinaryIO, arrived_at: Optional[float] = None) ->
while True: try: - line = control_file.readline() + line = await reader.readline() raw_content += line except socket.error as exc: log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content).decode('utf-8')))))