commit 1db0e6b84e870a5f228f3a770daca542bdef5d4e Author: Illia Volochii illia.volochii@gmail.com Date: Sun Apr 26 22:31:12 2020 +0300
Fix unit tests --- test/unit/connection/authentication.py | 36 +++-- test/unit/connection/connect.py | 19 +-- test/unit/control/controller.py | 254 +++++++++++++++++++-------------- test/unit/response/control_message.py | 10 +- 4 files changed, 188 insertions(+), 131 deletions(-)
diff --git a/test/unit/connection/authentication.py b/test/unit/connection/authentication.py index f6241e0e..596fa50c 100644 --- a/test/unit/connection/authentication.py +++ b/test/unit/connection/authentication.py @@ -14,41 +14,52 @@ import unittest import stem.connection import test
-from unittest.mock import Mock, patch +from unittest.mock import patch
from stem.response import ControlMessage from stem.util import log +from test.unit.util.asynchronous import ( + async_test, + coro_func_raising_exc, + coro_func_returning_value, +)
class TestAuthenticate(unittest.TestCase): @patch('stem.connection.get_protocolinfo') - @patch('stem.connection.authenticate_none', Mock()) - def test_with_get_protocolinfo(self, get_protocolinfo_mock): + @patch('stem.connection.authenticate_none') + @async_test + async def test_with_get_protocolinfo(self, authenticate_none_mock, get_protocolinfo_mock): """ Tests the authenticate() function when it needs to make a get_protocolinfo. """
# tests where get_protocolinfo succeeds
+ authenticate_none_mock.side_effect = coro_func_returning_value(None) + protocolinfo_message = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO') protocolinfo_message.auth_methods = (stem.connection.AuthMethod.NONE, ) - get_protocolinfo_mock.return_value = protocolinfo_message + get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_message)
- stem.connection.authenticate(None) + await stem.connection.authenticate(None)
# tests where get_protocolinfo raises an exception
get_protocolinfo_mock.side_effect = stem.ProtocolError - self.assertRaises(stem.connection.IncorrectSocketType, stem.connection.authenticate, None) + with self.assertRaises(stem.connection.IncorrectSocketType): + await stem.connection.authenticate(None)
get_protocolinfo_mock.side_effect = stem.SocketError - self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None) + with self.assertRaises(stem.connection.AuthenticationFailure): + await stem.connection.authenticate(None)
@patch('stem.connection.authenticate_none') @patch('stem.connection.authenticate_password') @patch('stem.connection.authenticate_cookie') @patch('stem.connection.authenticate_safecookie') - def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock): + @async_test + async def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock): """ Does basic validation that all valid use cases for the PROTOCOLINFO input and dependent functions result in either success or a AuthenticationFailed @@ -133,15 +144,16 @@ class TestAuthenticate(unittest.TestCase): auth_mock, raised_exc = authenticate_safecookie_mock, auth_cookie_exc
if raised_exc: - auth_mock.side_effect = raised_exc + auth_mock.side_effect = coro_func_raising_exc(raised_exc) else: - auth_mock.side_effect = None + auth_mock.side_effect = coro_func_returning_value(None) expect_success = True
if expect_success: - stem.connection.authenticate(None, 'blah', None, protocolinfo) + await stem.connection.authenticate(None, 'blah', None, protocolinfo) else: - self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None, 'blah', None, protocolinfo) + with self.assertRaises(stem.connection.AuthenticationFailure): + await stem.connection.authenticate(None, 'blah', None, protocolinfo)
# revert logging back to normal stem_logger.setLevel(log.logging_level(log.TRACE)) diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py index 175a1ebd..d2a22f18 100644 --- a/test/unit/connection/connect.py +++ b/test/unit/connection/connect.py @@ -11,6 +11,8 @@ import stem.socket
from unittest.mock import Mock, patch
+from test.unit.util.asynchronous import coro_func_raising_exc, coro_func_returning_value +
class TestConnect(unittest.TestCase): @patch('sys.stdout', new_callable = io.StringIO) @@ -85,6 +87,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.connection.authenticate') def test_auth_success(self, authenticate_mock): + authenticate_mock.side_effect = coro_func_returning_value(None) control_socket = Mock()
stem.connection._connect_auth(control_socket, None, False, None, None) @@ -99,7 +102,7 @@ class TestConnect(unittest.TestCase): def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock): control_socket = Mock()
- def authenticate_mock_func(controller, password, *args): + async def authenticate_mock_func(controller, password, *args): if password is None: raise stem.connection.MissingPassword('no password') elif password == 'my_password': @@ -117,25 +120,25 @@ class TestConnect(unittest.TestCase): @patch('sys.stdout', new_callable = io.StringIO) @patch('stem.connection.authenticate') def test_auth_failure(self, authenticate_mock, stdout_mock): - control_socket = stem.socket.ControlPort(connect = False) + control_socket = stem.socket.ControlPort()
- authenticate_mock.side_effect = stem.connection.IncorrectSocketType('unable to connect to socket') + authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectSocketType('unable to connect to socket')) self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Please check in your torrc that 9051 is the ControlPort.')
- control_socket = stem.socket.ControlSocketFile(connect = False) + control_socket = stem.socket.ControlSocketFile()
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Are you sure the interface you specified belongs to')
- authenticate_mock.side_effect = stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy']) + authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy'])) self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Tor is using a type of authentication we do not recognize...\n\n telepathy')
- authenticate_mock.side_effect = stem.connection.IncorrectPassword('password rejected') + authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectPassword('password rejected')) self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Incorrect password')
- authenticate_mock.side_effect = stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False) + authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False)) self._assert_authenticate_fails_with(control_socket, stdout_mock, "We were unable to read tor's authentication cookie...\n\n Path: /tmp/my_cookie\n Issue: permission denied")
- authenticate_mock.side_effect = stem.connection.OpenAuthRejected('crazy failure') + authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure')) self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg): diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py index c0a07e2a..d09b5ca8 100644 --- a/test/unit/control/controller.py +++ b/test/unit/control/controller.py @@ -3,6 +3,7 @@ Unit tests for the stem.control module. The module's primarily exercised via integ tests, but a few bits lend themselves to unit testing. """
+import asyncio import datetime import io import unittest @@ -20,6 +21,11 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType from stem.response import ControlMessage from stem.exit_policy import ExitPolicy +from test.unit.util.asynchronous import ( + async_test, + coro_func_raising_exc, + coro_func_returning_value, +)
NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75' TEST_TIMESTAMP = 12345 @@ -36,8 +42,9 @@ class TestControl(unittest.TestCase): # When initially constructing a controller we need to suppress msg, so our # constructor's SETEVENTS requests pass.
- with patch('stem.control.BaseController.msg', Mock()): + with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))): self.controller = Controller(socket) + self.async_controller = self.controller._async_controller
self.circ_listener = Mock() self.controller.add_event_listener(self.circ_listener, EventType.CIRC) @@ -59,18 +66,24 @@ class TestControl(unittest.TestCase): for event in stem.control.EventType: self.assertTrue(stem.control.event_description(event) is not None)
- @patch('stem.control.Controller.msg') + @patch('stem.control.AsyncController.msg') def test_get_info(self, msg_mock): - msg_mock.return_value = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO') + message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO') + msg_mock.side_effect = coro_func_returning_value(message) self.assertEqual('hi right back!', self.controller.get_info('hello'))
- @patch('stem.control.Controller.msg') - def test_get_info_address_caching(self, msg_mock): - msg_mock.return_value = ControlMessage.from_str('551 Address unknown\r\n') + @patch('stem.control.AsyncController.msg') + @async_test + async def test_get_info_address_caching(self, msg_mock): + def set_message(*args): + message = ControlMessage.from_str(*args) + msg_mock.side_effect = coro_func_returning_value(message)
- self.assertEqual(None, self.controller._last_address_exc) + set_message('551 Address unknown\r\n') + + self.assertEqual(None, self.async_controller._last_address_exc) self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address') - self.assertEqual('Address unknown', str(self.controller._last_address_exc)) + self.assertEqual('Address unknown', str(self.async_controller._last_address_exc)) self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back @@ -80,27 +93,28 @@ class TestControl(unittest.TestCase):
# invalidates the cache, transitioning from no address to having one
- msg_mock.return_value = ControlMessage.from_str('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO') + set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO') self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address') - self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n')) + await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n')) self.assertEqual('17.2.89.80', self.controller.get_info('address'))
# invalidates the cache, transitioning from one address to another
- msg_mock.return_value = ControlMessage.from_str('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO') + set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO') self.assertEqual('17.2.89.80', self.controller.get_info('address')) - self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n')) + await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n')) self.assertEqual('80.89.2.17', self.controller.get_info('address'))
- @patch('stem.control.Controller.msg') - @patch('stem.control.Controller.get_conf') + @patch('stem.control.AsyncController.msg') + @patch('stem.control.AsyncController.get_conf') def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock): - msg_mock.return_value = ControlMessage.from_str('551 Not running in server mode\r\n') + message = ControlMessage.from_str('551 Not running in server mode\r\n') + msg_mock.side_effect = coro_func_returning_value(message) get_conf_mock.return_value = None
- self.assertEqual(None, self.controller._last_fingerprint_exc) + self.assertEqual(None, self.async_controller._last_fingerprint_exc) self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint') - self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc)) + self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc)) self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back @@ -114,7 +128,7 @@ class TestControl(unittest.TestCase): self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint') self.assertEqual(2, msg_mock.call_count)
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') def test_get_version(self, get_info_mock): """ Exercises the get_version() method. @@ -124,7 +138,7 @@ class TestControl(unittest.TestCase): # Use one version for first check. version_2_1 = '0.2.1.32' version_2_1_object = stem.version.Version(version_2_1) - get_info_mock.return_value = version_2_1 + get_info_mock.side_effect = coro_func_returning_value(version_2_1)
# Return a version with a cold cache. self.assertEqual(version_2_1_object, self.controller.get_version()) @@ -132,23 +146,23 @@ class TestControl(unittest.TestCase): # Use a different version for second check. version_2_2 = '0.2.2.39' version_2_2_object = stem.version.Version(version_2_2) - get_info_mock.return_value = version_2_2 + get_info_mock.side_effect = coro_func_returning_value(version_2_2)
# Return a version with a hot cache, so it will be the old version. self.assertEqual(version_2_1_object, self.controller.get_version())
# Turn off caching. - self.controller._is_caching_enabled = False + self.async_controller._is_caching_enabled = False # Return a version without caching, so it will be the new version. self.assertEqual(version_2_2_object, self.controller.get_version())
# Spec says the getinfo response may optionally be prefixed by 'Tor '. In # practice it doesn't but we should accept that. - get_info_mock.return_value = 'Tor 0.2.1.32' + get_info_mock.side_effect = coro_func_returning_value('Tor 0.2.1.32') self.assertEqual(version_2_1_object, self.controller.get_version())
# Raise an exception in the get_info() call. - get_info_mock.side_effect = InvalidArguments + get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
# Get a default value when the call fails. self.assertEqual( @@ -161,22 +175,24 @@ class TestControl(unittest.TestCase):
# Give a bad version. The stem.version.Version ValueError should bubble up. version_A_42 = '0.A.42.spam' - get_info_mock.return_value = version_A_42 - get_info_mock.side_effect = None + get_info_mock.side_effect = coro_func_returning_value(version_A_42) self.assertRaises(ValueError, self.controller.get_version) finally: # Turn caching back on before we leave. self.controller._is_caching_enabled = True
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') def test_get_exit_policy(self, get_info_mock): """ Exercises the get_exit_policy() method. """
- get_info_mock.side_effect = lambda param, default = None: { - 'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*', - }[param] + async def get_info_mock_side_effect(param, default = None): + return { + 'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*', + }[param] + + get_info_mock.side_effect = get_info_mock_side_effect
expected = ExitPolicy( 'reject *:25', @@ -194,8 +210,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
- @patch('stem.control.Controller.get_info') - @patch('stem.control.Controller.get_conf') + @patch('stem.control.AsyncController.get_info') + @patch('stem.control.AsyncController.get_conf') def test_get_ports(self, get_conf_mock, get_info_mock): """ Exercises the get_ports() and get_listeners() methods. @@ -204,12 +220,15 @@ class TestControl(unittest.TestCase): # Exercise as an old version of tor that doesn't support the 'GETINFO # net/listeners/*' options.
- get_info_mock.side_effect = InvalidArguments + get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments) + + async def get_conf_mock_side_effect(param, **kwargs): + return { + 'ControlPort': '9050', + 'ControlListenAddress': ['127.0.0.1'], + }[param]
- get_conf_mock.side_effect = lambda param, *args, **kwargs: { - 'ControlPort': '9050', - 'ControlListenAddress': ['127.0.0.1'], - }[param] + get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual([('127.0.0.1', 9050)], self.controller.get_listeners(Listener.CONTROL)) self.assertEqual([9050], self.controller.get_ports(Listener.CONTROL)) @@ -217,10 +236,13 @@ class TestControl(unittest.TestCase):
# non-local addresss
- get_conf_mock.side_effect = lambda param, *args, **kwargs: { - 'ControlPort': '9050', - 'ControlListenAddress': ['27.4.4.1'], - }[param] + async def get_conf_mock_side_effect(param, **kwargs): + return { + 'ControlPort': '9050', + 'ControlListenAddress': ['27.4.4.1'], + }[param] + + get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual([('27.4.4.1', 9050)], self.controller.get_listeners(Listener.CONTROL)) self.assertEqual([], self.controller.get_ports(Listener.CONTROL)) @@ -228,8 +250,8 @@ class TestControl(unittest.TestCase):
# exercise via the GETINFO option
- get_info_mock.side_effect = None - get_info_mock.return_value = '"127.0.0.1:1112" "127.0.0.1:1114"' + listeners = '"127.0.0.1:1112" "127.0.0.1:1114"' + get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual( [('127.0.0.1', 1112), ('127.0.0.1', 1114)], @@ -241,15 +263,16 @@ class TestControl(unittest.TestCase):
# with all localhost addresses, including a couple that aren't
- get_info_mock.side_effect = None - get_info_mock.return_value = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"' + listeners = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"' + get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual([1114, 1115, 1116, 1117], self.controller.get_ports(Listener.OR)) self.controller.clear_cache()
# IPv6 address
- get_info_mock.return_value = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"' + listeners = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"' + get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual( [('0.0.0.0', 9001), ('fe80:0000:0000:0000:0202:b3ff:fe1e:8329', 9001)], @@ -259,25 +282,28 @@ class TestControl(unittest.TestCase): # unix socket file
self.controller.clear_cache() - get_info_mock.return_value = '"unix:/tmp/tor/socket"' + get_info_mock.side_effect = coro_func_returning_value('"unix:/tmp/tor/socket"')
self.assertEqual([], self.controller.get_listeners(Listener.CONTROL)) self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') @patch('time.time', Mock(return_value = 1410723598.276578)) def test_get_accounting_stats(self, get_info_mock): """ Exercises the get_accounting_stats() method. """
- get_info_mock.side_effect = lambda param, **kwargs: { - 'accounting/enabled': '1', - 'accounting/hibernating': 'awake', - 'accounting/interval-end': '2014-09-14 19:41:00', - 'accounting/bytes': '4837 2050', - 'accounting/bytes-left': '102944 7440', - }[param] + async def get_info_mock_side_effect(param, **kwargs): + return { + 'accounting/enabled': '1', + 'accounting/hibernating': 'awake', + 'accounting/interval-end': '2014-09-14 19:41:00', + 'accounting/bytes': '4837 2050', + 'accounting/bytes-left': '102944 7440', + }[param] + + get_info_mock.side_effect = get_info_mock_side_effect
expected = stem.control.AccountingStats( 1410723598.276578, @@ -290,7 +316,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(expected, self.controller.get_accounting_stats())
- get_info_mock.side_effect = ControllerError('nope, too bad') + get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad')) self.assertRaises(ControllerError, self.controller.get_accounting_stats) self.assertEqual('my default', self.controller.get_accounting_stats('my default'))
@@ -303,7 +329,7 @@ class TestControl(unittest.TestCase): # use the handy mocked protocolinfo response
protocolinfo_msg = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO') - get_protocolinfo_mock.return_value = protocolinfo_msg + get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_msg)
# compare the str representation of these object, because the class # does not have, nor need, a direct comparison operator @@ -315,7 +341,7 @@ class TestControl(unittest.TestCase):
# raise an exception in the stem.connection.get_protocolinfo() call
- get_protocolinfo_mock.side_effect = ProtocolError + get_protocolinfo_mock.side_effect = coro_func_raising_exc(ProtocolError)
# get a default value when the call fails
@@ -338,7 +364,7 @@ class TestControl(unittest.TestCase): self.assertEqual(123, self.controller.get_user(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.Controller.get_info', Mock(return_value = 'atagar')) + @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar'))) def test_get_user_by_getinfo(self): """ Exercise the get_user() resolution via its getinfo option. @@ -366,7 +392,7 @@ class TestControl(unittest.TestCase): self.assertEqual(123, self.controller.get_pid(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.Controller.get_info', Mock(return_value = '321')) + @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321'))) def test_get_pid_by_getinfo(self): """ Exercise the get_pid() resolution via its getinfo option. @@ -375,14 +401,14 @@ class TestControl(unittest.TestCase): self.assertEqual(321, self.controller.get_pid())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.Controller.get_conf') + @patch('stem.control.AsyncController.get_conf') @patch('stem.control.open', create = True) def test_get_pid_by_pid_file(self, open_mock, get_conf_mock): """ Exercise the get_pid() resolution via a PidFile. """
- get_conf_mock.return_value = '/tmp/pid_file' + get_conf_mock.side_effect = coro_func_returning_value('/tmp/pid_file') open_mock.return_value = io.BytesIO(b'432')
self.assertEqual(432, self.controller.get_pid()) @@ -397,25 +423,25 @@ class TestControl(unittest.TestCase):
self.assertEqual(432, self.controller.get_pid())
- @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14'))) + @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) @patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False)) - @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') @patch('time.time', Mock(return_value = 1000.0)) def test_get_uptime_by_getinfo(self, getinfo_mock): """ Exercise the get_uptime() resolution via a GETINFO query. """
- getinfo_mock.return_value = '321' + getinfo_mock.side_effect = coro_func_returning_value('321') self.assertEqual(321.0, self.controller.get_uptime()) self.controller.clear_cache()
- getinfo_mock.return_value = 'abc' + getinfo_mock.side_effect = coro_func_returning_value('abc') self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True)) - @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.1.0.14'))) - @patch('stem.control.Controller.get_pid', Mock(return_value = '12')) + @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14')))) + @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12'))) @patch('stem.util.system.start_time', Mock(return_value = 5000.0)) @patch('time.time', Mock(return_value = 5200.0)) def test_get_uptime_by_process(self): @@ -425,7 +451,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(200.0, self.controller.get_uptime())
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') def test_get_network_status_for_ourselves(self, get_info_mock): """ Exercises the get_network_status() method for getting our own relay. @@ -433,7 +459,7 @@ class TestControl(unittest.TestCase):
# when there's an issue getting our fingerprint
- get_info_mock.side_effect = ControllerError('nope, too bad') + get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
exc_msg = 'Unable to determine our own fingerprint: nope, too bad' self.assertRaisesWith(ControllerError, exc_msg, self.controller.get_network_status) @@ -443,25 +469,29 @@ class TestControl(unittest.TestCase):
desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
- get_info_mock.side_effect = lambda param, **kwargs: { - 'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31', - 'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc, - }[param] + async def get_info_mock_side_effect(param, **kwargs): + return { + 'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31', + 'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc, + }[param] + + get_info_mock.side_effect = get_info_mock_side_effect
self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') def test_get_network_status_when_unavailable(self, get_info_mock): """ Exercises the get_network_status() method. """
- get_info_mock.side_effect = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290') + exc = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290') + get_info_mock.side_effect = coro_func_raising_exc(exc)
exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'" self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
- @patch('stem.control.Controller.get_info') + @patch('stem.control.AsyncController.get_info') def test_get_network_status(self, get_info_mock): """ Exercises the get_network_status() method. @@ -476,7 +506,7 @@ class TestControl(unittest.TestCase):
# always return the same router status entry
- get_info_mock.return_value = desc + get_info_mock.side_effect = coro_func_returning_value(desc)
# pretend to get the router status entry with its name
@@ -494,7 +524,7 @@ class TestControl(unittest.TestCase):
# raise an exception in the get_info() call
- get_info_mock.side_effect = InvalidArguments + get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
# get a default value when the call fails
@@ -507,22 +537,28 @@ class TestControl(unittest.TestCase):
self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
- @patch('stem.control.Controller.is_authenticated', Mock(return_value = True)) - @patch('stem.control.Controller._attach_listeners', Mock(return_value = ([], []))) - @patch('stem.control.Controller.get_version') - def test_add_event_listener(self, get_version_mock): + @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True)) + @patch('stem.control.AsyncController._attach_listeners') + @patch('stem.control.AsyncController.get_version') + def test_add_event_listener(self, get_version_mock, attach_listeners_mock): """ Exercises the add_event_listener and remove_event_listener methods. """
+ attach_listeners_mock.side_effect = coro_func_returning_value(([], [])) + + def set_version(version_str): + version = stem.version.Version(version_str) + get_version_mock.side_effect = coro_func_returning_value(version) + # set up for failure to create any events
- get_version_mock.return_value = stem.version.Version('0.1.0.14') + set_version('0.1.0.14') self.assertRaises(InvalidRequest, self.controller.add_event_listener, Mock(), EventType.BW)
# set up to only fail newer events
- get_version_mock.return_value = stem.version.Version('0.2.0.35') + set_version('0.2.0.35')
# EventType.BW is one of the earliest events
@@ -551,7 +587,7 @@ class TestControl(unittest.TestCase): event thread. """
- self.circ_listener.side_effect = ValueError('boom') + self.circ_listener.side_effect = coro_func_raising_exc(ValueError('boom'))
self._emit_event(CIRC_EVENT) self.circ_listener.assert_called_once_with(CIRC_EVENT) @@ -582,10 +618,10 @@ class TestControl(unittest.TestCase): self._emit_event(BW_EVENT) self.bw_listener.assert_called_once_with(BW_EVENT)
- @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14'))) - @patch('stem.control.Controller.msg', Mock(return_value = ControlMessage.from_str('250 OK\r\n'))) - @patch('stem.control.Controller.add_event_listener', Mock()) - @patch('stem.control.Controller.remove_event_listener', Mock()) + @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14')))) + @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n')))) + @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None))) + @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None))) def test_timeout(self): """ Methods that have an 'await' argument also have an optional timeout. Check @@ -607,8 +643,9 @@ class TestControl(unittest.TestCase): )
response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams]) + get_info_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.Controller.get_info', Mock(return_value = response)): + with patch('stem.control.AsyncController.get_info', get_info_mock): streams = self.controller.get_streams() self.assertEqual(len(valid_streams), len(streams))
@@ -627,8 +664,9 @@ class TestControl(unittest.TestCase): # instance, it's already open).
response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n') + msg_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.Controller.msg', Mock(return_value = response)): + with patch('stem.control.AsyncController.msg', msg_mock): self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
def test_parse_circ_path(self): @@ -671,7 +709,7 @@ class TestControl(unittest.TestCase): for test_input in malformed_inputs: self.assertRaises(ProtocolError, _parse_circ_path, test_input)
- @patch('stem.control.Controller.get_conf') + @patch('stem.control.AsyncController.get_conf') def test_get_effective_rate(self, get_conf_mock): """ Exercise the get_effective_rate() method. @@ -679,18 +717,21 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- get_conf_mock.side_effect = lambda param, *args, **kwargs: { - 'BandwidthRate': '1073741824', - 'BandwidthBurst': '1073741824', - 'RelayBandwidthRate': '0', - 'RelayBandwidthBurst': '0', - 'MaxAdvertisedBandwidth': '1073741824', - }[param] + async def get_conf_mock_side_effect(param, **kwargs): + return { + 'BandwidthRate': '1073741824', + 'BandwidthBurst': '1073741824', + 'RelayBandwidthRate': '0', + 'RelayBandwidthBurst': '0', + 'MaxAdvertisedBandwidth': '1073741824', + }[param] + + get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual(1073741824, self.controller.get_effective_rate()) self.assertEqual(1073741824, self.controller.get_effective_rate(burst = True))
- get_conf_mock.side_effect = ControllerError('nope, too bad') + get_conf_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad')) self.assertRaises(ControllerError, self.controller.get_effective_rate) self.assertEqual('my_default', self.controller.get_effective_rate('my_default'))
@@ -705,18 +746,19 @@ class TestControl(unittest.TestCase): # with its work is to join on the thread.
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)): - with patch('stem.control.Controller.is_alive') as is_alive_mock: + with patch('stem.control.AsyncController.is_alive') as is_alive_mock: is_alive_mock.return_value = True - self.controller._create_loop_tasks() + loop = self.controller._asyncio_loop + asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try: # Converting an event back into an uncast ControlMessage, then feeding it # into our controller's event queue.
uncast_event = ControlMessage.from_str(event.raw_content()) - self.controller._event_queue.put(uncast_event) - self.controller._event_notice.set() - self.controller._event_queue.join() # block until the event is consumed + event_queue = self.async_controller._event_queue + asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result() + asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result() # block until the event is consumed finally: is_alive_mock.return_value = False - self.controller._close() + asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result() diff --git a/test/unit/response/control_message.py b/test/unit/response/control_message.py index abf5debf..414dcf63 100644 --- a/test/unit/response/control_message.py +++ b/test/unit/response/control_message.py @@ -126,7 +126,7 @@ class TestControlMessage(unittest.TestCase): # replace the CRLF for the line infonames_lines[index] = line.rstrip('\r\n') + '\n' test_socket_file = io.BytesIO(stem.util.str_tools._to_bytes(''.join(infonames_lines))) - self.assertRaises(stem.ProtocolError, stem.socket.recv_message, test_socket_file) + self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, test_socket_file)
# puts the CRLF back infonames_lines[index] = infonames_lines[index].rstrip('\n') + '\r\n' @@ -151,8 +151,8 @@ class TestControlMessage(unittest.TestCase): # - this is part of the message prefix # - this is disrupting the line ending
- self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input))) - self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input))) + self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input))) + self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input))) else: # otherwise the data will be malformed, but this goes undetected self._assert_message_parses(removal_test_input) @@ -166,7 +166,7 @@ class TestControlMessage(unittest.TestCase):
control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) control_socket_file = control_socket.makefile() - self.assertRaises(stem.SocketClosed, stem.socket.recv_message, control_socket_file) + self.assertRaises(stem.SocketClosed, stem.socket.recv_message_from_bytes_io, control_socket_file)
def test_equality(self): msg = stem.response.ControlMessage.from_str(EVENT_BW) @@ -200,7 +200,7 @@ class TestControlMessage(unittest.TestCase): stem.response.ControlMessage for the given input """
- message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply))) + message = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply)))
# checks that the raw_content equals the input value self.assertEqual(controller_reply, message.raw_content())