[tor-commits] [stem/master] Fix `stem.connection.connect`

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020


commit a8e92bbb1532f25da843a29b1b771eeb80c477bd
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Mon Apr 27 18:57:02 2020 +0300

    Fix `stem.connection.connect`
---
 stem/connection.py | 39 +++++++++++++++++++--------------------
 stem/control.py    | 10 +++++++---
 2 files changed, 26 insertions(+), 23 deletions(-)

diff --git a/stem/connection.py b/stem/connection.py
index 3f7983de..6b985725 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -267,10 +267,17 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
   control_connection = None  # type: Optional[stem.socket.ControlSocket]
   error_msg = ''
 
+  async_controller_thread = stem.control._AsyncControllerThread()
+  async_controller_thread.start()
+
+  def connect_socket(socket):
+    asyncio.run_coroutine_threadsafe(socket.connect(), async_controller_thread.loop).result()
+
   if control_socket:
     if os.path.exists(control_socket):
       try:
         control_connection = stem.socket.ControlSocketFile(control_socket)
+        connect_socket(control_connection)
       except stem.SocketError as exc:
         error_msg = CONNECT_MESSAGES['unable_to_use_socket'].format(path = control_socket, error = exc)
     else:
@@ -284,6 +291,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
         control_connection = _connection_for_default_port(address)
       else:
         control_connection = stem.socket.ControlPort(address, int(port))
+      connect_socket(control_connection)
     except stem.SocketError as exc:
       error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
 
@@ -297,12 +305,14 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
       error_msg = CONNECT_MESSAGES['no_control_port'] if is_tor_running else CONNECT_MESSAGES['tor_isnt_running']
 
     print(error_msg)
+    if async_controller_thread.is_alive():
+      async_controller_thread.join()
     return None
 
-  return _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
+  return _connect_auth(control_connection, password, password_prompt, chroot_path, controller, async_controller_thread)
 
 
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
+def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]], async_controller_thread: 'threading.Thread') -> Any:
   """
   Helper for the connect_* functions that authenticates the socket and
   constructs the controller.
@@ -318,20 +328,13 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
   :returns: authenticated control connection, the type based on the controller argument
   """
 
-  if controller:
-    loop = controller._asyncio_loop
-    asyncio_thread = None
-  else:
-    loop = asyncio.new_event_loop()
-    asyncio_thread = threading.Thread(target=loop.run_forever, name='async_auth')
-    asyncio_thread.setDaemon(True)
-    asyncio_thread.start()
-
   def run_coroutine(coroutine):
-    asyncio.run_coroutine_threadsafe(coroutine, loop).result()
+    asyncio.run_coroutine_threadsafe(coroutine, async_controller_thread.loop).result()
 
   def close_control_socket():
     run_coroutine(control_socket.close())
+    if async_controller_thread.is_alive():
+      async_controller_thread.join()
 
   try:
     run_coroutine(authenticate(control_socket, password, chroot_path))
@@ -339,7 +342,7 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
     if controller is None:
       return control_socket
     else:
-      return controller(control_socket, is_authenticated = True)
+      return controller(control_socket, is_authenticated = True, started_async_controller_thread = async_controller_thread)
   except IncorrectSocketType:
     if isinstance(control_socket, stem.socket.ControlPort):
       print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
@@ -381,12 +384,6 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
     print(CONNECT_MESSAGES['general_auth_failure'].format(error = exc))
     close_control_socket()
     return None
-  finally:
-    if asyncio_thread:
-      loop.call_soon_threadsafe(loop.stop)
-      if asyncio_thread.is_alive():
-        asyncio_thread.join()
-      loop.close()
 
 
 async def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
@@ -994,7 +991,9 @@ async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.s
   # next followed by authentication. Transparently reconnect if that happens.
 
   if not protocolinfo_response or str(protocolinfo_response) == 'Authentication required.':
-    controller.connect()
+    potential_coroutine = controller.connect()
+    if asyncio.iscoroutine(potential_coroutine):
+      await potential_coroutine
 
     try:
       protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
diff --git a/stem/control.py b/stem/control.py
index 20173a7a..21a89a5a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3933,9 +3933,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     instance.connect()
     return instance
 
-  def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False) -> None:
-    self._async_controller_thread = _AsyncControllerThread()
-    self._async_controller_thread.start()
+  def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
+  def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
+    if started_async_controller_thread:
+      self._async_controller_thread = started_async_controller_thread
+    else:
+      self._async_controller_thread = _AsyncControllerThread()
+      self._async_controller_thread.start()
     self._asyncio_loop = self._async_controller_thread.loop
 
     self._async_controller = self._init_async_controller(control_socket, is_authenticated)





More information about the tor-commits mailing list