commit 1db0e6b84e870a5f228f3a770daca542bdef5d4e
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun Apr 26 22:31:12 2020 +0300
Fix unit tests
---
test/unit/connection/authentication.py | 36 +++--
test/unit/connection/connect.py | 19 +--
test/unit/control/controller.py | 254 +++++++++++++++++++--------------
test/unit/response/control_message.py | 10 +-
4 files changed, 188 insertions(+), 131 deletions(-)
diff --git a/test/unit/connection/authentication.py b/test/unit/connection/authentication.py
index f6241e0e..596fa50c 100644
--- a/test/unit/connection/authentication.py
+++ b/test/unit/connection/authentication.py
@@ -14,41 +14,52 @@ import unittest
import stem.connection
import test
-from unittest.mock import Mock, patch
+from unittest.mock import patch
from stem.response import ControlMessage
from stem.util import log
+from test.unit.util.asynchronous import (
+ async_test,
+ coro_func_raising_exc,
+ coro_func_returning_value,
+)
class TestAuthenticate(unittest.TestCase):
@patch('stem.connection.get_protocolinfo')
- @patch('stem.connection.authenticate_none', Mock())
- def test_with_get_protocolinfo(self, get_protocolinfo_mock):
+ @patch('stem.connection.authenticate_none')
+ @async_test
+ async def test_with_get_protocolinfo(self, authenticate_none_mock, get_protocolinfo_mock):
"""
Tests the authenticate() function when it needs to make a get_protocolinfo.
"""
# tests where get_protocolinfo succeeds
+ authenticate_none_mock.side_effect = coro_func_returning_value(None)
+
protocolinfo_message = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO')
protocolinfo_message.auth_methods = (stem.connection.AuthMethod.NONE, )
- get_protocolinfo_mock.return_value = protocolinfo_message
+ get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_message)
- stem.connection.authenticate(None)
+ await stem.connection.authenticate(None)
# tests where get_protocolinfo raises an exception
get_protocolinfo_mock.side_effect = stem.ProtocolError
- self.assertRaises(stem.connection.IncorrectSocketType, stem.connection.authenticate, None)
+ with self.assertRaises(stem.connection.IncorrectSocketType):
+ await stem.connection.authenticate(None)
get_protocolinfo_mock.side_effect = stem.SocketError
- self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None)
+ with self.assertRaises(stem.connection.AuthenticationFailure):
+ await stem.connection.authenticate(None)
@patch('stem.connection.authenticate_none')
@patch('stem.connection.authenticate_password')
@patch('stem.connection.authenticate_cookie')
@patch('stem.connection.authenticate_safecookie')
- def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock):
+ @async_test
+ async def test_all_use_cases(self, authenticate_safecookie_mock, authenticate_cookie_mock, authenticate_password_mock, authenticate_none_mock):
"""
Does basic validation that all valid use cases for the PROTOCOLINFO input
and dependent functions result in either success or a AuthenticationFailed
@@ -133,15 +144,16 @@ class TestAuthenticate(unittest.TestCase):
auth_mock, raised_exc = authenticate_safecookie_mock, auth_cookie_exc
if raised_exc:
- auth_mock.side_effect = raised_exc
+ auth_mock.side_effect = coro_func_raising_exc(raised_exc)
else:
- auth_mock.side_effect = None
+ auth_mock.side_effect = coro_func_returning_value(None)
expect_success = True
if expect_success:
- stem.connection.authenticate(None, 'blah', None, protocolinfo)
+ await stem.connection.authenticate(None, 'blah', None, protocolinfo)
else:
- self.assertRaises(stem.connection.AuthenticationFailure, stem.connection.authenticate, None, 'blah', None, protocolinfo)
+ with self.assertRaises(stem.connection.AuthenticationFailure):
+ await stem.connection.authenticate(None, 'blah', None, protocolinfo)
# revert logging back to normal
stem_logger.setLevel(log.logging_level(log.TRACE))
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 175a1ebd..d2a22f18 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -11,6 +11,8 @@ import stem.socket
from unittest.mock import Mock, patch
+from test.unit.util.asynchronous import coro_func_raising_exc, coro_func_returning_value
+
class TestConnect(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@@ -85,6 +87,7 @@ class TestConnect(unittest.TestCase):
@patch('stem.connection.authenticate')
def test_auth_success(self, authenticate_mock):
+ authenticate_mock.side_effect = coro_func_returning_value(None)
control_socket = Mock()
stem.connection._connect_auth(control_socket, None, False, None, None)
@@ -99,7 +102,7 @@ class TestConnect(unittest.TestCase):
def test_auth_success_with_password_prompt(self, authenticate_mock, getpass_mock):
control_socket = Mock()
- def authenticate_mock_func(controller, password, *args):
+ async def authenticate_mock_func(controller, password, *args):
if password is None:
raise stem.connection.MissingPassword('no password')
elif password == 'my_password':
@@ -117,25 +120,25 @@ class TestConnect(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@patch('stem.connection.authenticate')
def test_auth_failure(self, authenticate_mock, stdout_mock):
- control_socket = stem.socket.ControlPort(connect = False)
+ control_socket = stem.socket.ControlPort()
- authenticate_mock.side_effect = stem.connection.IncorrectSocketType('unable to connect to socket')
+ authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectSocketType('unable to connect to socket'))
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Please check in your torrc that 9051 is the ControlPort.')
- control_socket = stem.socket.ControlSocketFile(connect = False)
+ control_socket = stem.socket.ControlSocketFile()
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Are you sure the interface you specified belongs to')
- authenticate_mock.side_effect = stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy'])
+ authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnrecognizedAuthMethods('unable to connect', ['telepathy']))
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Tor is using a type of authentication we do not recognize...\n\n telepathy')
- authenticate_mock.side_effect = stem.connection.IncorrectPassword('password rejected')
+ authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.IncorrectPassword('password rejected'))
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Incorrect password')
- authenticate_mock.side_effect = stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False)
+ authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.UnreadableCookieFile('permission denied', '/tmp/my_cookie', False))
self._assert_authenticate_fails_with(control_socket, stdout_mock, "We were unable to read tor's authentication cookie...\n\n Path: /tmp/my_cookie\n Issue: permission denied")
- authenticate_mock.side_effect = stem.connection.OpenAuthRejected('crazy failure')
+ authenticate_mock.side_effect = coro_func_raising_exc(stem.connection.OpenAuthRejected('crazy failure'))
self._assert_authenticate_fails_with(control_socket, stdout_mock, 'Unable to authenticate: crazy failure')
def _assert_authenticate_fails_with(self, control_socket, stdout_mock, msg):
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index c0a07e2a..d09b5ca8 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -3,6 +3,7 @@ Unit tests for the stem.control module. The module's primarily exercised via
integ tests, but a few bits lend themselves to unit testing.
"""
+import asyncio
import datetime
import io
import unittest
@@ -20,6 +21,11 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
from stem.response import ControlMessage
from stem.exit_policy import ExitPolicy
+from test.unit.util.asynchronous import (
+ async_test,
+ coro_func_raising_exc,
+ coro_func_returning_value,
+)
NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75'
TEST_TIMESTAMP = 12345
@@ -36,8 +42,9 @@ class TestControl(unittest.TestCase):
# When initially constructing a controller we need to suppress msg, so our
# constructor's SETEVENTS requests pass.
- with patch('stem.control.BaseController.msg', Mock()):
+ with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
self.controller = Controller(socket)
+ self.async_controller = self.controller._async_controller
self.circ_listener = Mock()
self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -59,18 +66,24 @@ class TestControl(unittest.TestCase):
for event in stem.control.EventType:
self.assertTrue(stem.control.event_description(event) is not None)
- @patch('stem.control.Controller.msg')
+ @patch('stem.control.AsyncController.msg')
def test_get_info(self, msg_mock):
- msg_mock.return_value = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
+ message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
+ msg_mock.side_effect = coro_func_returning_value(message)
self.assertEqual('hi right back!', self.controller.get_info('hello'))
- @patch('stem.control.Controller.msg')
- def test_get_info_address_caching(self, msg_mock):
- msg_mock.return_value = ControlMessage.from_str('551 Address unknown\r\n')
+ @patch('stem.control.AsyncController.msg')
+ @async_test
+ async def test_get_info_address_caching(self, msg_mock):
+ def set_message(*args):
+ message = ControlMessage.from_str(*args)
+ msg_mock.side_effect = coro_func_returning_value(message)
- self.assertEqual(None, self.controller._last_address_exc)
+ set_message('551 Address unknown\r\n')
+
+ self.assertEqual(None, self.async_controller._last_address_exc)
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- self.assertEqual('Address unknown', str(self.controller._last_address_exc))
+ self.assertEqual('Address unknown', str(self.async_controller._last_address_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -80,27 +93,28 @@ class TestControl(unittest.TestCase):
# invalidates the cache, transitioning from no address to having one
- msg_mock.return_value = ControlMessage.from_str('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
+ set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
+ await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
# invalidates the cache, transitioning from one address to another
- msg_mock.return_value = ControlMessage.from_str('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
+ set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
- self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
+ await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
self.assertEqual('80.89.2.17', self.controller.get_info('address'))
- @patch('stem.control.Controller.msg')
- @patch('stem.control.Controller.get_conf')
+ @patch('stem.control.AsyncController.msg')
+ @patch('stem.control.AsyncController.get_conf')
def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock):
- msg_mock.return_value = ControlMessage.from_str('551 Not running in server mode\r\n')
+ message = ControlMessage.from_str('551 Not running in server mode\r\n')
+ msg_mock.side_effect = coro_func_returning_value(message)
get_conf_mock.return_value = None
- self.assertEqual(None, self.controller._last_fingerprint_exc)
+ self.assertEqual(None, self.async_controller._last_fingerprint_exc)
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
- self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc))
+ self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -114,7 +128,7 @@ class TestControl(unittest.TestCase):
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
self.assertEqual(2, msg_mock.call_count)
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
def test_get_version(self, get_info_mock):
"""
Exercises the get_version() method.
@@ -124,7 +138,7 @@ class TestControl(unittest.TestCase):
# Use one version for first check.
version_2_1 = '0.2.1.32'
version_2_1_object = stem.version.Version(version_2_1)
- get_info_mock.return_value = version_2_1
+ get_info_mock.side_effect = coro_func_returning_value(version_2_1)
# Return a version with a cold cache.
self.assertEqual(version_2_1_object, self.controller.get_version())
@@ -132,23 +146,23 @@ class TestControl(unittest.TestCase):
# Use a different version for second check.
version_2_2 = '0.2.2.39'
version_2_2_object = stem.version.Version(version_2_2)
- get_info_mock.return_value = version_2_2
+ get_info_mock.side_effect = coro_func_returning_value(version_2_2)
# Return a version with a hot cache, so it will be the old version.
self.assertEqual(version_2_1_object, self.controller.get_version())
# Turn off caching.
- self.controller._is_caching_enabled = False
+ self.async_controller._is_caching_enabled = False
# Return a version without caching, so it will be the new version.
self.assertEqual(version_2_2_object, self.controller.get_version())
# Spec says the getinfo response may optionally be prefixed by 'Tor '. In
# practice it doesn't but we should accept that.
- get_info_mock.return_value = 'Tor 0.2.1.32'
+ get_info_mock.side_effect = coro_func_returning_value('Tor 0.2.1.32')
self.assertEqual(version_2_1_object, self.controller.get_version())
# Raise an exception in the get_info() call.
- get_info_mock.side_effect = InvalidArguments
+ get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
# Get a default value when the call fails.
self.assertEqual(
@@ -161,22 +175,24 @@ class TestControl(unittest.TestCase):
# Give a bad version. The stem.version.Version ValueError should bubble up.
version_A_42 = '0.A.42.spam'
- get_info_mock.return_value = version_A_42
- get_info_mock.side_effect = None
+ get_info_mock.side_effect = coro_func_returning_value(version_A_42)
self.assertRaises(ValueError, self.controller.get_version)
finally:
# Turn caching back on before we leave.
self.controller._is_caching_enabled = True
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
def test_get_exit_policy(self, get_info_mock):
"""
Exercises the get_exit_policy() method.
"""
- get_info_mock.side_effect = lambda param, default = None: {
- 'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
- }[param]
+ async def get_info_mock_side_effect(param, default = None):
+ return {
+ 'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
+ }[param]
+
+ get_info_mock.side_effect = get_info_mock_side_effect
expected = ExitPolicy(
'reject *:25',
@@ -194,8 +210,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
- @patch('stem.control.Controller.get_info')
- @patch('stem.control.Controller.get_conf')
+ @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.AsyncController.get_conf')
def test_get_ports(self, get_conf_mock, get_info_mock):
"""
Exercises the get_ports() and get_listeners() methods.
@@ -204,12 +220,15 @@ class TestControl(unittest.TestCase):
# Exercise as an old version of tor that doesn't support the 'GETINFO
# net/listeners/*' options.
- get_info_mock.side_effect = InvalidArguments
+ get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
+
+ async def get_conf_mock_side_effect(param, **kwargs):
+ return {
+ 'ControlPort': '9050',
+ 'ControlListenAddress': ['127.0.0.1'],
+ }[param]
- get_conf_mock.side_effect = lambda param, *args, **kwargs: {
- 'ControlPort': '9050',
- 'ControlListenAddress': ['127.0.0.1'],
- }[param]
+ get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual([('127.0.0.1', 9050)], self.controller.get_listeners(Listener.CONTROL))
self.assertEqual([9050], self.controller.get_ports(Listener.CONTROL))
@@ -217,10 +236,13 @@ class TestControl(unittest.TestCase):
# non-local addresss
- get_conf_mock.side_effect = lambda param, *args, **kwargs: {
- 'ControlPort': '9050',
- 'ControlListenAddress': ['27.4.4.1'],
- }[param]
+ async def get_conf_mock_side_effect(param, **kwargs):
+ return {
+ 'ControlPort': '9050',
+ 'ControlListenAddress': ['27.4.4.1'],
+ }[param]
+
+ get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual([('27.4.4.1', 9050)], self.controller.get_listeners(Listener.CONTROL))
self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
@@ -228,8 +250,8 @@ class TestControl(unittest.TestCase):
# exercise via the GETINFO option
- get_info_mock.side_effect = None
- get_info_mock.return_value = '"127.0.0.1:1112" "127.0.0.1:1114"'
+ listeners = '"127.0.0.1:1112" "127.0.0.1:1114"'
+ get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual(
[('127.0.0.1', 1112), ('127.0.0.1', 1114)],
@@ -241,15 +263,16 @@ class TestControl(unittest.TestCase):
# with all localhost addresses, including a couple that aren't
- get_info_mock.side_effect = None
- get_info_mock.return_value = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"'
+ listeners = '"27.4.4.1:1113" "127.0.0.5:1114" "0.0.0.0:1115" "[::]:1116" "[::1]:1117" "[10::]:1118"'
+ get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual([1114, 1115, 1116, 1117], self.controller.get_ports(Listener.OR))
self.controller.clear_cache()
# IPv6 address
- get_info_mock.return_value = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"'
+ listeners = '"0.0.0.0:9001" "[fe80:0000:0000:0000:0202:b3ff:fe1e:8329]:9001"'
+ get_info_mock.side_effect = coro_func_returning_value(listeners)
self.assertEqual(
[('0.0.0.0', 9001), ('fe80:0000:0000:0000:0202:b3ff:fe1e:8329', 9001)],
@@ -259,25 +282,28 @@ class TestControl(unittest.TestCase):
# unix socket file
self.controller.clear_cache()
- get_info_mock.return_value = '"unix:/tmp/tor/socket"'
+ get_info_mock.side_effect = coro_func_returning_value('"unix:/tmp/tor/socket"')
self.assertEqual([], self.controller.get_listeners(Listener.CONTROL))
self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
@patch('time.time', Mock(return_value = 1410723598.276578))
def test_get_accounting_stats(self, get_info_mock):
"""
Exercises the get_accounting_stats() method.
"""
- get_info_mock.side_effect = lambda param, **kwargs: {
- 'accounting/enabled': '1',
- 'accounting/hibernating': 'awake',
- 'accounting/interval-end': '2014-09-14 19:41:00',
- 'accounting/bytes': '4837 2050',
- 'accounting/bytes-left': '102944 7440',
- }[param]
+ async def get_info_mock_side_effect(param, **kwargs):
+ return {
+ 'accounting/enabled': '1',
+ 'accounting/hibernating': 'awake',
+ 'accounting/interval-end': '2014-09-14 19:41:00',
+ 'accounting/bytes': '4837 2050',
+ 'accounting/bytes-left': '102944 7440',
+ }[param]
+
+ get_info_mock.side_effect = get_info_mock_side_effect
expected = stem.control.AccountingStats(
1410723598.276578,
@@ -290,7 +316,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(expected, self.controller.get_accounting_stats())
- get_info_mock.side_effect = ControllerError('nope, too bad')
+ get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
self.assertRaises(ControllerError, self.controller.get_accounting_stats)
self.assertEqual('my default', self.controller.get_accounting_stats('my default'))
@@ -303,7 +329,7 @@ class TestControl(unittest.TestCase):
# use the handy mocked protocolinfo response
protocolinfo_msg = ControlMessage.from_str('250-PROTOCOLINFO 1\r\n250 OK\r\n', 'PROTOCOLINFO')
- get_protocolinfo_mock.return_value = protocolinfo_msg
+ get_protocolinfo_mock.side_effect = coro_func_returning_value(protocolinfo_msg)
# compare the str representation of these object, because the class
# does not have, nor need, a direct comparison operator
@@ -315,7 +341,7 @@ class TestControl(unittest.TestCase):
# raise an exception in the stem.connection.get_protocolinfo() call
- get_protocolinfo_mock.side_effect = ProtocolError
+ get_protocolinfo_mock.side_effect = coro_func_raising_exc(ProtocolError)
# get a default value when the call fails
@@ -338,7 +364,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_user(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.Controller.get_info', Mock(return_value = 'atagar'))
+ @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
def test_get_user_by_getinfo(self):
"""
Exercise the get_user() resolution via its getinfo option.
@@ -366,7 +392,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_pid(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.Controller.get_info', Mock(return_value = '321'))
+ @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321')))
def test_get_pid_by_getinfo(self):
"""
Exercise the get_pid() resolution via its getinfo option.
@@ -375,14 +401,14 @@ class TestControl(unittest.TestCase):
self.assertEqual(321, self.controller.get_pid())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.Controller.get_conf')
+ @patch('stem.control.AsyncController.get_conf')
@patch('stem.control.open', create = True)
def test_get_pid_by_pid_file(self, open_mock, get_conf_mock):
"""
Exercise the get_pid() resolution via a PidFile.
"""
- get_conf_mock.return_value = '/tmp/pid_file'
+ get_conf_mock.side_effect = coro_func_returning_value('/tmp/pid_file')
open_mock.return_value = io.BytesIO(b'432')
self.assertEqual(432, self.controller.get_pid())
@@ -397,25 +423,25 @@ class TestControl(unittest.TestCase):
self.assertEqual(432, self.controller.get_pid())
- @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14')))
+ @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
@patch('time.time', Mock(return_value = 1000.0))
def test_get_uptime_by_getinfo(self, getinfo_mock):
"""
Exercise the get_uptime() resolution via a GETINFO query.
"""
- getinfo_mock.return_value = '321'
+ getinfo_mock.side_effect = coro_func_returning_value('321')
self.assertEqual(321.0, self.controller.get_uptime())
self.controller.clear_cache()
- getinfo_mock.return_value = 'abc'
+ getinfo_mock.side_effect = coro_func_returning_value('abc')
self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.1.0.14')))
- @patch('stem.control.Controller.get_pid', Mock(return_value = '12'))
+ @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
+ @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12')))
@patch('stem.util.system.start_time', Mock(return_value = 5000.0))
@patch('time.time', Mock(return_value = 5200.0))
def test_get_uptime_by_process(self):
@@ -425,7 +451,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(200.0, self.controller.get_uptime())
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
def test_get_network_status_for_ourselves(self, get_info_mock):
"""
Exercises the get_network_status() method for getting our own relay.
@@ -433,7 +459,7 @@ class TestControl(unittest.TestCase):
# when there's an issue getting our fingerprint
- get_info_mock.side_effect = ControllerError('nope, too bad')
+ get_info_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
exc_msg = 'Unable to determine our own fingerprint: nope, too bad'
self.assertRaisesWith(ControllerError, exc_msg, self.controller.get_network_status)
@@ -443,25 +469,29 @@ class TestControl(unittest.TestCase):
desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
- get_info_mock.side_effect = lambda param, **kwargs: {
- 'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
- 'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
- }[param]
+ async def get_info_mock_side_effect(param, **kwargs):
+ return {
+ 'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
+ 'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
+ }[param]
+
+ get_info_mock.side_effect = get_info_mock_side_effect
self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
def test_get_network_status_when_unavailable(self, get_info_mock):
"""
Exercises the get_network_status() method.
"""
- get_info_mock.side_effect = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
+ exc = InvalidArguments(None, 'GETINFO request contained unrecognized keywords: ns/id/5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
+ get_info_mock.side_effect = coro_func_raising_exc(exc)
exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'"
self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
- @patch('stem.control.Controller.get_info')
+ @patch('stem.control.AsyncController.get_info')
def test_get_network_status(self, get_info_mock):
"""
Exercises the get_network_status() method.
@@ -476,7 +506,7 @@ class TestControl(unittest.TestCase):
# always return the same router status entry
- get_info_mock.return_value = desc
+ get_info_mock.side_effect = coro_func_returning_value(desc)
# pretend to get the router status entry with its name
@@ -494,7 +524,7 @@ class TestControl(unittest.TestCase):
# raise an exception in the get_info() call
- get_info_mock.side_effect = InvalidArguments
+ get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
# get a default value when the call fails
@@ -507,22 +537,28 @@ class TestControl(unittest.TestCase):
self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
- @patch('stem.control.Controller.is_authenticated', Mock(return_value = True))
- @patch('stem.control.Controller._attach_listeners', Mock(return_value = ([], [])))
- @patch('stem.control.Controller.get_version')
- def test_add_event_listener(self, get_version_mock):
+ @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True))
+ @patch('stem.control.AsyncController._attach_listeners')
+ @patch('stem.control.AsyncController.get_version')
+ def test_add_event_listener(self, get_version_mock, attach_listeners_mock):
"""
Exercises the add_event_listener and remove_event_listener methods.
"""
+ attach_listeners_mock.side_effect = coro_func_returning_value(([], []))
+
+ def set_version(version_str):
+ version = stem.version.Version(version_str)
+ get_version_mock.side_effect = coro_func_returning_value(version)
+
# set up for failure to create any events
- get_version_mock.return_value = stem.version.Version('0.1.0.14')
+ set_version('0.1.0.14')
self.assertRaises(InvalidRequest, self.controller.add_event_listener, Mock(), EventType.BW)
# set up to only fail newer events
- get_version_mock.return_value = stem.version.Version('0.2.0.35')
+ set_version('0.2.0.35')
# EventType.BW is one of the earliest events
@@ -551,7 +587,7 @@ class TestControl(unittest.TestCase):
event thread.
"""
- self.circ_listener.side_effect = ValueError('boom')
+ self.circ_listener.side_effect = coro_func_raising_exc(ValueError('boom'))
self._emit_event(CIRC_EVENT)
self.circ_listener.assert_called_once_with(CIRC_EVENT)
@@ -582,10 +618,10 @@ class TestControl(unittest.TestCase):
self._emit_event(BW_EVENT)
self.bw_listener.assert_called_once_with(BW_EVENT)
- @patch('stem.control.Controller.get_version', Mock(return_value = stem.version.Version('0.5.0.14')))
- @patch('stem.control.Controller.msg', Mock(return_value = ControlMessage.from_str('250 OK\r\n')))
- @patch('stem.control.Controller.add_event_listener', Mock())
- @patch('stem.control.Controller.remove_event_listener', Mock())
+ @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+ @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
+ @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
def test_timeout(self):
"""
Methods that have an 'await' argument also have an optional timeout. Check
@@ -607,8 +643,9 @@ class TestControl(unittest.TestCase):
)
response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams])
+ get_info_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.Controller.get_info', Mock(return_value = response)):
+ with patch('stem.control.AsyncController.get_info', get_info_mock):
streams = self.controller.get_streams()
self.assertEqual(len(valid_streams), len(streams))
@@ -627,8 +664,9 @@ class TestControl(unittest.TestCase):
# instance, it's already open).
response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n')
+ msg_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.Controller.msg', Mock(return_value = response)):
+ with patch('stem.control.AsyncController.msg', msg_mock):
self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
def test_parse_circ_path(self):
@@ -671,7 +709,7 @@ class TestControl(unittest.TestCase):
for test_input in malformed_inputs:
self.assertRaises(ProtocolError, _parse_circ_path, test_input)
- @patch('stem.control.Controller.get_conf')
+ @patch('stem.control.AsyncController.get_conf')
def test_get_effective_rate(self, get_conf_mock):
"""
Exercise the get_effective_rate() method.
@@ -679,18 +717,21 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- get_conf_mock.side_effect = lambda param, *args, **kwargs: {
- 'BandwidthRate': '1073741824',
- 'BandwidthBurst': '1073741824',
- 'RelayBandwidthRate': '0',
- 'RelayBandwidthBurst': '0',
- 'MaxAdvertisedBandwidth': '1073741824',
- }[param]
+ async def get_conf_mock_side_effect(param, **kwargs):
+ return {
+ 'BandwidthRate': '1073741824',
+ 'BandwidthBurst': '1073741824',
+ 'RelayBandwidthRate': '0',
+ 'RelayBandwidthBurst': '0',
+ 'MaxAdvertisedBandwidth': '1073741824',
+ }[param]
+
+ get_conf_mock.side_effect = get_conf_mock_side_effect
self.assertEqual(1073741824, self.controller.get_effective_rate())
self.assertEqual(1073741824, self.controller.get_effective_rate(burst = True))
- get_conf_mock.side_effect = ControllerError('nope, too bad')
+ get_conf_mock.side_effect = coro_func_raising_exc(ControllerError('nope, too bad'))
self.assertRaises(ControllerError, self.controller.get_effective_rate)
self.assertEqual('my_default', self.controller.get_effective_rate('my_default'))
@@ -705,18 +746,19 @@ class TestControl(unittest.TestCase):
# with its work is to join on the thread.
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
- with patch('stem.control.Controller.is_alive') as is_alive_mock:
+ with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
- self.controller._create_loop_tasks()
+ loop = self.controller._asyncio_loop
+ asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
# Converting an event back into an uncast ControlMessage, then feeding it
# into our controller's event queue.
uncast_event = ControlMessage.from_str(event.raw_content())
- self.controller._event_queue.put(uncast_event)
- self.controller._event_notice.set()
- self.controller._event_queue.join() # block until the event is consumed
+ event_queue = self.async_controller._event_queue
+ asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result()
+ asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result() # block until the event is consumed
finally:
is_alive_mock.return_value = False
- self.controller._close()
+ asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result()
diff --git a/test/unit/response/control_message.py b/test/unit/response/control_message.py
index abf5debf..414dcf63 100644
--- a/test/unit/response/control_message.py
+++ b/test/unit/response/control_message.py
@@ -126,7 +126,7 @@ class TestControlMessage(unittest.TestCase):
# replace the CRLF for the line
infonames_lines[index] = line.rstrip('\r\n') + '\n'
test_socket_file = io.BytesIO(stem.util.str_tools._to_bytes(''.join(infonames_lines)))
- self.assertRaises(stem.ProtocolError, stem.socket.recv_message, test_socket_file)
+ self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, test_socket_file)
# puts the CRLF back
infonames_lines[index] = infonames_lines[index].rstrip('\n') + '\r\n'
@@ -151,8 +151,8 @@ class TestControlMessage(unittest.TestCase):
# - this is part of the message prefix
# - this is disrupting the line ending
- self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input)))
- self.assertRaises(stem.ProtocolError, stem.socket.recv_message, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input)))
+ self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(removal_test_input)))
+ self.assertRaises(stem.ProtocolError, stem.socket.recv_message_from_bytes_io, io.BytesIO(stem.util.str_tools._to_bytes(replacement_test_input)))
else:
# otherwise the data will be malformed, but this goes undetected
self._assert_message_parses(removal_test_input)
@@ -166,7 +166,7 @@ class TestControlMessage(unittest.TestCase):
control_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
control_socket_file = control_socket.makefile()
- self.assertRaises(stem.SocketClosed, stem.socket.recv_message, control_socket_file)
+ self.assertRaises(stem.SocketClosed, stem.socket.recv_message_from_bytes_io, control_socket_file)
def test_equality(self):
msg = stem.response.ControlMessage.from_str(EVENT_BW)
@@ -200,7 +200,7 @@ class TestControlMessage(unittest.TestCase):
stem.response.ControlMessage for the given input
"""
- message = stem.socket.recv_message(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply)))
+ message = stem.socket.recv_message_from_bytes_io(io.BytesIO(stem.util.str_tools._to_bytes(controller_reply)))
# checks that the raw_content equals the input value
self.assertEqual(controller_reply, message.raw_content())