commit 0f727a91abb27e487fad49794a428f78e1582f31
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sat Mar 14 22:01:22 2020 +0200
Start using asyncio in "stem/async_socket.py"
---
stem/async_socket.py | 175 ++++++++++++++++++---------------------------------
1 file changed, 62 insertions(+), 113 deletions(-)
diff --git a/stem/async_socket.py b/stem/async_socket.py
index 2ef42dd5..512e6bde 100644
--- a/stem/async_socket.py
+++ b/stem/async_socket.py
@@ -72,6 +72,7 @@ Tor...
from __future__ import absolute_import
+import asyncio
import re
import socket
import ssl
@@ -97,7 +98,8 @@ class BaseSocket(object):
"""
def __init__(self):
- self._socket, self._socket_file = None, None
+ self._reader = None
+ self._writer = None
self._is_alive = False
self._connection_time = 0.0 # time when we last connected or disconnected
@@ -151,7 +153,7 @@ class BaseSocket(object):
return self._connection_time
- def connect(self):
+ async def connect(self):
"""
Connects to a new socket, closing our previous one if we're already
attached.
@@ -165,11 +167,10 @@ class BaseSocket(object):
# calls no longer block (raising SocketClosed instead).
if self.is_alive():
- self.close()
+ await self.close()
with self._recv_lock:
- self._socket = self._make_socket()
- self._socket_file = self._socket.makefile(mode = 'rwb')
+ self._reader, self._writer = await self._open_connection()
self._is_alive = True
self._connection_time = time.time()
@@ -179,11 +180,11 @@ class BaseSocket(object):
# It's safe to retry, so give it another try if it fails.
try:
- self._connect()
+ await self._connect()
except stem.SocketError:
- self._connect() # single retry
+ await self._connect() # single retry
- def close(self):
+ async def close(self):
"""
Shuts down the socket. If it's already closed then this is a no-op.
"""
@@ -194,32 +195,21 @@ class BaseSocket(object):
is_change = self.is_alive()
- if self._socket:
- # if we haven't yet established a connection then this raises an error
- # socket.error: [Errno 107] Transport endpoint is not connected
+ if self._writer:
+ self._writer.close()
+ # `StreamWriter.wait_closed` was added in Python 3.7.
+ if hasattr(self._writer, 'wait_closed'):
+ await self._writer.wait_closed()
- try:
- self._socket.shutdown(socket.SHUT_RDWR)
- except socket.error:
- pass
-
- self._socket.close()
-
- if self._socket_file:
- try:
- self._socket_file.close()
- except BrokenPipeError:
- pass
-
- self._socket = None
- self._socket_file = None
+ self._reader = None
+ self._writer = None
self._is_alive = False
self._connection_time = time.time()
if is_change:
- self._close()
+ await self._close()
- def _send(self, message, handler):
+ async def _send(self, message, handler):
"""
Send message in a thread safe manner. Handler is expected to be of the form...
@@ -233,17 +223,17 @@ class BaseSocket(object):
if not self.is_alive():
raise stem.SocketClosed()
- handler(self._socket, self._socket_file, message)
+ await handler(self._writer, message)
except stem.SocketClosed:
# if send_message raises a SocketClosed then we should properly shut
# everything down
if self.is_alive():
- self.close()
+ await self.close()
raise
- def _recv(self, handler):
+ async def _recv(self, handler):
"""
Receives a message in a thread safe manner. Handler is expected to be of the form...
@@ -254,15 +244,15 @@ class BaseSocket(object):
with self._recv_lock:
try:
- # makes a temporary reference to the _socket_file because connect()
+ # makes a temporary reference to the _reader because connect()
# and close() may set or unset it
- my_socket, my_socket_file = self._socket, self._socket_file
+ my_reader = self._reader
- if not my_socket or not my_socket_file:
+ if not my_reader:
raise stem.SocketClosed()
- return handler(my_socket, my_socket_file)
+ return await handler(my_reader)
except stem.SocketClosed:
# If recv_message raises a SocketClosed then we should properly shut
# everything down. However, there's a couple cases where this will
@@ -280,7 +270,7 @@ class BaseSocket(object):
if self.is_alive():
if self._send_lock.acquire(False):
- self.close()
+ await self.close()
self._send_lock.release()
raise
@@ -298,37 +288,27 @@ class BaseSocket(object):
return self._send_lock
- def __enter__(self):
+ async def __aenter__(self):
return self
- def __exit__(self, exit_type, value, traceback):
- self.close()
+ async def __aexit__(self, exit_type, value, traceback):
+ await self.close()
- def _connect(self):
+ async def _connect(self):
"""
Connection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _close(self):
+ async def _close(self):
"""
Disconnection callback that can be overwritten by subclasses and wrappers.
"""
pass
- def _make_socket(self):
- """
- Constructs and connects new socket. This is implemented by subclasses.
-
- :returns: **socket.socket** for our configuration
-
- :raises:
- * :class:`stem.SocketError` if unable to make a socket
- * **NotImplementedError** if not implemented by a subclass
- """
-
+ async def _open_connection(self):
raise NotImplementedError('Unsupported Operation: this should be implemented by the BaseSocket subclass')
@@ -344,26 +324,19 @@ class RelaySocket(BaseSocket):
:var int port: ORPort our socket connects to
"""
- def __init__(self, address = '127.0.0.1', port = 9050, connect = True):
+ def __init__(self, address = '127.0.0.1', port = 9050):
"""
RelaySocket constructor.
:param str address: ip address of the relay
:param int port: orport of the relay
- :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(RelaySocket, self).__init__()
self.address = address
self.port = port
- if connect:
- self.connect()
-
- def send(self, message):
+ async def send(self, message):
"""
Sends a message to the relay's ORPort.
@@ -374,9 +347,9 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket is known to be shut down
"""
- self._send(message, lambda s, sf, msg: _write_to_socket(sf, msg))
+ await self._send(message, _write_to_socket)
- def recv(self, timeout = None):
+ async def recv(self, timeout = None):
"""
Receives a message from the relay.
@@ -390,30 +363,24 @@ class RelaySocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- def wrapped_recv(s, sf):
+ async def wrapped_recv(reader):
+ read_coroutine = reader.read(1024)
if timeout is None:
- return s.recv()
+ return await read_coroutine
else:
- s.setblocking(0)
- s.settimeout(timeout)
-
try:
- return s.recv()
- except (socket.timeout, ssl.SSLError, ssl.SSLWantReadError):
+ return await asyncio.wait_for(read_coroutine, timeout)
+ except (asyncio.TimeoutError, ssl.SSLError, ssl.SSLWantReadError):
return None
- finally:
- s.setblocking(1)
- return self._recv(wrapped_recv)
+ return await self._recv(wrapped_recv)
def is_localhost(self):
return self.address == '127.0.0.1'
- def _make_socket(self):
+ async def _open_connection(self):
try:
- relay_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- relay_socket.connect((self.address, self.port))
- return ssl.wrap_socket(relay_socket)
+ return await asyncio.open_connection(self.address, self.port, ssl=ssl.SSLContext())
except socket.error as exc:
raise stem.SocketError(exc)
@@ -431,7 +398,7 @@ class ControlSocket(BaseSocket):
def __init__(self):
super(ControlSocket, self).__init__()
- def send(self, message):
+ async def send(self, message):
"""
Formats and sends a message to the control socket. For more information see
the :func:`~stem.socket.send_message` function.
@@ -443,9 +410,9 @@ class ControlSocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket is known to be shut down
"""
- self._send(message, lambda s, sf, msg: send_message(sf, msg))
+ await self._send(message, send_message)
- def recv(self):
+ async def recv(self):
"""
Receives a message from the control socket, blocking until we've received
one. For more information see the :func:`~stem.socket.recv_message` function.
@@ -457,7 +424,7 @@ class ControlSocket(BaseSocket):
* :class:`stem.SocketClosed` if the socket closes before we receive a complete message
"""
- return self._recv(lambda s, sf: recv_message(sf))
+ return await self._recv(recv_message)
class ControlPort(ControlSocket):
@@ -469,33 +436,24 @@ class ControlPort(ControlSocket):
:var int port: ControlPort our socket connects to
"""
- def __init__(self, address = '127.0.0.1', port = 9051, connect = True):
+ def __init__(self, address = '127.0.0.1', port = 9051):
"""
ControlPort constructor.
:param str address: ip address of the controller
:param int port: port number of the controller
- :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(ControlPort, self).__init__()
self.address = address
self.port = port
- if connect:
- self.connect()
-
def is_localhost(self):
return self.address == '127.0.0.1'
- def _make_socket(self):
+ async def _open_connection(self):
try:
- control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- control_socket.connect((self.address, self.port))
- return control_socket
+ return await asyncio.open_connection(self.address, self.port)
except socket.error as exc:
raise stem.SocketError(exc)
@@ -508,36 +466,27 @@ class ControlSocketFile(ControlSocket):
:var str path: filesystem path of the socket we connect to
"""
- def __init__(self, path = '/var/run/tor/control', connect = True):
+ def __init__(self, path = '/var/run/tor/control'):
"""
ControlSocketFile constructor.
:param str socket_path: path where the control socket is located
- :param bool connect: connects to the socket if True, leaves it unconnected otherwise
-
- :raises: :class:`stem.SocketError` if connect is **True** and we're
- unable to establish a connection
"""
super(ControlSocketFile, self).__init__()
self.path = path
- if connect:
- self.connect()
-
def is_localhost(self):
return True
- def _make_socket(self):
+ async def _open_connection(self):
try:
- control_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- control_socket.connect(self.path)
- return control_socket
+ return await asyncio.open_unix_connection(self.path)
except socket.error as exc:
raise stem.SocketError(exc)
-def send_message(control_file, message, raw = False):
+async def send_message(writer, message, raw = False):
"""
Sends a message to the control socket, adding the expected formatting for
single verses multi-line messages. Neither message type should contain an
@@ -572,7 +521,7 @@ def send_message(control_file, message, raw = False):
if not raw:
message = send_formatting(message)
- _write_to_socket(control_file, message)
+ await _write_to_socket(writer, message)
if log.is_tracing():
log_message = message.replace('\r\n', '\n').rstrip()
@@ -580,10 +529,10 @@ def send_message(control_file, message, raw = False):
log.trace('Sent to tor:%s%s' % (msg_div, log_message))
-def _write_to_socket(socket_file, message):
+async def _write_to_socket(writer, message):
try:
- socket_file.write(stem.util.str_tools._to_bytes(message))
- socket_file.flush()
+ writer.write(stem.util.str_tools._to_bytes(message))
+ await writer.drain()
except socket.error as exc:
log.info('Failed to send: %s' % exc)
@@ -603,7 +552,7 @@ def _write_to_socket(socket_file, message):
raise stem.SocketClosed('file has been closed')
-def recv_message(control_file, arrived_at = None):
+async def recv_message(reader, arrived_at = None):
"""
Pulls from a control socket until we either have a complete message or
encounter a problem.
@@ -623,7 +572,7 @@ def recv_message(control_file, arrived_at = None):
while True:
try:
- line = control_file.readline()
+ line = await reader.readline()
except AttributeError:
# if the control_file has been closed then we will receive:
# AttributeError: 'NoneType' object has no attribute 'recv'
@@ -689,7 +638,7 @@ def recv_message(control_file, arrived_at = None):
while True:
try:
- line = control_file.readline()
+ line = await reader.readline()
raw_content += line
except socket.error as exc:
log.info(ERROR_MSG % ('SocketClosed', 'received an exception while mid-way through a data reply (exception: "%s", read content: "%s")' % (exc, log.escape(bytes(raw_content)))))