commit a8e92bbb1532f25da843a29b1b771eeb80c477bd
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Mon Apr 27 18:57:02 2020 +0300
Fix `stem.connection.connect`
---
stem/connection.py | 39 +++++++++++++++++++--------------------
stem/control.py | 10 +++++++---
2 files changed, 26 insertions(+), 23 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index 3f7983de..6b985725 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -267,10 +267,17 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
control_connection = None # type: Optional[stem.socket.ControlSocket]
error_msg = ''
+ async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread.start()
+
+ def connect_socket(socket):
+ asyncio.run_coroutine_threadsafe(socket.connect(), async_controller_thread.loop).result()
+
if control_socket:
if os.path.exists(control_socket):
try:
control_connection = stem.socket.ControlSocketFile(control_socket)
+ connect_socket(control_connection)
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_socket'].format(path = control_socket, error = exc)
else:
@@ -284,6 +291,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
control_connection = _connection_for_default_port(address)
else:
control_connection = stem.socket.ControlPort(address, int(port))
+ connect_socket(control_connection)
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
@@ -297,12 +305,14 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
error_msg = CONNECT_MESSAGES['no_control_port'] if is_tor_running else CONNECT_MESSAGES['tor_isnt_running']
print(error_msg)
+ if async_controller_thread.is_alive():
+ async_controller_thread.join()
return None
- return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
+ return _connect_auth(control_connection, password, password_prompt, chroot_path, controller, async_controller_thread)
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
+def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]], async_controller_thread: 'threading.Thread') -> Any:
"""
Helper for the connect_* functions that authenticates the socket and
constructs the controller.
@@ -318,20 +328,13 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
:returns: authenticated control connection, the type based on the controller argument
"""
- if controller:
- loop = controller._asyncio_loop
- asyncio_thread = None
- else:
- loop = asyncio.new_event_loop()
- asyncio_thread = threading.Thread(target=loop.run_forever, name='async_auth')
- asyncio_thread.setDaemon(True)
- asyncio_thread.start()
-
def run_coroutine(coroutine):
- asyncio.run_coroutine_threadsafe(coroutine, loop).result()
+ asyncio.run_coroutine_threadsafe(coroutine, async_controller_thread.loop).result()
def close_control_socket():
run_coroutine(control_socket.close())
+ if async_controller_thread.is_alive():
+ async_controller_thread.join()
try:
run_coroutine(authenticate(control_socket, password, chroot_path))
@@ -339,7 +342,7 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
if controller is None:
return control_socket
else:
- return controller(control_socket, is_authenticated = True)
+ return controller(control_socket, is_authenticated = True, started_async_controller_thread = async_controller_thread)
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
@@ -381,12 +384,6 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
print(CONNECT_MESSAGES['general_auth_failure'].format(error = exc))
close_control_socket()
return None
- finally:
- if asyncio_thread:
- loop.call_soon_threadsafe(loop.stop)
- if asyncio_thread.is_alive():
- asyncio_thread.join()
- loop.close()
async def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
@@ -994,7 +991,9 @@ async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.s
# next followed by authentication. Transparently reconnect if that happens.
if not protocolinfo_response or str(protocolinfo_response) == 'Authentication required.':
- controller.connect()
+ potential_coroutine = controller.connect()
+ if asyncio.iscoroutine(potential_coroutine):
+ await potential_coroutine
try:
protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
diff --git a/stem/control.py b/stem/control.py
index 20173a7a..21a89a5a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3933,9 +3933,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
instance.connect()
return instance
- def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False) -> None:
- self._async_controller_thread = _AsyncControllerThread()
- self._async_controller_thread.start()
+ def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
+ def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
+ if started_async_controller_thread:
+ self._async_controller_thread = started_async_controller_thread
+ else:
+ self._async_controller_thread = _AsyncControllerThread()
+ self._async_controller_thread.start()
self._asyncio_loop = self._async_controller_thread.loop
self._async_controller = self._init_async_controller(control_socket, is_authenticated)