[tor-commits] [stem/master] Make it possible to use a function to connect to the async controller

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:59 UTC 2020


commit 459612e63181218d79d2a42ab5b0eebd0cb206bf
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Tue Apr 28 23:54:52 2020 +0300

    Make it possible to use a function to connect to the async controller
---
 stem/connection.py              | 76 ++++++++++++++++++++++++-----------------
 stem/control.py                 | 11 ++++--
 test/unit/connection/connect.py | 48 ++++++++++++--------------
 3 files changed, 74 insertions(+), 61 deletions(-)

diff --git a/stem/connection.py b/stem/connection.py
index 12330ca3..3d240070 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -133,6 +133,7 @@ import getpass
 import hashlib
 import hmac
 import os
+import threading
 
 import stem.control
 import stem.response
@@ -253,6 +254,31 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
 
   # TODO: change this function's API so we can provide a concrete type
 
+  if controller is None or not issubclass(controller, stem.control.Controller):
+    raise ValueError('Controller should be a stem.control.BaseController subclass.')
+
+  async_controller_thread = stem.control._AsyncControllerThread()
+  async_controller_thread.start()
+
+  connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
+  try:
+    connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result()
+    if connection is None and async_controller_thread.is_alive():
+      async_controller_thread.join()
+    return connection
+  except:
+    if async_controller_thread.is_alive():
+      async_controller_thread.join()
+    raise
+
+
+async def connect_async(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.AsyncController):
+  if controller and not issubclass(controller, stem.control.BaseController):
+    raise ValueError('The provided controller should be a stem.control.BaseController subclass.')
+  return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
+
+
+async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller):
   if control_port is None and control_socket is None:
     raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.')
   elif control_port:
@@ -266,17 +292,11 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
   control_connection = None  # type: Optional[stem.socket.ControlSocket]
   error_msg = ''
 
-  async_controller_thread = stem.control._AsyncControllerThread()
-  async_controller_thread.start()
-
-  def connect_socket(socket):
-    asyncio.run_coroutine_threadsafe(socket.connect(), async_controller_thread.loop).result()
-
   if control_socket:
     if os.path.exists(control_socket):
       try:
         control_connection = stem.socket.ControlSocketFile(control_socket)
-        connect_socket(control_connection)
+        await control_connection.connect()
       except stem.SocketError as exc:
         error_msg = CONNECT_MESSAGES['unable_to_use_socket'].format(path = control_socket, error = exc)
     else:
@@ -290,7 +310,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
         control_connection = _connection_for_default_port(address)
       else:
         control_connection = stem.socket.ControlPort(address, int(port))
-      connect_socket(control_connection)
+      await control_connection.connect()
     except stem.SocketError as exc:
       error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
 
@@ -304,14 +324,12 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
       error_msg = CONNECT_MESSAGES['no_control_port'] if is_tor_running else CONNECT_MESSAGES['tor_isnt_running']
 
     print(error_msg)
-    if async_controller_thread.is_alive():
-      async_controller_thread.join()
     return None
 
-  return _connect_auth(control_connection, password, password_prompt, chroot_path, controller, async_controller_thread)
+  return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
 
 
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]], async_controller_thread: 'threading.Thread') -> Any:
+async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any:
   """
   Helper for the connect_* functions that authenticates the socket and
   constructs the controller.
@@ -327,61 +345,55 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
   :returns: authenticated control connection, the type based on the controller argument
   """
 
-  def run_coroutine(coroutine):
-    asyncio.run_coroutine_threadsafe(coroutine, async_controller_thread.loop).result()
-
-  def close_control_socket():
-    run_coroutine(control_socket.close())
-    if async_controller_thread.is_alive():
-      async_controller_thread.join()
-
   try:
-    run_coroutine(authenticate(control_socket, password, chroot_path))
+    await authenticate(control_socket, password, chroot_path)
 
     if controller is None:
       return control_socket
-    else:
-      return controller(control_socket, is_authenticated = True, started_async_controller_thread = async_controller_thread)
+    elif issubclass(controller, stem.control.BaseController):
+      return controller(control_socket, is_authenticated = True)
+    elif issubclass(controller, stem.control.Controller):
+      return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread())
   except IncorrectSocketType:
     if isinstance(control_socket, stem.socket.ControlPort):
       print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port))
     else:
       print(CONNECT_MESSAGES['wrong_socket_type'])
 
-    close_control_socket()
+    await control_socket.close()
     return None
   except UnrecognizedAuthMethods as exc:
     print(CONNECT_MESSAGES['uncrcognized_auth_type'].format(auth_methods = ', '.join(exc.unknown_auth_methods)))
-    close_control_socket()
+    await control_socket.close()
     return None
   except IncorrectPassword:
     print(CONNECT_MESSAGES['incorrect_password'])
-    close_control_socket()
+    await control_socket.close()
     return None
   except MissingPassword:
     if password is not None:
-      close_control_socket()
+      await control_socket.close()
       raise ValueError(CONNECT_MESSAGES['missing_password_bug'])
 
     if password_prompt:
       try:
         password = getpass.getpass(CONNECT_MESSAGES['password_prompt'] + ' ')
       except KeyboardInterrupt:
-        close_control_socket()
+        await control_socket.close()
         return None
 
-      return _connect_auth(control_socket, password, password_prompt, chroot_path, controller, async_controller_thread)
+      return await _connect_auth(control_socket, password, password_prompt, chroot_path, controller)
     else:
       print(CONNECT_MESSAGES['needs_password'])
-      close_control_socket()
+      await control_socket.close()
       return None
   except UnreadableCookieFile as exc:
     print(CONNECT_MESSAGES['unreadable_cookie_file'].format(path = exc.cookie_path, issue = str(exc)))
-    close_control_socket()
+    await control_socket.close()
     return None
   except AuthenticationFailure as exc:
     print(CONNECT_MESSAGES['general_auth_failure'].format(error = exc))
-    close_control_socket()
+    await control_socket.close()
     return None
 
 
diff --git a/stem/control.py b/stem/control.py
index 21a89a5a..b2d2d9d7 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3946,10 +3946,15 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
     self._socket = self._async_controller._socket
 
   def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController':
-    async def init_async_controller():
-      return AsyncController(control_socket, is_authenticated)
+    # The asynchronous controller should be initialized in the thread where its
+    # methods will be executed.
+    if self._async_controller_thread != threading.current_thread():
+      async def init_async_controller() -> 'stem.control.AsyncController':
+        return AsyncController(control_socket, is_authenticated)
 
-    return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
+      return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
+
+    return AsyncController(control_socket, is_authenticated)
 
   def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
     return asyncio.run_coroutine_threadsafe(
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 8ba2770c..2112f678 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -1,7 +1,7 @@
 """
 Unit tests for the stem.connection.connect function.
 """
-import contextlib
+
 import io
 import unittest
 
@@ -11,7 +11,11 @@ import stem.socket
 
 from unittest.mock import Mock, patch
 
-from test.unit.async_util import coro_func_raising_exc, coro_func_returning_value
+from test.unit.async_util import (
+  async_test,
+  coro_func_raising_exc,
+  coro_func_returning_value,
+)
 
 
 class TestConnect(unittest.TestCase):
@@ -20,7 +24,7 @@ class TestConnect(unittest.TestCase):
   @patch('os.path.exists', Mock(return_value = True))
   @patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed')))
   @patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed')))
-  @patch('stem.connection._connect_auth', Mock())
+  @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
   def test_failue_with_the_default_endpoint(self, is_running_mock, stdout_mock):
     is_running_mock.return_value = False
     self._assert_connect_fails_with({}, stdout_mock, "Unable to connect to tor. Are you sure it's running?")
@@ -33,7 +37,7 @@ class TestConnect(unittest.TestCase):
   @patch('stem.util.system.is_running', Mock(return_value = True))
   @patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed')))
   @patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed')))
-  @patch('stem.connection._connect_auth', Mock())
+  @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
   def test_failure_with_a_custom_endpoint(self, path_exists_mock, stdout_mock):
     path_exists_mock.return_value = True
     self._assert_connect_fails_with({'control_port': ('127.0.0.1', 80), 'control_socket': None}, stdout_mock, "Unable to connect to 127.0.0.1:80: failed")
@@ -45,7 +49,7 @@ class TestConnect(unittest.TestCase):
 
   @patch('stem.socket.ControlPort')
   @patch('os.path.exists', Mock(return_value = False))
-  @patch('stem.connection._connect_auth', Mock())
+  @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
   def test_getting_a_control_port(self, port_mock):
     port_connect_mock = port_mock.return_value.connect
     port_connect_mock.side_effect = coro_func_returning_value(None)
@@ -59,7 +63,7 @@ class TestConnect(unittest.TestCase):
 
   @patch('stem.socket.ControlSocketFile')
   @patch('os.path.exists', Mock(return_value = True))
-  @patch('stem.connection._connect_auth', Mock())
+  @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None)))
   def test_getting_a_control_socket(self, socket_mock):
     socket_connect_mock = socket_mock.return_value.connect
     socket_connect_mock.side_effect = coro_func_returning_value(None)
@@ -92,21 +96,22 @@ class TestConnect(unittest.TestCase):
     self.assertEqual(msg, stdout_output.strip().lstrip('\x00'))
 
   @patch('stem.connection.authenticate')
-  def test_auth_success(self, authenticate_mock):
+  @async_test
+  async def test_auth_success(self, authenticate_mock):
     authenticate_mock.side_effect = coro_func_returning_value(None)
     control_socket = Mock()
 
-    with self._get_thread() as thread:
-      stem.connection._connect_auth(control_socket, None, False, None, None, thread)
-      authenticate_mock.assert_called_with(control_socket, None, None)
-      authenticate_mock.reset_mock()
+    await stem.connection._connect_auth(control_socket, None, False, None, None)
+    authenticate_mock.assert_called_with(control_socket, None, None)
+    authenticate_mock.reset_mock()
 
-      stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None, thread)
+    await stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None)
     authenticate_mock.assert_called_with(control_socket, 's3krit!!!', '/my/chroot')
 
   @patch('getpass.getpass')
   @patch('stem.connection.authenticate')
-  def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
+  @async_test
+  async def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
     control_socket = Mock()
 
     async def authenticate_mock_func(controller, password, *args):
@@ -120,8 +125,7 @@ class TestConnect(unittest.TestCase):
     authenticate_mock.side_effect = authenticate_mock_func
     getpass_mock.return_value = 'my_password'
 
-    with self._get_thread() as thread:
-      stem.connection._connect_auth(control_socket, None, True, None, None, thread)
+    await stem.connection._connect_auth(control_socket, None, True, None, None)
     authenticate_mock.assert_any_call(control_socket, None, None)
     authenticate_mock.assert_any_call(control_socket, 'my_password', None)
 
@@ -149,9 +153,9 @@ class TestConnect(unittest.TestCase):
     authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure'))
     self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
 
-  def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
-    with self._get_thread() as thread:
-      result = stem.connection._connect_auth(control_socket, None, False, None, None, thread)
+  @async_test
+  async def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
+    result = await stem.connection._connect_auth(control_socket, None, False, None, None)
 
     if result is not None:
       self.fail()  # _connect_auth() was successful
@@ -161,11 +165,3 @@ class TestConnect(unittest.TestCase):
 
     if msg not in stdout_output:
       self.fail("Expected...\n\n%s\n\n... which couldn't be found in...\n\n%s" % (msg, stdout_output))
-
-  @contextlib.contextmanager
-  def _get_thread(self):
-    thread = stem.control._AsyncControllerThread()
-    thread.start()
-    yield thread
-    if thread.is_alive():
-      thread.join()





More information about the tor-commits mailing list