commit 459612e63181218d79d2a42ab5b0eebd0cb206bf Author: Illia Volochii illia.volochii@gmail.com Date: Tue Apr 28 23:54:52 2020 +0300
Make it possible to use a function to connect to the async controller --- stem/connection.py | 76 ++++++++++++++++++++++++----------------- stem/control.py | 11 ++++-- test/unit/connection/connect.py | 48 ++++++++++++-------------- 3 files changed, 74 insertions(+), 61 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py index 12330ca3..3d240070 100644 --- a/stem/connection.py +++ b/stem/connection.py @@ -133,6 +133,7 @@ import getpass import hashlib import hmac import os +import threading
import stem.control import stem.response @@ -253,6 +254,31 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
# TODO: change this function's API so we can provide a concrete type
+ if controller is None or not issubclass(controller, stem.control.Controller): + raise ValueError('Controller should be a stem.control.BaseController subclass.') + + async_controller_thread = stem.control._AsyncControllerThread() + async_controller_thread.start() + + connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller) + try: + connection = asyncio.run_coroutine_threadsafe(connect_coroutine, async_controller_thread.loop).result() + if connection is None and async_controller_thread.is_alive(): + async_controller_thread.join() + return connection + except: + if async_controller_thread.is_alive(): + async_controller_thread.join() + raise + + +async def connect_async(control_port = ('127.0.0.1', 'default'), control_socket = '/var/run/tor/control', password = None, password_prompt = False, chroot_path = None, controller = stem.control.AsyncController): + if controller and not issubclass(controller, stem.control.BaseController): + raise ValueError('The provided controller should be a stem.control.BaseController subclass.') + return await _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller) + + +async def _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller): if control_port is None and control_socket is None: raise ValueError('Neither a control port nor control socket were provided. Nothing to connect to.') elif control_port: @@ -266,17 +292,11 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') control_connection = None # type: Optional[stem.socket.ControlSocket] error_msg = ''
- async_controller_thread = stem.control._AsyncControllerThread() - async_controller_thread.start() - - def connect_socket(socket): - asyncio.run_coroutine_threadsafe(socket.connect(), async_controller_thread.loop).result() - if control_socket: if os.path.exists(control_socket): try: control_connection = stem.socket.ControlSocketFile(control_socket) - connect_socket(control_connection) + await control_connection.connect() except stem.SocketError as exc: error_msg = CONNECT_MESSAGES['unable_to_use_socket'].format(path = control_socket, error = exc) else: @@ -290,7 +310,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') control_connection = _connection_for_default_port(address) else: control_connection = stem.socket.ControlPort(address, int(port)) - connect_socket(control_connection) + await control_connection.connect() except stem.SocketError as exc: error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
@@ -304,14 +324,12 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default') error_msg = CONNECT_MESSAGES['no_control_port'] if is_tor_running else CONNECT_MESSAGES['tor_isnt_running']
print(error_msg) - if async_controller_thread.is_alive(): - async_controller_thread.join() return None
- return _connect_auth(control_connection, password, password_prompt, chroot_path, controller, async_controller_thread) + return await _connect_auth(control_connection, password, password_prompt, chroot_path, controller)
-def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]], async_controller_thread: 'threading.Thread') -> Any: +async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, password_prompt: bool, chroot_path: str, controller: Optional[Type[stem.control.BaseController]]) -> Any: """ Helper for the connect_* functions that authenticates the socket and constructs the controller. @@ -327,61 +345,55 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass :returns: authenticated control connection, the type based on the controller argument """
- def run_coroutine(coroutine): - asyncio.run_coroutine_threadsafe(coroutine, async_controller_thread.loop).result() - - def close_control_socket(): - run_coroutine(control_socket.close()) - if async_controller_thread.is_alive(): - async_controller_thread.join() - try: - run_coroutine(authenticate(control_socket, password, chroot_path)) + await authenticate(control_socket, password, chroot_path)
if controller is None: return control_socket - else: - return controller(control_socket, is_authenticated = True, started_async_controller_thread = async_controller_thread) + elif issubclass(controller, stem.control.BaseController): + return controller(control_socket, is_authenticated = True) + elif issubclass(controller, stem.control.Controller): + return controller(control_socket, is_authenticated = True, started_async_controller_thread = threading.current_thread()) except IncorrectSocketType: if isinstance(control_socket, stem.socket.ControlPort): print(CONNECT_MESSAGES['wrong_port_type'].format(port = control_socket.port)) else: print(CONNECT_MESSAGES['wrong_socket_type'])
- close_control_socket() + await control_socket.close() return None except UnrecognizedAuthMethods as exc: print(CONNECT_MESSAGES['uncrcognized_auth_type'].format(auth_methods = ', '.join(exc.unknown_auth_methods))) - close_control_socket() + await control_socket.close() return None except IncorrectPassword: print(CONNECT_MESSAGES['incorrect_password']) - close_control_socket() + await control_socket.close() return None except MissingPassword: if password is not None: - close_control_socket() + await control_socket.close() raise ValueError(CONNECT_MESSAGES['missing_password_bug'])
if password_prompt: try: password = getpass.getpass(CONNECT_MESSAGES['password_prompt'] + ' ') except KeyboardInterrupt: - close_control_socket() + await control_socket.close() return None
- return _connect_auth(control_socket, password, password_prompt, chroot_path, controller, async_controller_thread) + return await _connect_auth(control_socket, password, password_prompt, chroot_path, controller) else: print(CONNECT_MESSAGES['needs_password']) - close_control_socket() + await control_socket.close() return None except UnreadableCookieFile as exc: print(CONNECT_MESSAGES['unreadable_cookie_file'].format(path = exc.cookie_path, issue = str(exc))) - close_control_socket() + await control_socket.close() return None except AuthenticationFailure as exc: print(CONNECT_MESSAGES['general_auth_failure'].format(error = exc)) - close_control_socket() + await control_socket.close() return None
diff --git a/stem/control.py b/stem/control.py index 21a89a5a..b2d2d9d7 100644 --- a/stem/control.py +++ b/stem/control.py @@ -3946,10 +3946,15 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin): self._socket = self._async_controller._socket
def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController': - async def init_async_controller(): - return AsyncController(control_socket, is_authenticated) + # The asynchronous controller should be initialized in the thread where its + # methods will be executed. + if self._async_controller_thread != threading.current_thread(): + async def init_async_controller() -> 'stem.control.AsyncController': + return AsyncController(control_socket, is_authenticated)
- return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result() + return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result() + + return AsyncController(control_socket, is_authenticated)
def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: return asyncio.run_coroutine_threadsafe( diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py index 8ba2770c..2112f678 100644 --- a/test/unit/connection/connect.py +++ b/test/unit/connection/connect.py @@ -1,7 +1,7 @@ """ Unit tests for the stem.connection.connect function. """ -import contextlib + import io import unittest
@@ -11,7 +11,11 @@ import stem.socket
from unittest.mock import Mock, patch
-from test.unit.async_util import coro_func_raising_exc, coro_func_returning_value +from test.unit.async_util import ( + async_test, + coro_func_raising_exc, + coro_func_returning_value, +)
class TestConnect(unittest.TestCase): @@ -20,7 +24,7 @@ class TestConnect(unittest.TestCase): @patch('os.path.exists', Mock(return_value = True)) @patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed'))) @patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed'))) - @patch('stem.connection._connect_auth', Mock()) + @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None))) def test_failue_with_the_default_endpoint(self, is_running_mock, stdout_mock): is_running_mock.return_value = False self._assert_connect_fails_with({}, stdout_mock, "Unable to connect to tor. Are you sure it's running?") @@ -33,7 +37,7 @@ class TestConnect(unittest.TestCase): @patch('stem.util.system.is_running', Mock(return_value = True)) @patch('stem.socket.ControlSocketFile', Mock(side_effect = stem.SocketError('failed'))) @patch('stem.socket.ControlPort', Mock(side_effect = stem.SocketError('failed'))) - @patch('stem.connection._connect_auth', Mock()) + @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None))) def test_failure_with_a_custom_endpoint(self, path_exists_mock, stdout_mock): path_exists_mock.return_value = True self._assert_connect_fails_with({'control_port': ('127.0.0.1', 80), 'control_socket': None}, stdout_mock, "Unable to connect to 127.0.0.1:80: failed") @@ -45,7 +49,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.socket.ControlPort') @patch('os.path.exists', Mock(return_value = False)) - @patch('stem.connection._connect_auth', Mock()) + @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None))) def test_getting_a_control_port(self, port_mock): port_connect_mock = port_mock.return_value.connect port_connect_mock.side_effect = coro_func_returning_value(None) @@ -59,7 +63,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.socket.ControlSocketFile') @patch('os.path.exists', Mock(return_value = True)) - @patch('stem.connection._connect_auth', Mock()) + @patch('stem.connection._connect_auth', Mock(side_effect = coro_func_returning_value(None))) def test_getting_a_control_socket(self, socket_mock): socket_connect_mock = socket_mock.return_value.connect socket_connect_mock.side_effect = coro_func_returning_value(None) @@ -92,21 +96,22 @@ class TestConnect(unittest.TestCase): self.assertEqual(msg, stdout_output.strip().lstrip('\x00'))
@patch('stem.connection.authenticate') - def test_auth_success(self, authenticate_mock): + @async_test + async def test_auth_success(self, authenticate_mock): authenticate_mock.side_effect = coro_func_returning_value(None) control_socket = Mock()
- with self._get_thread() as thread: - stem.connection._connect_auth(control_socket, None, False, None, None, thread) - authenticate_mock.assert_called_with(control_socket, None, None) - authenticate_mock.reset_mock() + await stem.connection._connect_auth(control_socket, None, False, None, None) + authenticate_mock.assert_called_with(control_socket, None, None) + authenticate_mock.reset_mock()
- stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None, thread) + await stem.connection._connect_auth(control_socket, 's3krit!!!', False, '/my/chroot', None) authenticate_mock.assert_called_with(control_socket, 's3krit!!!', '/my/chroot')
@patch('getpass.getpass') @patch('stem.connection.authenticate') - def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock): + @async_test + async def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock): control_socket = Mock()
async def authenticate_mock_func(controller, password, *args): @@ -120,8 +125,7 @@ class TestConnect(unittest.TestCase): authenticate_mock.side_effect = authenticate_mock_func getpass_mock.return_value = 'my_password'
- with self._get_thread() as thread: - stem.connection._connect_auth(control_socket, None, True, None, None, thread) + await stem.connection._connect_auth(control_socket, None, True, None, None) authenticate_mock.assert_any_call(control_socket, None, None) authenticate_mock.assert_any_call(control_socket, 'my_password', None)
@@ -149,9 +153,9 @@ class TestConnect(unittest.TestCase): authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure')) self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
- def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg): - with self._get_thread() as thread: - result = stem.connection._connect_auth(control_socket, None, False, None, None, thread) + @async_test + async def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg): + result = await stem.connection._connect_auth(control_socket, None, False, None, None)
if result is not None: self.fail() # _connect_auth() was successful @@ -161,11 +165,3 @@ class TestConnect(unittest.TestCase):
if msg not in stdout_output: self.fail("Expected...\n\n%s\n\n... which couldn't be found in...\n\n%s" % (msg, stdout_output)) - - @contextlib.contextmanager - def _get_thread(self): - thread = stem.control._AsyncControllerThread() - thread.start() - yield thread - if thread.is_alive(): - thread.join()
tor-commits@lists.torproject.org