tor-commits
Threads by month
- ----- 2025 -----
- May
- April
- March
- February
- January
- ----- 2024 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2023 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2022 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2021 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2020 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2019 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2018 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2017 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2016 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2015 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2014 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2013 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2012 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2011 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
July 2020
- 17 participants
- 2100 discussions
commit 2b4d3666ef9d3a9faf6742bd5ccfdef88cfbe4fd
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu Apr 30 19:30:00 2020 +0300
Optimize `_MsgLock` a little bit
---
stem/control.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/stem/control.py b/stem/control.py
index ff0f3bc1..6de671b6 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -554,6 +554,8 @@ def event_description(event: str) -> str:
class _MsgLock:
+ __slots__ = ('_r_lock', '_async_lock')
+
def __init__(self):
self._r_lock = threading.RLock()
self._async_lock = asyncio.Lock()
1
0
commit 00e719e07c9a9993607b96d0151aa1baf943d91e
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu Apr 30 19:43:53 2020 +0300
Fix most of the integration tests
---
test/{unit => }/async_util.py | 0
test/integ/connection/authentication.py | 143 +++++++++--------
test/integ/connection/connect.py | 17 ++-
test/integ/control/base_controller.py | 130 ++++++++--------
test/integ/control/controller.py | 261 +++++++++++++++++++-------------
test/integ/manual.py | 6 +-
test/integ/process.py | 27 ++--
test/integ/response/protocolinfo.py | 44 +++---
test/integ/socket/control_message.py | 84 +++++-----
test/integ/socket/control_socket.py | 84 +++++-----
test/integ/util/connection.py | 6 +-
test/integ/util/proc.py | 6 +-
test/integ/version.py | 12 +-
test/runner.py | 64 +++++---
test/unit/connection/authentication.py | 2 +-
test/unit/connection/connect.py | 2 +-
test/unit/control/controller.py | 2 +-
17 files changed, 516 insertions(+), 374 deletions(-)
diff --git a/test/unit/async_util.py b/test/async_util.py
similarity index 100%
rename from test/unit/async_util.py
rename to test/async_util.py
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index d07c20b2..b992ac9a 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -12,6 +12,7 @@ import stem.version
import test
import test.require
import test.runner
+from test.async_util import async_test
# Responses given by tor for various authentication failures. These may change
# in the future and if they do then this test should be updated.
@@ -98,31 +99,34 @@ def _get_auth_failure_message(auth_type):
class TestAuthenticate(unittest.TestCase):
@test.require.controller
- def test_authenticate_general_socket(self):
+ @async_test
+ async def test_authenticate_general_socket(self):
"""
Tests that the authenticate function can authenticate to our socket.
"""
runner = test.runner.get_runner()
- with runner.get_tor_socket(False) as control_socket:
- stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
- test.runner.exercise_controller(self, control_socket)
+ async with await runner.get_tor_socket(False) as control_socket:
+ await stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
+ await test.runner.exercise_controller(self, control_socket)
@test.require.controller
- def test_authenticate_general_controller(self):
+ @async_test
+ async def test_authenticate_general_controller(self):
"""
Tests that the authenticate function can authenticate via a Controller.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
- test.runner.exercise_controller(self, controller)
+ with await runner.get_tor_controller(False) as controller:
+ await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
+ await test.runner.exercise_controller(self, controller)
@test.require.controller
- def test_authenticate_general_example(self):
+ @async_test
+ async def test_authenticate_general_example(self):
"""
Tests the authenticate function with something like its pydoc example.
"""
@@ -139,8 +143,8 @@ class TestAuthenticate(unittest.TestCase):
try:
# this authenticate call should work for everything but password-only auth
- stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot())
- test.runner.exercise_controller(self, control_socket)
+ await stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot())
+ await test.runner.exercise_controller(self, control_socket)
except stem.connection.IncorrectSocketType:
self.fail()
except stem.connection.MissingPassword:
@@ -148,17 +152,18 @@ class TestAuthenticate(unittest.TestCase):
controller_password = test.runner.CONTROL_PASSWORD
try:
- stem.connection.authenticate_password(control_socket, controller_password)
- test.runner.exercise_controller(self, control_socket)
+ await stem.connection.authenticate_password(control_socket, controller_password)
+ await test.runner.exercise_controller(self, control_socket)
except stem.connection.PasswordAuthFailed:
self.fail()
except stem.connection.AuthenticationFailure:
self.fail()
finally:
- control_socket.close()
+ await control_socket.close()
@test.require.controller
- def test_authenticate_general_password(self):
+ @async_test
+ async def test_authenticate_general_password(self):
"""
Tests the authenticate function's password argument.
"""
@@ -172,28 +177,31 @@ class TestAuthenticate(unittest.TestCase):
is_password_only = test.runner.Torrc.PASSWORD in tor_options and test.runner.Torrc.COOKIE not in tor_options
# tests without a password
- with runner.get_tor_socket(False) as control_socket:
+ async with await runner.get_tor_socket(False) as control_socket:
if is_password_only:
- self.assertRaises(stem.connection.MissingPassword, stem.connection.authenticate, control_socket)
+ with self.assertRaises(stem.connection.MissingPassword):
+ await stem.connection.authenticate(control_socket)
else:
- stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot())
- test.runner.exercise_controller(self, control_socket)
+ await stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot())
+ await test.runner.exercise_controller(self, control_socket)
# tests with the incorrect password
- with runner.get_tor_socket(False) as control_socket:
+ async with await runner.get_tor_socket(False) as control_socket:
if is_password_only:
- self.assertRaises(stem.connection.IncorrectPassword, stem.connection.authenticate, control_socket, 'blarg')
+ with self.assertRaises(stem.connection.IncorrectPassword):
+ await stem.connection.authenticate(control_socket, 'blarg')
else:
- stem.connection.authenticate(control_socket, 'blarg', runner.get_chroot())
- test.runner.exercise_controller(self, control_socket)
+ await stem.connection.authenticate(control_socket, 'blarg', runner.get_chroot())
+ await test.runner.exercise_controller(self, control_socket)
# tests with the right password
- with runner.get_tor_socket(False) as control_socket:
- stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
- test.runner.exercise_controller(self, control_socket)
+ async with await runner.get_tor_socket(False) as control_socket:
+ await stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
+ await test.runner.exercise_controller(self, control_socket)
@test.require.controller
- def test_authenticate_general_cookie(self):
+ @async_test
+ async def test_authenticate_general_cookie(self):
"""
Tests the authenticate function with only cookie authentication methods.
This manipulates our PROTOCOLINFO response to test each method
@@ -205,7 +213,7 @@ class TestAuthenticate(unittest.TestCase):
is_cookie_only = test.runner.Torrc.COOKIE in tor_options and test.runner.Torrc.PASSWORD not in tor_options
# test both cookie authentication mechanisms
- with runner.get_tor_socket(False) as control_socket:
+ async with await runner.get_tor_socket(False) as control_socket:
if is_cookie_only:
for method in (stem.connection.AuthMethod.COOKIE, stem.connection.AuthMethod.SAFECOOKIE):
protocolinfo_response = stem.connection.get_protocolinfo(control_socket)
@@ -215,10 +223,11 @@ class TestAuthenticate(unittest.TestCase):
# both independently
protocolinfo_response.auth_methods = (method, )
- stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot(), protocolinfo_response = protocolinfo_response)
+ await stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot(), protocolinfo_response = protocolinfo_response)
@test.require.controller
- def test_authenticate_none(self):
+ @async_test
+ async def test_authenticate_none(self):
"""
Tests the authenticate_none function.
"""
@@ -226,12 +235,14 @@ class TestAuthenticate(unittest.TestCase):
auth_type = stem.connection.AuthMethod.NONE
if _can_authenticate(auth_type):
- self._check_auth(auth_type)
+ await self._check_auth(auth_type)
else:
- self.assertRaises(stem.connection.OpenAuthRejected, self._check_auth, auth_type)
+ with self.assertRaises(stem.connection.OpenAuthRejected):
+ await self._check_auth(auth_type)
@test.require.controller
- def test_authenticate_password(self):
+ @async_test
+ async def test_authenticate_password(self):
"""
Tests the authenticate_password function.
"""
@@ -240,26 +251,29 @@ class TestAuthenticate(unittest.TestCase):
auth_value = test.runner.CONTROL_PASSWORD
if _can_authenticate(auth_type):
- self._check_auth(auth_type, auth_value)
+ await self._check_auth(auth_type, auth_value)
else:
- self.assertRaises(stem.connection.PasswordAuthRejected, self._check_auth, auth_type, auth_value)
+ with self.assertRaises(stem.connection.PasswordAuthRejected):
+ await self._check_auth(auth_type, auth_value)
# Check with an empty, invalid, and quoted password. These should work if
# we have no authentication, and fail otherwise.
for auth_value in ('', 'blarg', 'this has a " in it'):
if _can_authenticate(stem.connection.AuthMethod.NONE):
- self._check_auth(auth_type, auth_value)
+ await self._check_auth(auth_type, auth_value)
else:
if _can_authenticate(stem.connection.AuthMethod.PASSWORD):
exc_type = stem.connection.IncorrectPassword
else:
exc_type = stem.connection.PasswordAuthRejected
- self.assertRaises(exc_type, self._check_auth, auth_type, auth_value)
+ with self.assertRaises(exc_type):
+ await self._check_auth(auth_type, auth_value)
@test.require.controller
- def test_wrong_password_with_controller(self):
+ @async_test
+ async def test_wrong_password_with_controller(self):
"""
We ran into a race condition where providing the wrong password to the
Controller caused inconsistent responses. Checking for that...
@@ -273,11 +287,13 @@ class TestAuthenticate(unittest.TestCase):
self.skipTest('(requires only password auth)')
for i in range(10):
- with runner.get_tor_controller(False) as controller:
- self.assertRaises(stem.connection.IncorrectPassword, controller.authenticate, 'wrong_password')
+ with await runner.get_tor_controller(False) as controller:
+ with self.assertRaises(stem.connection.IncorrectPassword):
+ await controller.authenticate('wrong_password')
@test.require.controller
- def test_authenticate_cookie(self):
+ @async_test
+ async def test_authenticate_cookie(self):
"""
Tests the authenticate_cookie function.
"""
@@ -292,14 +308,17 @@ class TestAuthenticate(unittest.TestCase):
# auth but the function will short circuit with failure due to the
# missing file.
- self.assertRaises(stem.connection.UnreadableCookieFile, self._check_auth, auth_type, auth_value, False)
+ with self.assertRaises(stem.connection.UnreadableCookieFile):
+ await self._check_auth(auth_type, auth_value, False)
elif _can_authenticate(auth_type):
- self._check_auth(auth_type, auth_value)
+ await self._check_auth(auth_type, auth_value)
else:
- self.assertRaises(stem.connection.CookieAuthRejected, self._check_auth, auth_type, auth_value, False)
+ with self.assertRaises(stem.connection.CookieAuthRejected):
+ await self._check_auth(auth_type, auth_value, False)
@test.require.controller
- def test_authenticate_cookie_invalid(self):
+ @async_test
+ async def test_authenticate_cookie_invalid(self):
"""
Tests the authenticate_cookie function with a properly sized but incorrect
value.
@@ -316,10 +335,11 @@ class TestAuthenticate(unittest.TestCase):
if _can_authenticate(stem.connection.AuthMethod.NONE):
# authentication will work anyway unless this is safecookie
if auth_type == stem.connection.AuthMethod.COOKIE:
- self._check_auth(auth_type, auth_value)
+ await self._check_auth(auth_type, auth_value)
elif auth_type == stem.connection.AuthMethod.SAFECOOKIE:
exc_type = stem.connection.CookieAuthRejected
- self.assertRaises(exc_type, self._check_auth, auth_type, auth_value)
+ with self.assertRaises(exc_type):
+ await self._check_auth(auth_type, auth_value)
else:
if auth_type == stem.connection.AuthMethod.SAFECOOKIE:
if _can_authenticate(auth_type):
@@ -331,12 +351,14 @@ class TestAuthenticate(unittest.TestCase):
else:
exc_type = stem.connection.CookieAuthRejected
- self.assertRaises(exc_type, self._check_auth, auth_type, auth_value, False)
+ with self.assertRaises(exc_type):
+ await self._check_auth(auth_type, auth_value, False)
os.remove(auth_value)
@test.require.controller
- def test_authenticate_cookie_missing(self):
+ @async_test
+ async def test_authenticate_cookie_missing(self):
"""
Tests the authenticate_cookie function with a path that really, really
shouldn't exist.
@@ -344,10 +366,12 @@ class TestAuthenticate(unittest.TestCase):
for auth_type in (stem.connection.AuthMethod.COOKIE, stem.connection.AuthMethod.SAFECOOKIE):
auth_value = "/if/this/exists/then/they're/asking/for/a/failure"
- self.assertRaises(stem.connection.UnreadableCookieFile, self._check_auth, auth_type, auth_value, False)
+ with self.assertRaises(stem.connection.UnreadableCookieFile):
+ await self._check_auth(auth_type, auth_value, False)
@test.require.controller
- def test_authenticate_cookie_wrong_size(self):
+ @async_test
+ async def test_authenticate_cookie_wrong_size(self):
"""
Tests the authenticate_cookie function with our torrc as an auth cookie.
This is to confirm that we won't read arbitrary files to the control
@@ -361,9 +385,10 @@ class TestAuthenticate(unittest.TestCase):
# Weird coincidence? Fail so we can pick another file to check against.
self.fail('Our torrc is 32 bytes, preventing the test_authenticate_cookie_wrong_size test from running.')
else:
- self.assertRaises(stem.connection.IncorrectCookieSize, self._check_auth, auth_type, auth_value, False)
+ with self.assertRaises(stem.connection.IncorrectCookieSize):
+ await self._check_auth(auth_type, auth_value, False)
- def _check_auth(self, auth_type, auth_arg = None, check_message = True):
+ async def _check_auth(self, auth_type, auth_arg = None, check_message = True):
"""
Attempts to use the given type of authentication against tor's control
socket. If it succeeds then we check that the socket can then be used. If
@@ -377,19 +402,19 @@ class TestAuthenticate(unittest.TestCase):
:raises: :class:`stem.connection.AuthenticationFailure` if the authentication fails
"""
- with test.runner.get_runner().get_tor_socket(False) as control_socket:
+ async with await test.runner.get_runner().get_tor_socket(False) as control_socket:
# run the authentication, re-raising if there's a problem
try:
if auth_type == stem.connection.AuthMethod.NONE:
- stem.connection.authenticate_none(control_socket)
+ await stem.connection.authenticate_none(control_socket)
elif auth_type == stem.connection.AuthMethod.PASSWORD:
- stem.connection.authenticate_password(control_socket, auth_arg)
+ await stem.connection.authenticate_password(control_socket, auth_arg)
elif auth_type == stem.connection.AuthMethod.COOKIE:
- stem.connection.authenticate_cookie(control_socket, auth_arg)
+ await stem.connection.authenticate_cookie(control_socket, auth_arg)
elif auth_type == stem.connection.AuthMethod.SAFECOOKIE:
- stem.connection.authenticate_safecookie(control_socket, auth_arg)
+ await stem.connection.authenticate_safecookie(control_socket, auth_arg)
- test.runner.exercise_controller(self, control_socket)
+ await test.runner.exercise_controller(self, control_socket)
except stem.connection.AuthenticationFailure as exc:
# authentication functions should re-attach on failure
self.assertTrue(control_socket.is_alive())
diff --git a/test/integ/connection/connect.py b/test/integ/connection/connect.py
index a271843f..b1d2a672 100644
--- a/test/integ/connection/connect.py
+++ b/test/integ/connection/connect.py
@@ -8,6 +8,7 @@ import unittest
import stem.connection
import test.require
import test.runner
+from test.async_util import async_test
from unittest.mock import patch
@@ -15,37 +16,37 @@ from unittest.mock import patch
class TestConnect(unittest.TestCase):
@test.require.controller
@patch('sys.stdout', new_callable = io.StringIO)
- def test_connect(self, stdout_mock):
+ @async_test
+ async def test_connect(self, stdout_mock):
"""
Basic sanity checks for the connect function.
"""
runner = test.runner.get_runner()
- control_socket = stem.connection.connect(
+ control_socket = await stem.connection.connect_async(
control_port = ('127.0.0.1', test.runner.CONTROL_PORT),
control_socket = test.runner.CONTROL_SOCKET_PATH,
password = test.runner.CONTROL_PASSWORD,
chroot_path = runner.get_chroot(),
controller = None)
- test.runner.exercise_controller(self, control_socket)
+ await test.runner.exercise_controller(self, control_socket)
self.assertEqual('', stdout_mock.getvalue())
@test.require.controller
@patch('sys.stdout', new_callable = io.StringIO)
- def test_connect_to_socks_port(self, stdout_mock):
+ @async_test
+ async def test_connect_to_socks_port(self, stdout_mock):
"""
Common user gotcha is connecting to the SocksPort or ORPort rather than the
ControlPort. Testing that connecting to the SocksPort errors in a
reasonable way.
"""
- runner = test.runner.get_runner()
-
- control_socket = stem.connection.connect(
+ control_socket = await stem.connection.connect_async(
control_port = ('127.0.0.1', test.runner.SOCKS_PORT),
- chroot_path = runner.get_chroot(),
+ control_socket = None,
controller = None)
self.assertEqual(None, control_socket)
diff --git a/test/integ/control/base_controller.py b/test/integ/control/base_controller.py
index 323b57c7..ff51e2f1 100644
--- a/test/integ/control/base_controller.py
+++ b/test/integ/control/base_controller.py
@@ -2,10 +2,11 @@
Integration tests for the stem.control.BaseController class.
"""
+import asyncio
import os
import hashlib
+import random
import re
-import threading
import time
import unittest
@@ -14,6 +15,7 @@ import stem.socket
import stem.util.system
import test.require
import test.runner
+from test.async_util import async_test
class StateObserver(object):
@@ -39,7 +41,8 @@ class StateObserver(object):
class TestBaseController(unittest.TestCase):
@test.require.controller
- def test_connect_repeatedly(self):
+ @async_test
+ async def test_connect_repeatedly(self):
"""
Connects and closes the socket repeatedly. This is a simple attempt to
trigger concurrency issues.
@@ -48,47 +51,51 @@ class TestBaseController(unittest.TestCase):
if stem.util.system.is_mac():
self.skipTest('(ticket #6235)')
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
for _ in range(50):
- controller.connect()
- controller.close()
+ await controller.connect()
+ await controller.close()
@test.require.controller
- def test_msg(self):
+ @async_test
+ async def test_msg(self):
"""
Tests a basic query with the msg() method.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
- test.runner.exercise_controller(self, controller)
+ await test.runner.exercise_controller(self, controller)
@test.require.controller
- def test_msg_invalid(self):
+ @async_test
+ async def test_msg_invalid(self):
"""
Tests the msg() method against an invalid controller command.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
- response = controller.msg('invalid')
+ response = await controller.msg('invalid')
self.assertEqual('Unrecognized command "invalid"', str(response))
@test.require.controller
- def test_msg_invalid_getinfo(self):
+ @async_test
+ async def test_msg_invalid_getinfo(self):
"""
Tests the msg() method against a non-existant GETINFO option.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
- response = controller.msg('GETINFO blarg')
+ response = await controller.msg('GETINFO blarg')
self.assertEqual('Unrecognized key "blarg"', str(response))
@test.require.controller
- def test_msg_repeatedly(self):
+ @async_test
+ async def test_msg_repeatedly(self):
"""
Connects, sends a burst of messages, and closes the socket repeatedly. This
is a simple attempt to trigger concurrency issues.
@@ -97,35 +104,31 @@ class TestBaseController(unittest.TestCase):
if stem.util.system.is_mac():
self.skipTest('(ticket #6235)')
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
- def run_getinfo():
+ async def connect_and_close():
+ await controller.connect()
+ await controller.close()
+
+ async def run_getinfo():
for _ in range(50):
try:
- controller.msg('GETINFO version')
- controller.msg('GETINFO blarg')
- controller.msg('blarg')
+ await controller.msg('GETINFO version')
+ await controller.msg('GETINFO blarg')
+ await controller.msg('blarg')
except stem.ControllerError:
pass
- message_threads = []
-
- for _ in range(5):
- msg_thread = threading.Thread(target = run_getinfo)
- message_threads.append(msg_thread)
- msg_thread.setDaemon(True)
- msg_thread.start()
-
- for index in range(50):
- controller.connect()
- controller.close()
+ coroutines = [connect_and_close()] * 50
+ coroutines.extend(run_getinfo() for _ in range(5))
+ random.shuffle(coroutines)
- for msg_thread in message_threads:
- msg_thread.join()
+ await asyncio.gather(*coroutines)
@test.require.controller
- def test_asynchronous_event_handling(self):
+ @async_test
+ async def test_asynchronous_event_handling(self):
"""
Check that we can both receive asynchronous events while hammering our
socket with queries, and checks that when a controller is closed the
@@ -140,37 +143,27 @@ class TestBaseController(unittest.TestCase):
def __init__(self, control_socket):
stem.control.BaseController.__init__(self, control_socket)
self.received_events = []
- self.receive_notice = threading.Event()
+ self.receive_notice = asyncio.Event()
- def _handle_event(self, event_message):
- self.receive_notice.wait()
+ async def _handle_event(self, event_message):
+ await self.receive_notice.wait()
self.received_events.append(event_message)
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = ControlledListener(control_socket)
- controller.msg('SETEVENTS CONF_CHANGED')
+ await controller.msg('SETEVENTS CONF_CHANGED')
for i in range(10):
fingerprint = hashlib.sha1(os.urandom(20)).hexdigest().upper()
- controller.msg('SETCONF NodeFamily=%s' % fingerprint)
- test.runner.exercise_controller(self, controller)
-
- controller.msg('SETEVENTS')
- controller.msg('RESETCONF NodeFamily')
-
- # Concurrently shut down the controller. We need to do this in another
- # thread because it'll block on the event handling, which in turn is
- # currently blocking on the reveive_notice.
-
- close_thread = threading.Thread(target = controller.close, name = 'Closing controller')
- close_thread.setDaemon(True)
- close_thread.start()
+ await controller.msg('SETCONF NodeFamily=%s' % fingerprint)
+ await test.runner.exercise_controller(self, controller)
- # Finally start handling the BW events that we've received. We should
- # have at least a couple of them.
+ await controller.msg('SETEVENTS')
+ await controller.msg('RESETCONF NodeFamily')
+ await controller.close()
controller.receive_notice.set()
- close_thread.join()
+ await asyncio.sleep(0)
self.assertTrue(len(controller.received_events) >= 2)
@@ -180,19 +173,21 @@ class TestBaseController(unittest.TestCase):
self.assertEqual(('650', '-'), conf_changed_event.content()[0][:2])
@test.require.controller
- def test_get_latest_heartbeat(self):
+ @async_test
+ async def test_get_latest_heartbeat(self):
"""
Basic check for get_latest_heartbeat().
"""
# makes a getinfo query, then checks that the heartbeat is close to now
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
controller = stem.control.BaseController(control_socket)
- controller.msg('GETINFO version')
+ await controller.msg('GETINFO version')
self.assertTrue((time.time() - controller.get_latest_heartbeat()) < 5)
@test.require.controller
- def test_status_notifications(self):
+ @async_test
+ async def test_status_notifications(self):
"""
Checks basic functionality of the add_status_listener() and
remove_status_listener() methods.
@@ -200,18 +195,18 @@ class TestBaseController(unittest.TestCase):
state_observer = StateObserver()
- with test.runner.get_runner().get_tor_socket(False) as control_socket:
+ async with await test.runner.get_runner().get_tor_socket(False) as control_socket:
controller = stem.control.BaseController(control_socket)
controller.add_status_listener(state_observer.listener, False)
- controller.close()
+ await controller.close()
self.assertEqual(controller, state_observer.controller)
self.assertEqual(stem.control.State.CLOSED, state_observer.state)
self.assertTrue(state_observer.timestamp <= time.time())
self.assertTrue(state_observer.timestamp > time.time() - 1.0)
state_observer.reset()
- controller.connect()
+ await controller.connect()
self.assertEqual(controller, state_observer.controller)
self.assertEqual(stem.control.State.INIT, state_observer.state)
self.assertTrue(state_observer.timestamp <= time.time())
@@ -219,8 +214,9 @@ class TestBaseController(unittest.TestCase):
state_observer.reset()
# cause the socket to shut down without calling close()
- controller.msg('Blarg!')
- self.assertRaises(stem.SocketClosed, controller.msg, 'blarg')
+ await controller.msg('Blarg!')
+ with self.assertRaises(stem.SocketClosed):
+ await controller.msg('blarg')
self.assertEqual(controller, state_observer.controller)
self.assertEqual(stem.control.State.CLOSED, state_observer.state)
self.assertTrue(state_observer.timestamp <= time.time())
@@ -229,7 +225,7 @@ class TestBaseController(unittest.TestCase):
# remove listener and make sure we don't get further notices
controller.remove_status_listener(state_observer.listener)
- controller.connect()
+ await controller.connect()
self.assertEqual(None, state_observer.controller)
self.assertEqual(None, state_observer.state)
self.assertEqual(None, state_observer.timestamp)
@@ -239,8 +235,8 @@ class TestBaseController(unittest.TestCase):
# get the notice asynchronously
controller.add_status_listener(state_observer.listener, True)
- controller.close()
- time.sleep(0.001) # not much work going on so this doesn't need to be much
+ await controller.close()
+ await asyncio.sleep(0.001) # not much work going on so this doesn't need to be much
self.assertEqual(controller, state_observer.controller)
self.assertEqual(stem.control.State.CLOSED, state_observer.state)
self.assertTrue(state_observer.timestamp <= time.time())
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 732ae50a..ab1e76c3 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -2,6 +2,7 @@
Integration tests for the stem.control.Controller class.
"""
+import asyncio
import os
import shutil
import socket
@@ -22,6 +23,7 @@ import test
import test.network
import test.require
import test.runner
+from test.async_util import async_test
from stem import Flag, Signal
from stem.control import EventType, Listener, State
@@ -36,13 +38,14 @@ TEST_ROUTER_STATUS_ENTRY = None
class TestController(unittest.TestCase):
@test.require.only_run_once
@test.require.controller
- def test_missing_capabilities(self):
+ @async_test
+ async def test_missing_capabilities(self):
"""
Check to see if tor supports any events, signals, or features that we
don't.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
for event in controller.get_info('events/names').split():
if event not in EventType:
test.register_new_capability('Event', event)
@@ -80,12 +83,13 @@ class TestController(unittest.TestCase):
self.assertRaises(stem.SocketError, stem.control.Controller.from_socket_file, test.runner.CONTROL_SOCKET_PATH)
@test.require.controller
- def test_reset_notification(self):
+ @async_test
+ async def test_reset_notification(self):
"""
Checks that a notificiation listener is... well, notified of SIGHUPs.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
received_events = []
def status_listener(my_controller, state, timestamp):
@@ -101,7 +105,7 @@ class TestController(unittest.TestCase):
if (time.time() - before) > 2:
self.fail("We've waited a couple seconds for SIGHUP to generate an event, but it didn't come")
- time.sleep(0.001)
+ await asyncio.sleep(0.001)
after = time.time()
@@ -109,14 +113,15 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller, state_controller)
+ self.assertEqual(controller._async_controller, state_controller)
self.assertEqual(State.RESET, state_type)
self.assertTrue(state_timestamp > before and state_timestamp < after)
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_event_handling(self):
+ @async_test
+ async def test_event_handling(self):
"""
Add a couple listeners for various events and make sure that they receive
them. Then remove the listeners.
@@ -135,7 +140,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
controller.add_event_listener(listener1, EventType.CONF_CHANGED)
controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
@@ -174,7 +179,8 @@ class TestController(unittest.TestCase):
controller.reset_conf('NodeFamily')
@test.require.controller
- def test_reattaching_listeners(self):
+ @async_test
+ async def test_reattaching_listeners(self):
"""
Checks that event listeners are re-attached when a controller disconnects
then reconnects to tor.
@@ -189,7 +195,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
controller.add_event_listener(listener, EventType.CONF_CHANGED)
# trigger an event
@@ -215,14 +221,15 @@ class TestController(unittest.TestCase):
controller.reset_conf('NodeFamily')
@test.require.controller
- def test_getinfo(self):
+ @async_test
+ async def test_getinfo(self):
"""
Exercises GETINFO with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# successful single query
torrc_path = runner.get_torrc_path()
@@ -253,12 +260,13 @@ class TestController(unittest.TestCase):
self.assertEqual({}, controller.get_info([], {}))
@test.require.controller
- def test_getinfo_freshrelaydescs(self):
+ @async_test
+ async def test_getinfo_freshrelaydescs(self):
"""
Exercises 'GETINFO status/fresh-relay-descs'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
response = controller.get_info('status/fresh-relay-descs')
div = response.find('\nextra-info ')
nickname = controller.get_conf('Nickname')
@@ -276,12 +284,13 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_getinfo_dir_status(self):
+ @async_test
+ async def test_getinfo_dir_status(self):
"""
Exercise 'GETINFO dir/status-vote/*'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
consensus = controller.get_info('dir/status-vote/current/consensus')
self.assertTrue('moria1' in consensus, 'moria1 not found in the consensus')
@@ -290,47 +299,51 @@ class TestController(unittest.TestCase):
self.assertTrue('moria1' in microdescs, 'moria1 not found in the microdescriptor consensus')
@test.require.controller
- def test_get_version(self):
+ @async_test
+ async def test_get_version(self):
"""
Test that the convenient method get_version() works.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
version = controller.get_version()
self.assertTrue(isinstance(version, stem.version.Version))
self.assertEqual(version, test.tor_version())
@test.require.controller
- def test_get_exit_policy(self):
+ @async_test
+ async def test_get_exit_policy(self):
"""
Sanity test for get_exit_policy(). Our 'ExitRelay 0' torrc entry causes us
to have a simple reject-all policy.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(ExitPolicy('reject *:*'), controller.get_exit_policy())
@test.require.controller
- def test_authenticate(self):
+ @async_test
+ async def test_authenticate(self):
"""
Test that the convenient method authenticate() works.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
+ with await runner.get_tor_controller(False) as controller:
controller.authenticate(test.runner.CONTROL_PASSWORD)
- test.runner.exercise_controller(self, controller)
+ await test.runner.exercise_controller(self, controller)
@test.require.controller
- def test_protocolinfo(self):
+ @async_test
+ async def test_protocolinfo(self):
"""
Test that the convenient method protocolinfo() works.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
+ with await runner.get_tor_controller(False) as controller:
protocolinfo = controller.get_protocolinfo()
self.assertTrue(isinstance(protocolinfo, stem.response.protocolinfo.ProtocolInfoResponse))
@@ -351,14 +364,15 @@ class TestController(unittest.TestCase):
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
@test.require.controller
- def test_getconf(self):
+ @async_test
+ async def test_getconf(self):
"""
Exercises GETCONF with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
control_socket = controller.get_socket()
if isinstance(control_socket, stem.socket.ControlPort):
@@ -414,15 +428,16 @@ class TestController(unittest.TestCase):
self.assertEqual({}, controller.get_conf_map([], 'la-di-dah'))
@test.require.controller
- def test_is_set(self):
+ @async_test
+ async def test_is_set(self):
"""
Exercises our is_set() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- custom_options = controller._get_custom_options()
+ with await runner.get_tor_controller() as controller:
+ custom_options = controller._execute_async_method('_get_custom_options')
self.assertTrue('ControlPort' in custom_options or 'ControlSocket' in custom_options)
self.assertEqual('1', custom_options['DownloadExtraInfo'])
self.assertEqual('1112', custom_options['SocksPort'])
@@ -441,7 +456,8 @@ class TestController(unittest.TestCase):
self.assertFalse(controller.is_set('ConnLimit'))
@test.require.controller
- def test_hidden_services_conf(self):
+ @async_test
+ async def test_hidden_services_conf(self):
"""
Exercises the hidden service family of methods (get_hidden_service_conf,
set_hidden_service_conf, create_hidden_service, and remove_hidden_service).
@@ -455,7 +471,7 @@ class TestController(unittest.TestCase):
service3_path = os.path.join(test_dir, 'test_hidden_service3')
service4_path = os.path.join(test_dir, 'test_hidden_service4')
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
try:
# initially we shouldn't be running any hidden services
@@ -549,32 +565,35 @@ class TestController(unittest.TestCase):
pass
@test.require.controller
- def test_without_ephemeral_hidden_services(self):
+ @async_test
+ async def test_without_ephemeral_hidden_services(self):
"""
Exercises ephemeral hidden service methods when none are present.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual([], controller.list_ephemeral_hidden_services())
self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
self.assertEqual(False, controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
@test.require.controller
- def test_with_invalid_ephemeral_hidden_service_port(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_with_invalid_ephemeral_hidden_service_port(self):
+ with await test.runner.get_runner().get_tor_controller() as controller:
for ports in (4567890, [4567, 4567890], {4567: '-:4567'}):
exc_msg = "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"
self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, ports)
@test.require.controller
- def test_ephemeral_hidden_services_v2(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v2(self):
"""
Exercises creating v2 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -606,19 +625,20 @@ class TestController(unittest.TestCase):
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
+ with await runner.get_tor_controller() as second_controller:
self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_ephemeral_hidden_services_v3(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v3(self):
"""
Exercises creating v3 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -650,19 +670,20 @@ class TestController(unittest.TestCase):
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
+ with await runner.get_tor_controller() as second_controller:
self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth(self):
"""
Exercises creating ephemeral hidden services that uses basic authentication.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
@@ -674,7 +695,8 @@ class TestController(unittest.TestCase):
self.assertEqual([], controller.list_ephemeral_hidden_services())
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
"""
Exercises creating ephemeral hidden services when attempting to use basic
auth but not including any credentials.
@@ -682,12 +704,13 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
exc_msg = "ADD_ONION response didn't have an OK status: No auth clients specified"
self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, 4567, basic_auth = {})
@test.require.controller
- def test_with_detached_ephemeral_hidden_services(self):
+ @async_test
+ async def test_with_detached_ephemeral_hidden_services(self):
"""
Exercises creating detached ephemeral hidden services and methods when
they're present.
@@ -695,7 +718,7 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
response = controller.create_ephemeral_hidden_service(4567, detached = True)
self.assertEqual([], controller.list_ephemeral_hidden_services())
self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
@@ -709,7 +732,7 @@ class TestController(unittest.TestCase):
# other controllers should be able to see this service, and drop it
- with runner.get_tor_controller() as second_controller:
+ with await runner.get_tor_controller() as second_controller:
self.assertEqual([response.service_id], second_controller.list_ephemeral_hidden_services(detached = True))
self.assertEqual(True, second_controller.remove_ephemeral_hidden_service(response.service_id))
self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
@@ -722,7 +745,8 @@ class TestController(unittest.TestCase):
controller.remove_ephemeral_hidden_service(response.service_id)
@test.require.controller
- def test_rejecting_unanonymous_hidden_services_creation(self):
+ @async_test
+ async def test_rejecting_unanonymous_hidden_services_creation(self):
"""
Attempt to create a non-anonymous hidden service despite not setting
HiddenServiceSingleHopMode and HiddenServiceNonAnonymousMode.
@@ -730,11 +754,12 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
self.assertEqual('Tor is in anonymous hidden service mode', str(controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
@test.require.controller
- def test_set_conf(self):
+ @async_test
+ async def test_set_conf(self):
"""
Exercises set_conf(), reset_conf(), and set_options() methods with valid
and invalid requests.
@@ -744,7 +769,7 @@ class TestController(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
try:
# successfully set a single option
connlimit = int(controller.get_conf('ConnLimit'))
@@ -807,13 +832,14 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- def test_set_conf_for_usebridges(self):
+ @async_test
+ async def test_set_conf_for_usebridges(self):
"""
Ensure we can set UseBridges=1 and also set a Bridge. This is a tor
regression check (:trac:`31945`).
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
orport = controller.get_conf('ORPort')
try:
@@ -830,24 +856,26 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- def test_set_conf_when_immutable(self):
+ @async_test
+ async def test_set_conf_when_immutable(self):
"""
Issue a SETCONF for tor options that cannot be changed while running.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running", controller.set_conf, 'DisableAllSwap', '1')
self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running", controller.set_options, {'User': 'atagar', 'DisableAllSwap': '1'})
@test.require.controller
- def test_loadconf(self):
+ @async_test
+ async def test_loadconf(self):
"""
Exercises Controller.load_conf with valid and invalid requests.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -877,12 +905,13 @@ class TestController(unittest.TestCase):
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_saveconf(self):
+ @async_test
+ async def test_saveconf(self):
runner = test.runner.get_runner()
# only testing for success, since we need to run out of disk space to test
# for failure
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -897,14 +926,15 @@ class TestController(unittest.TestCase):
controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_get_ports(self):
+ @async_test
+ async def test_get_ports(self):
"""
Test Controller.get_ports against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
self.assertEqual([test.runner.ORPORT], controller.get_ports(Listener.OR))
self.assertEqual([], controller.get_ports(Listener.DIR))
self.assertEqual([test.runner.SOCKS_PORT], controller.get_ports(Listener.SOCKS))
@@ -918,14 +948,15 @@ class TestController(unittest.TestCase):
self.assertEqual([], controller.get_ports(Listener.CONTROL))
@test.require.controller
- def test_get_listeners(self):
+ @async_test
+ async def test_get_listeners(self):
"""
Test Controller.get_listeners against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
self.assertEqual([('0.0.0.0', test.runner.ORPORT)], controller.get_listeners(Listener.OR))
self.assertEqual([], controller.get_listeners(Listener.DIR))
self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], controller.get_listeners(Listener.SOCKS))
@@ -941,14 +972,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
@test.require.version(stem.version.Version('0.1.2.2-alpha'))
- def test_enable_feature(self):
+ @async_test
+ async def test_enable_feature(self):
"""
Test Controller.enable_feature with valid and invalid inputs.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
self.assertTrue(controller.is_feature_enabled('VERBOSE_NAMES'))
self.assertRaises(stem.InvalidArguments, controller.enable_feature, ['NOT', 'A', 'FEATURE'])
@@ -960,12 +992,13 @@ class TestController(unittest.TestCase):
self.fail()
@test.require.controller
- def test_signal(self):
+ @async_test
+ async def test_signal(self):
"""
Test controller.signal with valid and invalid signals.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
# valid signal
controller.signal('CLEARDNSCACHE')
@@ -973,12 +1006,13 @@ class TestController(unittest.TestCase):
self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR')
@test.require.controller
- def test_newnym_availability(self):
+ @async_test
+ async def test_newnym_availability(self):
"""
Test the is_newnym_available and get_newnym_wait methods.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(True, controller.is_newnym_available())
self.assertEqual(0.0, controller.get_newnym_wait())
@@ -989,8 +1023,9 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_extendcircuit(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_extendcircuit(self):
+ with await test.runner.get_runner().get_tor_controller() as controller:
circuit_id = controller.extend_circuit('0')
# check if our circuit was created
@@ -1004,14 +1039,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_repurpose_circuit(self):
+ @async_test
+ async def test_repurpose_circuit(self):
"""
Tests Controller.repurpose_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
circ_id = controller.new_circuit()
controller.repurpose_circuit(circ_id, 'CONTROLLER')
circuit = controller.get_circuit(circ_id)
@@ -1026,14 +1062,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_close_circuit(self):
+ @async_test
+ async def test_close_circuit(self):
"""
Tests Controller.close_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id)
circuit_output = controller.get_info('circuit-status')
@@ -1052,7 +1089,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_streams(self):
+ @async_test
+ async def test_get_streams(self):
"""
Tests Controller.get_streams().
"""
@@ -1061,7 +1099,7 @@ class TestController(unittest.TestCase):
port = 443
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# we only need one proxy port, so take the first
socks_listener = controller.get_listeners(Listener.SOCKS)[0]
@@ -1077,14 +1115,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_close_stream(self):
+ @async_test
+ async def test_close_stream(self):
"""
Tests Controller.close_stream with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# use the first socks listener
socks_listener = controller.get_listeners(Listener.SOCKS)[0]
@@ -1116,11 +1155,12 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_mapaddress(self):
+ @async_test
+ async def test_mapaddress(self):
self.skipTest('(https://trac.torproject.org/projects/tor/ticket/25611)')
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
controller.map_address({'1.2.1.2': 'ifconfig.me'})
s = None
@@ -1154,10 +1194,11 @@ class TestController(unittest.TestCase):
self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)), "'%s' isn't an address" % ip_addr)
@test.require.controller
- def test_mapaddress_offline(self):
+ @async_test
+ async def test_mapaddress_offline(self):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# try mapping one element, ensuring results are as expected
map1 = {'1.2.1.2': 'ifconfig.me'}
@@ -1233,12 +1274,13 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_microdescriptor(self):
+ @async_test
+ async def test_get_microdescriptor(self):
"""
Basic checks for get_microdescriptor().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_microdescriptor, '')
self.assertRaises(ValueError, controller.get_microdescriptor, 5)
@@ -1257,7 +1299,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_microdescriptors(self):
+ @async_test
+ async def test_get_microdescriptors(self):
"""
Fetches a few descriptors via the get_microdescriptors() method.
"""
@@ -1267,7 +1310,7 @@ class TestController(unittest.TestCase):
if not os.path.exists(runner.get_test_dir('cached-microdescs')):
self.skipTest('(no cached microdescriptors)')
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_microdescriptors():
@@ -1279,14 +1322,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptor(self):
+ @async_test
+ async def test_get_server_descriptor(self):
"""
Basic checks for get_server_descriptor().
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_server_descriptor, '')
self.assertRaises(ValueError, controller.get_server_descriptor, 5)
@@ -1305,14 +1349,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptors(self):
+ @async_test
+ async def test_get_server_descriptors(self):
"""
Fetches a few descriptors via the get_server_descriptors() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_server_descriptors():
@@ -1330,12 +1375,13 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_status(self):
+ @async_test
+ async def test_get_network_status(self):
"""
Basic checks for get_network_status().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
self.assertRaises(ValueError, controller.get_network_status, '')
self.assertRaises(ValueError, controller.get_network_status, 5)
@@ -1354,14 +1400,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_statuses(self):
+ @async_test
+ async def test_get_network_statuses(self):
"""
Fetches a few descriptors via the get_network_statuses() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_network_statuses():
@@ -1377,14 +1424,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_hidden_service_descriptor(self):
+ @async_test
+ async def test_get_hidden_service_descriptor(self):
"""
Fetches a few descriptors via the get_hidden_service_descriptor() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ with await runner.get_tor_controller() as controller:
# fetch the descriptor for DuckDuckGo
desc = controller.get_hidden_service_descriptor('3g2upl4pq6kufc4m.onion')
@@ -1402,7 +1450,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_attachstream(self):
+ @async_test
+ async def test_attachstream(self):
host = socket.gethostbyname('www.torproject.org')
port = 80
@@ -1412,7 +1461,7 @@ class TestController(unittest.TestCase):
if stream.status == 'NEW' and circuit_id:
controller.attach_stream(stream.id, circuit_id)
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
# try 10 times to build a circuit we can connect through
for i in range(10):
controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
@@ -1442,24 +1491,26 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_circuits(self):
+ @async_test
+ async def test_get_circuits(self):
"""
Fetches circuits via the get_circuits() method.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
new_circ = controller.new_circuit()
circuits = controller.get_circuits()
self.assertTrue(new_circ in [circ.id for circ in circuits])
@test.require.controller
- def test_transition_to_relay(self):
+ @async_test
+ async def test_transition_to_relay(self):
"""
Transitions Tor to turn into a relay, then back to a client. This helps to
catch transition issues such as the one cited in :trac:`14901`.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
try:
controller.reset_conf('OrPort', 'DisableNetwork')
self.assertEqual(None, controller.get_conf('OrPort'))
diff --git a/test/integ/manual.py b/test/integ/manual.py
index 3e721ac6..df0c0105 100644
--- a/test/integ/manual.py
+++ b/test/integ/manual.py
@@ -14,6 +14,7 @@ import test
import test.runner
from stem.manual import Category
+from test.async_util import async_test
EXPECTED_CATEGORIES = set([
'NAME',
@@ -216,14 +217,15 @@ class TestManual(unittest.TestCase):
self.assertEqual(['tor - The second-generation onion router'], categories['NAME'])
self.assertEqual(['tor [OPTION value]...'], categories['SYNOPSIS'])
- def test_has_all_tor_config_options(self):
+ @async_test
+ async def test_has_all_tor_config_options(self):
"""
Check that all the configuration options tor supports are in the man page.
"""
self.requires_downloaded_manual()
- with test.runner.get_runner().get_tor_controller() as controller:
+ with await test.runner.get_runner().get_tor_controller() as controller:
config_options_in_tor = set([line.split()[0] for line in controller.get_info('config/names').splitlines() if line.split()[1] != 'Virtual'])
# options starting with an underscore are hidden by convention
diff --git a/test/integ/process.py b/test/integ/process.py
index 12cc493a..a2363fea 100644
--- a/test/integ/process.py
+++ b/test/integ/process.py
@@ -27,6 +27,7 @@ from contextlib import contextmanager
from unittest.mock import patch, Mock
from stem.util.test_tools import asynchronous, assert_equal, assert_in, skip
+from test.async_util import async_test
BASIC_RELAY_TORRC = """\
SocksPort 9089
@@ -97,9 +98,9 @@ class TestProcess(unittest.TestCase):
global TOR_CMD
TOR_CMD = args.tor_cmd
- for func, async_test in stem.util.test_tools.ASYNC_TESTS.items():
+ for func, asynchronous_test in stem.util.test_tools.ASYNC_TESTS.items():
if func.startswith('test.integ.process.'):
- async_test.run(TOR_CMD)
+ asynchronous_test.run(TOR_CMD)
@asynchronous
def test_version_argument(tor_cmd):
@@ -407,7 +408,8 @@ class TestProcess(unittest.TestCase):
raise AssertionError('Launching tor with the default timeout should be successful')
@asynchronous
- def test_launch_tor_with_config_via_file(tor_cmd):
+ @async_test
+ async def test_launch_tor_with_config_via_file(tor_cmd):
"""
Exercises launch_tor_with_config when we write a torrc to disk.
"""
@@ -432,23 +434,24 @@ class TestProcess(unittest.TestCase):
)
control_socket = stem.socket.ControlPort(port = int(control_port))
- stem.connection.authenticate(control_socket)
+ await stem.connection.authenticate(control_socket)
# exercises the socket
- control_socket.send('GETCONF ControlPort')
- getconf_response = control_socket.recv()
+ await control_socket.send('GETCONF ControlPort')
+ getconf_response = await control_socket.recv()
assert_equal('ControlPort=%s' % control_port, str(getconf_response))
finally:
if control_socket:
- control_socket.close()
+ await control_socket.close()
if tor_process:
tor_process.kill()
tor_process.wait()
@asynchronous
- def test_launch_tor_with_config_via_stdin(tor_cmd):
+ @async_test
+ async def test_launch_tor_with_config_via_stdin(tor_cmd):
"""
Exercises launch_tor_with_config when we provide our torrc via stdin.
"""
@@ -469,16 +472,16 @@ class TestProcess(unittest.TestCase):
)
control_socket = stem.socket.ControlPort(port = int(control_port))
- stem.connection.authenticate(control_socket)
+ await stem.connection.authenticate(control_socket)
# exercises the socket
- control_socket.send('GETCONF ControlPort')
- getconf_response = control_socket.recv()
+ await control_socket.send('GETCONF ControlPort')
+ getconf_response = await control_socket.recv()
assert_equal('ControlPort=%s' % control_port, str(getconf_response))
finally:
if control_socket:
- control_socket.close()
+ await control_socket.close()
if tor_process:
tor_process.kill()
diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py
index 3a9ee0be..f824be5d 100644
--- a/test/integ/response/protocolinfo.py
+++ b/test/integ/response/protocolinfo.py
@@ -16,20 +16,23 @@ import test.runner
from unittest.mock import Mock, patch
+from test.async_util import async_test
+
class TestProtocolInfo(unittest.TestCase):
@test.require.controller
- def test_parsing(self):
+ @async_test
+ async def test_parsing(self):
"""
Makes a PROTOCOLINFO query and processes the response for our control
connection.
"""
- control_socket = test.runner.get_runner().get_tor_socket(False)
- control_socket.send('PROTOCOLINFO 1')
- protocolinfo_response = control_socket.recv()
+ control_socket = await test.runner.get_runner().get_tor_socket(False)
+ await control_socket.send('PROTOCOLINFO 1')
+ protocolinfo_response = await control_socket.recv()
stem.response.convert('PROTOCOLINFO', protocolinfo_response)
- control_socket.close()
+ await control_socket.close()
# according to the control spec the following _could_ differ or be
# undefined but if that actually happens then it's gonna make people sad
@@ -43,7 +46,8 @@ class TestProtocolInfo(unittest.TestCase):
@test.require.controller
@patch('stem.util.proc.is_available', Mock(return_value = False))
@patch('stem.util.system.is_available', Mock(return_value = True))
- def test_get_protocolinfo_path_expansion(self):
+ @async_test
+ async def test_get_protocolinfo_path_expansion(self):
"""
If we're running with the 'RELATIVE' target then test_parsing() will
exercise cookie path expansion when we're able to query the pid by our
@@ -71,47 +75,51 @@ class TestProtocolInfo(unittest.TestCase):
control_socket = stem.socket.ControlSocketFile(test.runner.CONTROL_SOCKET_PATH)
+ await control_socket.connect()
+
call_replacement = test.integ.util.system.filter_system_call(lookup_prefixes)
with patch('stem.util.system.call') as call_mock:
call_mock.side_effect = call_replacement
- protocolinfo_response = stem.connection.get_protocolinfo(control_socket)
+ protocolinfo_response = await stem.connection.get_protocolinfo(control_socket)
self.assert_matches_test_config(protocolinfo_response)
# we should have a usable socket at this point
self.assertTrue(control_socket.is_alive())
- control_socket.close()
+ await control_socket.close()
@test.require.controller
- def test_multiple_protocolinfo_calls(self):
+ @async_test
+ async def test_multiple_protocolinfo_calls(self):
"""
Tests making repeated PROTOCOLINFO queries. This use case is interesting
because tor will shut down the socket and stem should transparently
re-establish it.
"""
- with test.runner.get_runner().get_tor_socket(False) as control_socket:
+ async with await test.runner.get_runner().get_tor_socket(False) as control_socket:
for _ in range(5):
- protocolinfo_response = stem.connection.get_protocolinfo(control_socket)
+ protocolinfo_response = await stem.connection.get_protocolinfo(control_socket)
self.assert_matches_test_config(protocolinfo_response)
@test.require.controller
- def test_pre_disconnected_query(self):
+ @async_test
+ async def test_pre_disconnected_query(self):
"""
Tests making a PROTOCOLINFO query when previous use of the socket had
already disconnected it.
"""
- with test.runner.get_runner().get_tor_socket(False) as control_socket:
+ async with await test.runner.get_runner().get_tor_socket(False) as control_socket:
# makes a couple protocolinfo queries outside of get_protocolinfo first
- control_socket.send('PROTOCOLINFO 1')
- control_socket.recv()
+ await control_socket.send('PROTOCOLINFO 1')
+ await control_socket.recv()
- control_socket.send('PROTOCOLINFO 1')
- control_socket.recv()
+ await control_socket.send('PROTOCOLINFO 1')
+ await control_socket.recv()
- protocolinfo_response = stem.connection.get_protocolinfo(control_socket)
+ protocolinfo_response = await stem.connection.get_protocolinfo(control_socket)
self.assert_matches_test_config(protocolinfo_response)
def assert_matches_test_config(self, protocolinfo_response):
diff --git a/test/integ/socket/control_message.py b/test/integ/socket/control_message.py
index e0a4cca2..80bf4762 100644
--- a/test/integ/socket/control_message.py
+++ b/test/integ/socket/control_message.py
@@ -9,11 +9,13 @@ import stem.socket
import stem.version
import test.require
import test.runner
+from test.async_util import async_test
class TestControlMessage(unittest.TestCase):
@test.require.controller
- def test_unestablished_socket(self):
+ @async_test
+ async def test_unestablished_socket(self):
"""
Checks message parsing when we have a valid but unauthenticated socket.
"""
@@ -22,10 +24,10 @@ class TestControlMessage(unittest.TestCase):
# PROTOCOLINFO then tor will give an 'Authentication required.' message and
# hang up.
- control_socket = test.runner.get_runner().get_tor_socket(False)
- control_socket.send('GETINFO version')
+ control_socket = await test.runner.get_runner().get_tor_socket(False)
+ await control_socket.send('GETINFO version')
- auth_required_response = control_socket.recv()
+ auth_required_response = await control_socket.recv()
self.assertEqual('Authentication required.', str(auth_required_response))
self.assertEqual(['Authentication required.'], list(auth_required_response))
self.assertEqual('514 Authentication required.\r\n', auth_required_response.raw_content())
@@ -35,54 +37,64 @@ class TestControlMessage(unittest.TestCase):
# checked in more depth by the ControlSocket integ tests.
self.assertTrue(control_socket.is_alive())
- self.assertRaises(stem.SocketClosed, control_socket.recv)
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
self.assertFalse(control_socket.is_alive())
# Additional socket usage should fail, and pulling more responses will fail
# with more closed exceptions.
- self.assertRaises(stem.SocketError, control_socket.send, 'GETINFO version')
- self.assertRaises(stem.SocketClosed, control_socket.recv)
- self.assertRaises(stem.SocketClosed, control_socket.recv)
- self.assertRaises(stem.SocketClosed, control_socket.recv)
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.send('GETINFO version')
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
# The socket connection is already broken so calling close shouldn't have
# an impact.
- control_socket.close()
- self.assertRaises(stem.SocketClosed, control_socket.send, 'GETINFO version')
- self.assertRaises(stem.SocketClosed, control_socket.recv)
+ await control_socket.close()
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.send('GETINFO version')
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
@test.require.controller
- def test_invalid_command(self):
+ @async_test
+ async def test_invalid_command(self):
"""
Parses the response for a command which doesn't exist.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
- control_socket.send('blarg')
- unrecognized_command_response = control_socket.recv()
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
+ await control_socket.send('blarg')
+ unrecognized_command_response = await control_socket.recv()
self.assertEqual('Unrecognized command "blarg"', str(unrecognized_command_response))
self.assertEqual(['Unrecognized command "blarg"'], list(unrecognized_command_response))
self.assertEqual('510 Unrecognized command "blarg"\r\n', unrecognized_command_response.raw_content())
self.assertEqual([('510', ' ', 'Unrecognized command "blarg"')], unrecognized_command_response.content())
@test.require.controller
- def test_invalid_getinfo(self):
+ @async_test
+ async def test_invalid_getinfo(self):
"""
Parses the response for a GETINFO query which doesn't exist.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
- control_socket.send('GETINFO blarg')
- unrecognized_key_response = control_socket.recv()
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
+ await control_socket.send('GETINFO blarg')
+ unrecognized_key_response = await control_socket.recv()
self.assertEqual('Unrecognized key "blarg"', str(unrecognized_key_response))
self.assertEqual(['Unrecognized key "blarg"'], list(unrecognized_key_response))
self.assertEqual('552 Unrecognized key "blarg"\r\n', unrecognized_key_response.raw_content())
self.assertEqual([('552', ' ', 'Unrecognized key "blarg"')], unrecognized_key_response.content())
@test.require.controller
- def test_getinfo_config_file(self):
+ @async_test
+ async def test_getinfo_config_file(self):
"""
Parses the 'GETINFO config-file' response.
"""
@@ -90,16 +102,17 @@ class TestControlMessage(unittest.TestCase):
runner = test.runner.get_runner()
torrc_dst = runner.get_torrc_path()
- with runner.get_tor_socket() as control_socket:
- control_socket.send('GETINFO config-file')
- config_file_response = control_socket.recv()
+ async with await runner.get_tor_socket() as control_socket:
+ await control_socket.send('GETINFO config-file')
+ config_file_response = await control_socket.recv()
self.assertEqual('config-file=%s\nOK' % torrc_dst, str(config_file_response))
self.assertEqual(['config-file=%s' % torrc_dst, 'OK'], list(config_file_response))
self.assertEqual('250-config-file=%s\r\n250 OK\r\n' % torrc_dst, config_file_response.raw_content())
self.assertEqual([('250', '-', 'config-file=%s' % torrc_dst), ('250', ' ', 'OK')], config_file_response.content())
@test.require.controller
- def test_getinfo_config_text(self):
+ @async_test
+ async def test_getinfo_config_text(self):
"""
Parses the 'GETINFO config-text' response.
"""
@@ -120,9 +133,9 @@ class TestControlMessage(unittest.TestCase):
if line and not line.startswith('#'):
torrc_contents.append(line)
- with runner.get_tor_socket() as control_socket:
- control_socket.send('GETINFO config-text')
- config_text_response = control_socket.recv()
+ async with await runner.get_tor_socket() as control_socket:
+ await control_socket.send('GETINFO config-text')
+ config_text_response = await control_socket.recv()
# the response should contain two entries, the first being a data response
self.assertEqual(2, len(list(config_text_response)))
@@ -140,14 +153,15 @@ class TestControlMessage(unittest.TestCase):
self.assertTrue('%s' % torrc_entry in config_text_response.content()[0][2])
@test.require.controller
- def test_setconf_event(self):
+ @async_test
+ async def test_setconf_event(self):
"""
Issues 'SETEVENTS CONF_CHANGED' and parses an events.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
- control_socket.send('SETEVENTS CONF_CHANGED')
- setevents_response = control_socket.recv()
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
+ await control_socket.send('SETEVENTS CONF_CHANGED')
+ setevents_response = await control_socket.recv()
self.assertEqual('OK', str(setevents_response))
self.assertEqual(['OK'], list(setevents_response))
self.assertEqual('250 OK\r\n', setevents_response.raw_content())
@@ -156,9 +170,9 @@ class TestControlMessage(unittest.TestCase):
# We'll receive both a CONF_CHANGED event and 'OK' response for the
# SETCONF, but not necessarily in any specific order.
- control_socket.send('SETCONF NodeFamily=FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
- msg1 = control_socket.recv()
- msg2 = control_socket.recv()
+ await control_socket.send('SETCONF NodeFamily=FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
+ msg1 = await control_socket.recv()
+ msg2 = await control_socket.recv()
if msg1.content()[0][0] == '650':
conf_changed_event, setconf_response = msg1, msg2
diff --git a/test/integ/socket/control_socket.py b/test/integ/socket/control_socket.py
index f479892c..bb2d8873 100644
--- a/test/integ/socket/control_socket.py
+++ b/test/integ/socket/control_socket.py
@@ -8,6 +8,7 @@ those focus on parsing and correctness of the content these are more concerned
with the behavior of the socket itself.
"""
+import asyncio
import time
import unittest
@@ -17,11 +18,13 @@ import stem.socket
import test
import test.require
import test.runner
+from test.async_util import async_test
class TestControlSocket(unittest.TestCase):
@test.require.controller
- def test_connection_time(self):
+ @async_test
+ async def test_connection_time(self):
"""
Checks that our connection_time method tracks when our state's changed.
"""
@@ -29,7 +32,7 @@ class TestControlSocket(unittest.TestCase):
test_start = time.time()
runner = test.runner.get_runner()
- with runner.get_tor_socket() as control_socket:
+ async with await runner.get_tor_socket() as control_socket:
connection_time = control_socket.connection_time()
# connection time should be between our tests start and now
@@ -38,54 +41,58 @@ class TestControlSocket(unittest.TestCase):
# connection time should be absolute (shouldn't change as time goes on)
- time.sleep(0.001)
+ await asyncio.sleep(0.001)
self.assertEqual(connection_time, control_socket.connection_time())
# should change to the disconnection time if we detactch
- control_socket.close()
+ await control_socket.close()
disconnection_time = control_socket.connection_time()
self.assertTrue(connection_time < disconnection_time <= time.time())
# then change again if we reconnect
- time.sleep(0.001)
- control_socket.connect()
+ await asyncio.sleep(0.001)
+ await control_socket.connect()
reconnection_time = control_socket.connection_time()
self.assertTrue(disconnection_time < reconnection_time <= time.time())
@test.require.controller
- def test_send_buffered(self):
+ @async_test
+ async def test_send_buffered(self):
"""
Sends multiple requests before receiving back any of the replies.
"""
runner = test.runner.get_runner()
- with runner.get_tor_socket() as control_socket:
+ async with await runner.get_tor_socket() as control_socket:
for _ in range(100):
- control_socket.send('GETINFO version')
+ await control_socket.send('GETINFO version')
for _ in range(100):
- response = control_socket.recv()
+ response = await control_socket.recv()
self.assertTrue(str(response).startswith('version=%s' % test.tor_version()))
self.assertTrue(str(response).endswith('\nOK'))
@test.require.controller
- def test_send_closed(self):
+ @async_test
+ async def test_send_closed(self):
"""
Sends a message after we've closed the connection.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
self.assertTrue(control_socket.is_alive())
- control_socket.close()
+ await control_socket.close()
self.assertFalse(control_socket.is_alive())
- self.assertRaises(stem.SocketClosed, control_socket.send, 'blarg')
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.send('blarg')
@test.require.controller
- def test_send_disconnected(self):
+ @async_test
+ async def test_send_disconnected(self):
"""
Sends a message to a socket that has been disconnected by the other end.
@@ -95,64 +102,71 @@ class TestControlSocket(unittest.TestCase):
call. With a file socket, however, we'll also fail when calling send().
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
- control_socket.send('QUIT')
- self.assertEqual('closing connection', str(control_socket.recv()))
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
+ await control_socket.send('QUIT')
+ self.assertEqual('closing connection', str(await control_socket.recv()))
self.assertTrue(control_socket.is_alive())
# If we send another message to a port based socket then it will seem to
# succeed. However, a file based socket should report a failure.
if isinstance(control_socket, stem.socket.ControlPort):
- control_socket.send('blarg')
+ await control_socket.send('blarg')
self.assertTrue(control_socket.is_alive())
else:
- self.assertRaises(stem.SocketClosed, control_socket.send, 'blarg')
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.send('blarg')
self.assertFalse(control_socket.is_alive())
@test.require.controller
- def test_recv_closed(self):
+ @async_test
+ async def test_recv_closed(self):
"""
Receives a message after we've closed the connection.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
self.assertTrue(control_socket.is_alive())
- control_socket.close()
+ await control_socket.close()
self.assertFalse(control_socket.is_alive())
- self.assertRaises(stem.SocketClosed, control_socket.recv)
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
@test.require.controller
- def test_recv_disconnected(self):
+ @async_test
+ async def test_recv_disconnected(self):
"""
Receives a message from a socket that has been disconnected by the other
end.
"""
- with test.runner.get_runner().get_tor_socket() as control_socket:
- control_socket.send('QUIT')
- self.assertEqual('closing connection', str(control_socket.recv()))
+ async with await test.runner.get_runner().get_tor_socket() as control_socket:
+ await control_socket.send('QUIT')
+ self.assertEqual('closing connection', str(await control_socket.recv()))
# Neither a port or file based socket will know that tor has hung up on
# the connection at this point. We should know after calling recv(),
# however.
self.assertTrue(control_socket.is_alive())
- self.assertRaises(stem.SocketClosed, control_socket.recv)
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.recv()
self.assertFalse(control_socket.is_alive())
@test.require.controller
- def test_connect_repeatedly(self):
+ @async_test
+ async def test_connect_repeatedly(self):
"""
Checks that we can reconnect, use, and disconnect a socket repeatedly.
"""
- with test.runner.get_runner().get_tor_socket(False) as control_socket:
+ async with await test.runner.get_runner().get_tor_socket(False) as control_socket:
for _ in range(10):
# this will raise if the PROTOCOLINFO query fails
- stem.connection.get_protocolinfo(control_socket)
+ await stem.connection.get_protocolinfo(control_socket)
- control_socket.close()
- self.assertRaises(stem.SocketClosed, control_socket.send, 'PROTOCOLINFO 1')
- control_socket.connect()
+ await control_socket.close()
+ with self.assertRaises(stem.SocketClosed):
+ await control_socket.send('PROTOCOLINFO 1')
+ await control_socket.connect()
diff --git a/test/integ/util/connection.py b/test/integ/util/connection.py
index f1745ec0..c35d8448 100644
--- a/test/integ/util/connection.py
+++ b/test/integ/util/connection.py
@@ -13,11 +13,13 @@ import test.require
import test.runner
from stem.util.connection import Resolver
+from test.async_util import async_test
class TestConnection(unittest.TestCase):
@test.require.ptrace
- def check_resolver(self, resolver):
+ @async_test
+ async def check_resolver(self, resolver):
runner = test.runner.get_runner()
if test.runner.Torrc.PORT not in runner.get_options():
@@ -25,7 +27,7 @@ class TestConnection(unittest.TestCase):
elif resolver not in stem.util.connection.system_resolvers():
self.skipTest('(resolver unavailable on this platform)')
- with runner.get_tor_socket():
+ async with await runner.get_tor_socket():
connections = stem.util.connection.get_connections(resolver, process_pid = runner.get_pid())
for conn in connections:
diff --git a/test/integ/util/proc.py b/test/integ/util/proc.py
index 315082d5..4038984c 100644
--- a/test/integ/util/proc.py
+++ b/test/integ/util/proc.py
@@ -10,6 +10,7 @@ import test.require
import test.runner
from stem.util import proc
+from test.async_util import async_test
class TestProc(unittest.TestCase):
@@ -63,7 +64,8 @@ class TestProc(unittest.TestCase):
@test.require.proc
@test.require.ptrace
- def test_connections(self):
+ @async_test
+ async def test_connections(self):
"""
Checks for our control port in the stem.util.proc.connections output if
we have one.
@@ -78,7 +80,7 @@ class TestProc(unittest.TestCase):
self.skipTest('(proc lacks read permissions)')
# making a controller connection so that we have something to query for
- with runner.get_tor_socket():
+ async with await runner.get_tor_socket():
tor_pid = test.runner.get_runner().get_pid()
for conn in proc.connections(tor_pid):
diff --git a/test/integ/version.py b/test/integ/version.py
index 641629d4..d02014a5 100644
--- a/test/integ/version.py
+++ b/test/integ/version.py
@@ -8,6 +8,7 @@ import unittest
import stem.version
import test.require
import test.runner
+from test.async_util import async_test
class TestVersion(unittest.TestCase):
@@ -30,16 +31,17 @@ class TestVersion(unittest.TestCase):
self.assertRaises(IOError, stem.version.get_system_tor_version, 'blarg')
@test.require.controller
- def test_getinfo_version_parsing(self):
+ @async_test
+ async def test_getinfo_version_parsing(self):
"""
Issues a 'GETINFO version' query to our test instance and makes sure that
we can parse it.
"""
- control_socket = test.runner.get_runner().get_tor_socket()
- control_socket.send('GETINFO version')
- version_response = control_socket.recv()
- control_socket.close()
+ control_socket = await test.runner.get_runner().get_tor_socket()
+ await control_socket.send('GETINFO version')
+ version_response = await control_socket.recv()
+ await control_socket.close()
# the getinfo response looks like...
# 250-version=0.2.3.10-alpha-dev (git-65420e4cb5edcd02)
diff --git a/test/runner.py b/test/runner.py
index a8079908..4a38e824 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -32,6 +32,7 @@ about the tor test instance they're running against.
+- get_tor_command - provides the command used to start tor
"""
+import asyncio
import logging
import os
import shutil
@@ -86,8 +87,8 @@ class TorInaccessable(Exception):
'Raised when information is needed from tor but the instance we have is inaccessible'
-def exercise_controller(test_case, controller):
- """
+async def exercise_controller(test_case, controller):
+ """with await test.runner.get_runner().get_tor_socket
Checks that we can now use the socket by issuing a 'GETINFO config-file'
query. Controller can be either a :class:`stem.socket.ControlSocket` or
:class:`stem.control.BaseController`.
@@ -100,10 +101,12 @@ def exercise_controller(test_case, controller):
torrc_path = runner.get_torrc_path()
if isinstance(controller, stem.socket.ControlSocket):
- controller.send('GETINFO config-file')
- config_file_response = controller.recv()
+ await controller.send('GETINFO config-file')
+ config_file_response = await controller.recv()
else:
config_file_response = controller.msg('GETINFO config-file')
+ if asyncio.iscoroutine(config_file_response):
+ config_file_response = await config_file_response
test_case.assertEqual('config-file=%s\nOK' % torrc_path, str(config_file_response))
@@ -134,8 +137,8 @@ class _MockChrootFile(object):
self.wrapped_file = wrapped_file
self.strip_text = strip_text
- def readline(self):
- return self.wrapped_file.readline().replace(self.strip_text, '')
+ async def readline(self):
+ return (await self.wrapped_file.readline()).replace(self.strip_text, '')
class Runner(object):
@@ -252,13 +255,15 @@ class Runner(object):
self._original_recv_message = stem.socket.recv_message
self._chroot_path = data_dir_path
- def _chroot_recv_message(control_file):
- return self._original_recv_message(_MockChrootFile(control_file, data_dir_path))
+ async def _chroot_recv_message(control_file):
+ return await self._original_recv_message(_MockChrootFile(control_file, data_dir_path))
stem.socket.recv_message = _chroot_recv_message
if self.is_accessible():
- self._owner_controller = self.get_tor_controller(True)
+ self._owner_controller = stem.control.Controller(self._get_unconnected_socket(), False)
+ self._owner_controller.connect()
+ self._authenticate_controller(self._owner_controller)
if test.Target.RELATIVE in self.attribute_targets:
os.chdir(original_cwd) # revert our cwd back to normal
@@ -440,7 +445,17 @@ class Runner(object):
tor_process = self._get('_tor_process')
return tor_process.pid
- def get_tor_socket(self, authenticate = True):
+ def _get_unconnected_socket(self):
+ if Torrc.PORT in self._custom_opts:
+ control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
+ elif Torrc.SOCKET in self._custom_opts:
+ control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
+ else:
+ raise TorInaccessable('Unable to connect to tor')
+
+ return control_socket
+
+ async def get_tor_socket(self, authenticate = True):
"""
Provides a socket connected to our tor test instance.
@@ -451,19 +466,18 @@ class Runner(object):
:raises: :class:`test.runner.TorInaccessable` if tor can't be connected to
"""
- if Torrc.PORT in self._custom_opts:
- control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
- elif Torrc.SOCKET in self._custom_opts:
- control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
- else:
- raise TorInaccessable('Unable to connect to tor')
+ control_socket = self._get_unconnected_socket()
+ await control_socket.connect()
if authenticate:
- stem.connection.authenticate(control_socket, CONTROL_PASSWORD, self.get_chroot())
+ await stem.connection.authenticate(control_socket, CONTROL_PASSWORD, self.get_chroot())
return control_socket
- def get_tor_controller(self, authenticate = True):
+ def _authenticate_controller(self, controller):
+ controller.authenticate(password=CONTROL_PASSWORD, chroot_path=self.get_chroot())
+
+ async def get_tor_controller(self, authenticate = True):
"""
Provides a controller connected to our tor test instance.
@@ -474,11 +488,19 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- control_socket = self.get_tor_socket(False)
- controller = stem.control.Controller(control_socket)
+ async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread.start()
+
+ try:
+ control_socket = asyncio.run_coroutine_threadsafe(self.get_tor_socket(False), async_controller_thread.loop).result()
+ controller = stem.control.Controller(control_socket, started_async_controller_thread = async_controller_thread)
+ except Exception:
+ if async_controller_thread.is_alive():
+ async_controller_thread.join()
+ raise
if authenticate:
- controller.authenticate(password = CONTROL_PASSWORD, chroot_path = self.get_chroot())
+ self._authenticate_controller(controller)
return controller
diff --git a/test/unit/connection/authentication.py b/test/unit/connection/authentication.py
index 8df38c8f..5e59adae 100644
--- a/test/unit/connection/authentication.py
+++ b/test/unit/connection/authentication.py
@@ -18,7 +18,7 @@ from unittest.mock import patch
from stem.response import ControlMessage
from stem.util import log
-from test.unit.async_util import (
+from test.async_util import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 2112f678..3a0e0767 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -11,7 +11,7 @@ import stem.socket
from unittest.mock import Mock, patch
-from test.unit.async_util import (
+from test.async_util import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 99bd5f19..4c03dea3 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -21,7 +21,7 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
from stem.response import ControlMessage
from stem.exit_policy import ExitPolicy
-from test.unit.async_util import (
+from test.async_util import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
1
0

[stem/master] Prepare to creating and wrapping one more asynchronous class
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 79e8c1b47c63dfd49b62ec47c1b7902f51b06a83
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu May 14 00:06:49 2020 +0300
Prepare to creating and wrapping one more asynchronous class
---
stem/connection.py | 2 +-
stem/control.py | 108 ++++++++-------------------------------
stem/interpreter/__init__.py | 2 +-
stem/interpreter/commands.py | 5 +-
stem/util/__init__.py | 81 +++++++++++++++++++++++++++++
test/integ/control/controller.py | 2 +-
test/runner.py | 2 +-
test/unit/control/controller.py | 4 +-
8 files changed, 110 insertions(+), 96 deletions(-)
diff --git a/stem/connection.py b/stem/connection.py
index c44fddb1..8f57f3b3 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -257,7 +257,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
if controller is None or not issubclass(controller, stem.control.Controller):
raise ValueError('Controller should be a stem.control.BaseController subclass.')
- async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
async_controller_thread.start()
connect_coroutine = _connect_async(control_port, control_socket, password, password_prompt, chroot_path, controller)
diff --git a/stem/control.py b/stem/control.py
index 6de671b6..1488621a 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -553,29 +553,6 @@ def event_description(event: str) -> str:
return EVENT_DESCRIPTIONS.get(event.lower())
-class _MsgLock:
- __slots__ = ('_r_lock', '_async_lock')
-
- def __init__(self):
- self._r_lock = threading.RLock()
- self._async_lock = asyncio.Lock()
-
- async def acquire(self):
- await self._async_lock.acquire()
- self._r_lock.acquire()
-
- def release(self):
- self._r_lock.release()
- self._async_lock.release()
-
- async def __aenter__(self):
- await self.acquire()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- self.release()
-
-
class _BaseControllerSocketMixin:
def is_alive(self) -> bool:
"""
@@ -644,7 +621,7 @@ class BaseController(_BaseControllerSocketMixin):
self._asyncio_loop = asyncio.get_event_loop()
- self._msg_lock = _MsgLock()
+ self._msg_lock = stem.util.CombinedReentrantAndAsyncioLock()
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
@@ -3901,22 +3878,7 @@ class AsyncController(_ControllerClassMethodMixin, BaseController):
return (set_events, failed_events)
-class _AsyncControllerThread(threading.Thread):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, *kwargs)
- self.loop = asyncio.new_event_loop()
- self.setDaemon(True)
-
- def run(self):
- self.loop.run_forever()
-
- def join(self, timeout = None):
- self.loop.call_soon_threadsafe(self.loop.stop)
- super().join(timeout)
- self.loop.close()
-
-
-class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
+class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
@classmethod
def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
instance = super().from_port(address, port)
@@ -3932,48 +3894,19 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
if started_async_controller_thread:
- self._async_controller_thread = started_async_controller_thread
+ self._thread_for_wrapped_class = started_async_controller_thread
else:
- self._async_controller_thread = _AsyncControllerThread()
- self._async_controller_thread.start()
- self._asyncio_loop = self._async_controller_thread.loop
-
- self._async_controller = self._init_async_controller(control_socket, is_authenticated)
- self._socket = self._async_controller._socket
-
- def _init_async_controller(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool) -> 'stem.control.AsyncController':
- # The asynchronous controller should be initialized in the thread where its
- # methods will be executed.
- if self._async_controller_thread != threading.current_thread():
- async def init_async_controller() -> 'stem.control.AsyncController':
- return AsyncController(control_socket, is_authenticated)
-
- return asyncio.run_coroutine_threadsafe(init_async_controller(), self._asyncio_loop).result()
-
- return AsyncController(control_socket, is_authenticated)
-
- def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- return asyncio.run_coroutine_threadsafe(
- getattr(self._async_controller, method_name)(*args, **kwargs),
- self._asyncio_loop,
- ).result()
-
- def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- async def convert_async_generator(generator):
- return iter([d async for d in generator])
+ self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+ self._thread_for_wrapped_class.start()
- return asyncio.run_coroutine_threadsafe(
- convert_async_generator(
- getattr(self._async_controller, method_name)(*args, **kwargs),
- ),
- self._asyncio_loop,
- ).result()
+ self._wrapped_instance = self._init_async_class(AsyncController, control_socket, is_authenticated)
+ self._socket = self._wrapped_instance._socket
def msg(self, message: str) -> stem.response.ControlMessage:
return self._execute_async_method('msg', message)
def is_authenticated(self) -> bool:
- return self._async_controller.is_authenticated()
+ return self._wrapped_instance.is_authenticated()
def connect(self) -> None:
self._execute_async_method('connect')
@@ -3985,13 +3918,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('close')
def get_latest_heartbeat(self) -> float:
- return self._async_controller.get_latest_heartbeat()
+ return self._wrapped_instance.get_latest_heartbeat()
def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
- self._async_controller.add_status_listener(callback, spawn)
+ self._wrapped_instance.add_status_listener(callback, spawn)
def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
- self._async_controller.remove_status_listener(callback)
+ self._wrapped_instance.remove_status_listener(callback)
def authenticate(self, *args: Any, **kwargs: Any) -> None:
self._execute_async_method('authenticate', *args, **kwargs)
@@ -4099,13 +4032,13 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('remove_event_listener', listener)
def is_caching_enabled(self) -> bool:
- return self._async_controller.is_caching_enabled()
+ return self._wrapped_instance.is_caching_enabled()
def set_caching(self, enabled: bool) -> None:
- self._async_controller.set_caching(enabled)
+ self._wrapped_instance.set_caching(enabled)
def clear_cache(self) -> None:
- self._async_controller.clear_cache()
+ self._wrapped_instance.clear_cache()
def load_conf(self, configtext: str) -> None:
self._execute_async_method('load_conf', configtext)
@@ -4114,10 +4047,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
return self._execute_async_method('save_conf', force)
def is_feature_enabled(self, feature: str) -> bool:
- return self._async_controller.is_feature_enabled(feature)
+ return self._wrapped_instance.is_feature_enabled(feature)
def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
- self._async_controller.enable_feature(features)
+ self._wrapped_instance.enable_feature(features)
def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
return self._execute_async_method('get_circuit', circuit_id, default)
@@ -4150,10 +4083,10 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('signal', signal)
def is_newnym_available(self) -> bool:
- return self._async_controller.is_newnym_available()
+ return self._wrapped_instance.is_newnym_available()
def get_newnym_wait(self) -> float:
- return self._async_controller.get_newnym_wait()
+ return self._wrapped_instance.get_newnym_wait()
def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
return self._execute_async_method('get_effective_rate', default, burst)
@@ -4165,8 +4098,9 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin):
self._execute_async_method('drop_guards')
def __del__(self) -> None:
- if self._asyncio_loop.is_running():
- self._asyncio_loop.call_soon_threadsafe(self._asyncio_loop.stop)
+ loop = self._thread_for_wrapped_class.loop
+ if loop.is_running():
+ loop.call_soon_threadsafe(loop.stop)
def __enter__(self) -> 'stem.control.Controller':
return self
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 07353d44..ae064a0a 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -127,7 +127,7 @@ def main() -> None:
async def handle_event(event_message):
print(format(str(event_message), *STANDARD_OUTPUT))
- controller._async_controller._handle_event = handle_event
+ controller._wrapped_instance._handle_event = handle_event
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 0d262ab5..edbcca70 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole):
# Intercept events our controller hears about at a pretty low level since
# the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._async_controller._handle_event
+ handle_event_real = self._controller._wrapped_instance._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
await handle_event_real(event_message)
@@ -139,8 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._async_controller._handle_event = handle_event_wrapper
- self._controller._handle_event = handle_event_wrapper # type: ignore
+ self._controller._wrapped_instance._handle_event = handle_event_wrapper
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index e4fa3ca8..a230cfbd 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -5,7 +5,9 @@
Utility functions used by the stem library.
"""
+import asyncio
import datetime
+import threading
from typing import Any, Union
@@ -139,3 +141,82 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
setattr(obj, '_cached_hash', my_hash)
return my_hash
+
+
+class CombinedReentrantAndAsyncioLock:
+ """
+ Lock that combines thread-safe reentrant and not thread-safe asyncio locks.
+ """
+
+ __slots__ = ('_r_lock', '_async_lock')
+
+ def __init__(self):
+ self._r_lock = threading.RLock()
+ self._async_lock = asyncio.Lock()
+
+ async def acquire(self):
+ await self._async_lock.acquire()
+ self._r_lock.acquire()
+
+ def release(self):
+ self._r_lock.release()
+ self._async_lock.release()
+
+ async def __aenter__(self):
+ await self.acquire()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ self.release()
+
+
+class ThreadForWrappedAsyncClass(threading.Thread):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, *kwargs)
+ self.loop = asyncio.new_event_loop()
+ self.setDaemon(True)
+
+ def run(self):
+ self.loop.run_forever()
+
+ def join(self, timeout=None):
+ self.loop.call_soon_threadsafe(self.loop.stop)
+ super().join(timeout)
+ self.loop.close()
+
+
+class AsyncClassWrapper:
+ _thread_for_wrapped_class: ThreadForWrappedAsyncClass
+ _wrapped_instance: type
+
+ def _init_async_class(self, async_class, *args, **kwargs):
+ thread = self._thread_for_wrapped_class
+ # The asynchronous class should be initialized in the thread where
+ # its methods will be executed.
+ if thread != threading.current_thread():
+ async def init():
+ return async_class(*args, **kwargs)
+
+ return asyncio.run_coroutine_threadsafe(init(), thread.loop).result()
+
+ return async_class(*args, **kwargs)
+
+ def _call_async_method_soon(self, method_name, *args, **kwargs):
+ return asyncio.run_coroutine_threadsafe(
+ getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+ self._thread_for_wrapped_class.loop,
+ )
+
+ def _execute_async_method(self, method_name, *args, **kwargs):
+ return self._call_async_method_soon(method_name, *args, **kwargs).result()
+
+ def _execute_async_generator_method(self, method_name, *args, **kwargs):
+ async def convert_async_generator(generator):
+ return iter([d async for d in generator])
+
+ return asyncio.run_coroutine_threadsafe(
+ convert_async_generator(
+ getattr(self._wrapped_instance, method_name)(*args, **kwargs),
+ ),
+ self._thread_for_wrapped_class.loop,
+ ).result()
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 73e71fa4..b1772f34 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -113,7 +113,7 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._async_controller, state_controller)
+ self.assertEqual(controller._wrapped_instance, state_controller)
self.assertEqual(State.RESET, state_type)
self.assertTrue(state_timestamp > before and state_timestamp < after)
diff --git a/test/runner.py b/test/runner.py
index 4a38e824..189a2d7b 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -488,7 +488,7 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- async_controller_thread = stem.control._AsyncControllerThread()
+ async_controller_thread = stem.util.ThreadForWrappedAsyncClass()
async_controller_thread.start()
try:
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index e8ef4787..a11aba45 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -44,7 +44,7 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
self.controller = Controller(socket)
- self.async_controller = self.controller._async_controller
+ self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock()
self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -748,7 +748,7 @@ class TestControl(unittest.TestCase):
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
- loop = self.controller._asyncio_loop
+ loop = self.controller._thread_for_wrapped_class.loop
asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
try:
1
0

[stem/master] Fix `stem.connection._msg` operating a synchronous controller
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 0be69071af23ed6cb753188160e29bfbcfab328c
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu Apr 30 19:30:52 2020 +0300
Fix `stem.connection._msg` operating a synchronous controller
---
stem/connection.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/stem/connection.py b/stem/connection.py
index 3d240070..c44fddb1 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -1025,7 +1025,10 @@ async def _msg(controller: Union[stem.control.BaseController, stem.socket.Contro
await controller.send(message)
return await controller.recv()
else:
- return await controller.msg(message)
+ message = controller.msg(message)
+ if asyncio.iscoroutine(message):
+ message = await message
+ return message
def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
1
0

[stem/master] Move utility functions for asynchronous tests to `stem.util.test_tools`
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 2b8b8fd51c74af4509f8745f22d081191dbe63b5
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu Apr 30 19:52:42 2020 +0300
Move utility functions for asynchronous tests to `stem.util.test_tools`
---
stem/util/test_tools.py | 26 ++++++++++++++++++++++++++
test/async_util.py | 26 --------------------------
test/integ/connection/authentication.py | 2 +-
test/integ/connection/connect.py | 2 +-
test/integ/control/base_controller.py | 2 +-
test/integ/control/controller.py | 2 +-
test/integ/manual.py | 2 +-
test/integ/process.py | 3 +--
test/integ/response/protocolinfo.py | 3 +--
test/integ/socket/control_message.py | 2 +-
test/integ/socket/control_socket.py | 2 +-
test/integ/util/connection.py | 2 +-
test/integ/util/proc.py | 2 +-
test/integ/version.py | 2 +-
test/unit/connection/authentication.py | 2 +-
test/unit/connection/connect.py | 2 +-
test/unit/control/controller.py | 2 +-
17 files changed, 41 insertions(+), 43 deletions(-)
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 36b4110e..f3c736a1 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -30,7 +30,9 @@ to match just against the prefix or suffix. For instance...
type_issues - checks for type problems
"""
+import asyncio
import collections
+import functools
import linecache
import multiprocessing
import os
@@ -680,3 +682,27 @@ def _is_ignored(config: Mapping[str, Sequence[str]], path: str, issue: str) -> b
return True # suffix match
return False
+
+
+def async_test(func: Callable) -> Callable:
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ loop = asyncio.new_event_loop()
+ try:
+ result = loop.run_until_complete(func(*args, **kwargs))
+ finally:
+ loop.close()
+ return result
+ return wrapper
+
+
+def coro_func_returning_value(return_value):
+ async def coroutine_func(*args, **kwargs):
+ return return_value
+ return coroutine_func
+
+
+def coro_func_raising_exc(exc):
+ async def coroutine_func(*args, **kwargs):
+ raise exc
+ return coroutine_func
diff --git a/test/async_util.py b/test/async_util.py
deleted file mode 100644
index aa8a6dd0..00000000
--- a/test/async_util.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import asyncio
-import functools
-
-
-def async_test(func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- loop = asyncio.new_event_loop()
- try:
- result = loop.run_until_complete(func(*args, **kwargs))
- finally:
- loop.close()
- return result
- return wrapper
-
-
-def coro_func_returning_value(return_value):
- async def coroutine_func(*args, **kwargs):
- return return_value
- return coroutine_func
-
-
-def coro_func_raising_exc(exc):
- async def coroutine_func(*args, **kwargs):
- raise exc
- return coroutine_func
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index b992ac9a..683e555f 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -12,7 +12,7 @@ import stem.version
import test
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
# Responses given by tor for various authentication failures. These may change
# in the future and if they do then this test should be updated.
diff --git a/test/integ/connection/connect.py b/test/integ/connection/connect.py
index b1d2a672..399598bc 100644
--- a/test/integ/connection/connect.py
+++ b/test/integ/connection/connect.py
@@ -8,7 +8,7 @@ import unittest
import stem.connection
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
from unittest.mock import patch
diff --git a/test/integ/control/base_controller.py b/test/integ/control/base_controller.py
index ff51e2f1..ac5f1e56 100644
--- a/test/integ/control/base_controller.py
+++ b/test/integ/control/base_controller.py
@@ -15,7 +15,7 @@ import stem.socket
import stem.util.system
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class StateObserver(object):
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index ab1e76c3..73e71fa4 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -23,11 +23,11 @@ import test
import test.network
import test.require
import test.runner
-from test.async_util import async_test
from stem import Flag, Signal
from stem.control import EventType, Listener, State
from stem.exit_policy import ExitPolicy
+from stem.util.test_tools import async_test
# Router status entry for a relay with a nickname other than 'Unnamed'. This is
# used for a few tests that need to look up a relay.
diff --git a/test/integ/manual.py b/test/integ/manual.py
index df0c0105..d285c758 100644
--- a/test/integ/manual.py
+++ b/test/integ/manual.py
@@ -14,7 +14,7 @@ import test
import test.runner
from stem.manual import Category
-from test.async_util import async_test
+from stem.util.test_tools import async_test
EXPECTED_CATEGORIES = set([
'NAME',
diff --git a/test/integ/process.py b/test/integ/process.py
index a2363fea..30cb0430 100644
--- a/test/integ/process.py
+++ b/test/integ/process.py
@@ -26,8 +26,7 @@ import test.require
from contextlib import contextmanager
from unittest.mock import patch, Mock
-from stem.util.test_tools import asynchronous, assert_equal, assert_in, skip
-from test.async_util import async_test
+from stem.util.test_tools import async_test, asynchronous, assert_equal, assert_in, skip
BASIC_RELAY_TORRC = """\
SocksPort 9089
diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py
index f824be5d..244c9e81 100644
--- a/test/integ/response/protocolinfo.py
+++ b/test/integ/response/protocolinfo.py
@@ -13,11 +13,10 @@ import test
import test.integ.util.system
import test.require
import test.runner
+from stem.util.test_tools import async_test
from unittest.mock import Mock, patch
-from test.async_util import async_test
-
class TestProtocolInfo(unittest.TestCase):
@test.require.controller
diff --git a/test/integ/socket/control_message.py b/test/integ/socket/control_message.py
index 80bf4762..8511877b 100644
--- a/test/integ/socket/control_message.py
+++ b/test/integ/socket/control_message.py
@@ -9,7 +9,7 @@ import stem.socket
import stem.version
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class TestControlMessage(unittest.TestCase):
diff --git a/test/integ/socket/control_socket.py b/test/integ/socket/control_socket.py
index bb2d8873..14797589 100644
--- a/test/integ/socket/control_socket.py
+++ b/test/integ/socket/control_socket.py
@@ -18,7 +18,7 @@ import stem.socket
import test
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class TestControlSocket(unittest.TestCase):
diff --git a/test/integ/util/connection.py b/test/integ/util/connection.py
index c35d8448..3e22667e 100644
--- a/test/integ/util/connection.py
+++ b/test/integ/util/connection.py
@@ -13,7 +13,7 @@ import test.require
import test.runner
from stem.util.connection import Resolver
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class TestConnection(unittest.TestCase):
diff --git a/test/integ/util/proc.py b/test/integ/util/proc.py
index 4038984c..b1006d10 100644
--- a/test/integ/util/proc.py
+++ b/test/integ/util/proc.py
@@ -10,7 +10,7 @@ import test.require
import test.runner
from stem.util import proc
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class TestProc(unittest.TestCase):
diff --git a/test/integ/version.py b/test/integ/version.py
index d02014a5..0df48646 100644
--- a/test/integ/version.py
+++ b/test/integ/version.py
@@ -8,7 +8,7 @@ import unittest
import stem.version
import test.require
import test.runner
-from test.async_util import async_test
+from stem.util.test_tools import async_test
class TestVersion(unittest.TestCase):
diff --git a/test/unit/connection/authentication.py b/test/unit/connection/authentication.py
index 5e59adae..700f458c 100644
--- a/test/unit/connection/authentication.py
+++ b/test/unit/connection/authentication.py
@@ -18,7 +18,7 @@ from unittest.mock import patch
from stem.response import ControlMessage
from stem.util import log
-from test.async_util import (
+from stem.util.test_tools import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
diff --git a/test/unit/connection/connect.py b/test/unit/connection/connect.py
index 3a0e0767..ec92a6c4 100644
--- a/test/unit/connection/connect.py
+++ b/test/unit/connection/connect.py
@@ -11,7 +11,7 @@ import stem.socket
from unittest.mock import Mock, patch
-from test.async_util import (
+from stem.util.test_tools import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 4c03dea3..e8ef4787 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -21,7 +21,7 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
from stem.response import ControlMessage
from stem.exit_policy import ExitPolicy
-from test.async_util import (
+from stem.util.test_tools import (
async_test,
coro_func_raising_exc,
coro_func_returning_value,
1
0
commit 66d597e90741260341ec264ea774f3089d43f7b9
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu May 14 00:12:02 2020 +0300
Fix errors in static checks of my IDE
---
stem/control.py | 2 +-
stem/descriptor/remote.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index 1488621a..e483d4f3 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -3899,7 +3899,7 @@ class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.u
self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
self._thread_for_wrapped_class.start()
- self._wrapped_instance = self._init_async_class(AsyncController, control_socket, is_authenticated)
+ self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated)
self._socket = self._wrapped_instance._socket
def msg(self, message: str) -> stem.response.ControlMessage:
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index eca846ee..daa3b42c 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -565,7 +565,7 @@ class Query(stem.util.AsyncClassWrapper):
def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs):
self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
self._thread_for_wrapped_class.start()
- self._wrapped_instance = self._init_async_class(
+ self._wrapped_instance: AsyncQuery = self._init_async_class(
AsyncQuery,
resource,
descriptor_type,
1
0

16 Jul '20
commit 2a65ee4bb3cd4a25baf14918cd7bb1c614be361e
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu Apr 30 19:32:25 2020 +0300
Fix monkey patches of the `_handle_event` method
---
stem/interpreter/__init__.py | 5 ++++-
stem/interpreter/commands.py | 9 +++++----
2 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 1d08abb6..07353d44 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -124,7 +124,10 @@ def main() -> None:
if args.run_cmd:
if args.run_cmd.upper().startswith('SETEVENTS '):
- controller._handle_event = lambda event_message: print(format(str(event_message), *STANDARD_OUTPUT)) # type: ignore
+ async def handle_event(event_message):
+ print(format(str(event_message), *STANDARD_OUTPUT))
+
+ controller._async_controller._handle_event = handle_event
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 4c506b5e..0d262ab5 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,17 +128,18 @@ class ControlInterpreter(code.InteractiveConsole):
# Intercept events our controller hears about at a pretty low level since
# the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._handle_event
+ handle_event_real = self._controller._async_controller._handle_event
- def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
- handle_event_real(event_message)
- self._received_events.insert(0, event_message) # type: ignore
+ async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
+ await handle_event_real(event_message)
+ self._received_events.insert(0, event_message)
if len(self._received_events) > MAX_EVENTS:
self._received_events.pop()
# type check disabled due to https://github.com/python/mypy/issues/708
+ self._controller._async_controller._handle_event = handle_event_wrapper
self._controller._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
1
0

16 Jul '20
commit a7f190e37417f90de2b255a178a78dbe093b8d5a
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun May 17 16:13:15 2020 +0300
Add the mypy cache folder to .gitignore
---
.gitignore | 1 +
1 file changed, 1 insertion(+)
diff --git a/.gitignore b/.gitignore
index d8c1f13d..67405a22 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,5 +5,6 @@
.editorconfig
.tox
.idea/
+.mypy_cache/
test/data/
docs/_build/
1
0

[stem/master] Get rid of `_ControllerClassMethodMixin` to fix type checks
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 1458e0899eea60763416632406b658a6dca69b6d
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Sun May 17 16:12:13 2020 +0300
Get rid of `_ControllerClassMethodMixin` to fix type checks
---
stem/control.py | 106 +++++++++++++++++++++++++++++++++++++-------------------
1 file changed, 70 insertions(+), 36 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index e483d4f3..2ffcf173 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -990,19 +990,21 @@ class BaseController(_BaseControllerSocketMixin):
self._event_notice.clear()
-class _ControllerClassMethodMixin:
+class AsyncController(BaseController):
+ """
+ Connection with Tor's control socket. This is built on top of the
+ BaseController and provides a more user friendly API for library users.
+ """
+
@classmethod
- def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control._ControllerClassMethodMixin':
+ def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.AsyncController':
"""
- Constructs a :class:`~stem.socket.ControlPort` based Controller.
+ Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
If the **port** is **'default'** then this checks on both 9051 (default
for relays) and 9151 (default for the Tor Browser). This default may change
in the future.
- .. versionchanged:: 1.5.0
- Use both port 9051 and 9151 by default.
-
:param address: ip address of the controller
:param port: port number of the controller
@@ -1011,24 +1013,13 @@ class _ControllerClassMethodMixin:
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- import stem.connection
-
- if not stem.util.connection.is_valid_ipv4_address(address):
- raise ValueError('Invalid IP address: %s' % address)
- elif port != 'default' and not stem.util.connection.is_valid_port(port):
- raise ValueError('Invalid port: %s' % port)
-
- if port == 'default':
- control_port = stem.connection._connection_for_default_port(address)
- else:
- control_port = stem.socket.ControlPort(address, int(port))
-
- return cls(control_port)
+ control_socket = _init_control_port(address, port)
+ return cls(control_socket)
@classmethod
- def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.Controller':
+ def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.AsyncController':
"""
- Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
+ Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
:param path: path where the control socket is located
@@ -1037,16 +1028,9 @@ class _ControllerClassMethodMixin:
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- control_socket = stem.socket.ControlSocketFile(path)
+ control_socket = _init_control_socket_file(path)
return cls(control_socket)
-
-class AsyncController(_ControllerClassMethodMixin, BaseController):
- """
- Connection with Tor's control socket. This is built on top of the
- BaseController and provides a more user friendly API for library users.
- """
-
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._is_caching_enabled = True
self._request_cache = {} # type: Dict[str, Any]
@@ -3878,18 +3862,48 @@ class AsyncController(_ControllerClassMethodMixin, BaseController):
return (set_events, failed_events)
-class Controller(_ControllerClassMethodMixin, _BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
+class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
@classmethod
def from_port(cls: Type, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
- instance = super().from_port(address, port)
- instance.connect()
- return instance
+ """
+ Constructs a :class:`~stem.socket.ControlPort` based Controller.
+
+ If the **port** is **'default'** then this checks on both 9051 (default
+ for relays) and 9151 (default for the Tor Browser). This default may change
+ in the future.
+
+ .. versionchanged:: 1.5.0
+ Use both port 9051 and 9151 by default.
+
+ :param str address: ip address of the controller
+ :param int port: port number of the controller
+
+ :returns: :class:`~stem.control.Controller` attached to the given port
+
+ :raises: :class:`stem.SocketError` if we're unable to establish a connection
+ """
+
+ control_socket = _init_control_port(address, port)
+ controller = cls(control_socket)
+ controller.connect()
+ return controller
@classmethod
def from_socket_file(cls: Type, path: str = '/var/run/tor/control') -> 'stem.control.Controller':
- instance = super().from_socket_file(path)
- instance.connect()
- return instance
+ """
+ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
+
+ :param str path: path where the control socket is located
+
+ :returns: :class:`~stem.control.Controller` attached to the given socket file
+
+ :raises: :class:`stem.SocketError` if we're unable to establish a connection
+ """
+
+ control_socket = _init_control_socket_file(path)
+ controller = cls(control_socket)
+ controller.connect()
+ return controller
def __init__(self, control_socket: 'stem.socket.ControlSocket', is_authenticated: bool = False, started_async_controller_thread: Optional['threading.Thread'] = None) -> None:
def __init__(self, control_socket, is_authenticated = False, started_async_controller_thread = None):
@@ -4231,3 +4245,23 @@ async def _get_with_timeout(event_queue: queue.Queue, timeout: float, start_time
return await asyncio.wait_for(event_queue.get(), timeout=time_left)
except asyncio.TimeoutError:
raise stem.Timeout('Reached our %0.1f second timeout' % timeout)
+
+
+def _init_control_port(address: str, port: Union[int, str]) -> stem.socket.ControlPort:
+ import stem.connection
+
+ if not stem.util.connection.is_valid_ipv4_address(address):
+ raise ValueError('Invalid IP address: %s' % address)
+ elif port != 'default' and not stem.util.connection.is_valid_port(port):
+ raise ValueError('Invalid port: %s' % port)
+
+ if port == 'default':
+ control_port = stem.connection._connection_for_default_port(address)
+ else:
+ control_port = stem.socket.ControlPort(address, int(port))
+
+ return control_port
+
+
+def _init_control_socket_file(path: str) -> stem.socket.ControlSocketFile:
+ return stem.socket.ControlSocketFile(path)
1
0

[stem/master] Make requesting for descriptor content asynchronous
by atagar@torproject.org 16 Jul '20
by atagar@torproject.org 16 Jul '20
16 Jul '20
commit 841e2105147177f5959987c8bec1179dc94a59b3
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Thu May 14 00:08:59 2020 +0300
Make requesting for descriptor content asynchronous
---
stem/client/__init__.py | 91 +++++++++++++-------------
stem/descriptor/remote.py | 138 ++++++++++++++++++++++++++++------------
stem/util/test_tools.py | 4 +-
test/integ/client/connection.py | 38 +++++++----
test/unit/descriptor/remote.py | 24 +++++--
5 files changed, 187 insertions(+), 108 deletions(-)
diff --git a/stem/client/__init__.py b/stem/client/__init__.py
index 639f118f..941f0ee7 100644
--- a/stem/client/__init__.py
+++ b/stem/client/__init__.py
@@ -26,7 +26,6 @@ a wrapper for :class:`~stem.socket.RelaySocket`, much the same way as
"""
import hashlib
-import threading
import stem
import stem.client.cell
@@ -71,11 +70,10 @@ class Relay(object):
self.link_protocol = LinkProtocol(link_protocol)
self._orport = orport
self._orport_buffer = b'' # unread bytes
- self._orport_lock = threading.RLock()
- self._circuits = {} # type: Dict[int, stem.client.Circuit]
+ self._circuits = {}
@staticmethod
- def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
+ async def connect(address: str, port: int, link_protocols: Sequence['stem.client.datatype.LinkProtocol'] = DEFAULT_LINK_PROTOCOLS) -> 'stem.client.Relay': # type: ignore
"""
Establishes a connection with the given ORPort.
@@ -97,8 +95,9 @@ class Relay(object):
try:
conn = stem.socket.RelaySocket(address, port)
+ await conn.connect()
except stem.SocketError as exc:
- if 'Connection refused' in str(exc):
+ if 'Connect call failed' in str(exc):
raise stem.SocketError("Failed to connect to %s:%i. Maybe it isn't an ORPort?" % (address, port))
# If not an ORPort (for instance, mistakenly connecting to a ControlPort
@@ -122,21 +121,21 @@ class Relay(object):
# first VERSIONS cell, always have CIRCID_LEN == 2 for backward
# compatibility.
- conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore
- response = conn.recv()
+ await conn.send(stem.client.cell.VersionsCell(link_protocols).pack(2)) # type: ignore
+ response = await conn.recv()
# Link negotiation ends right away if we lack a common protocol
# version. (#25139)
if not response:
- conn.close()
+ await conn.close()
raise stem.SocketError('Unable to establish a common link protocol with %s:%i' % (address, port))
versions_reply = stem.client.cell.Cell.pop(response, 2)[0] # type: stem.client.cell.VersionsCell # type: ignore
common_protocols = set(link_protocols).intersection(versions_reply.versions)
if not common_protocols:
- conn.close()
+ await conn.close()
raise stem.SocketError('Unable to find a common link protocol. We support %s but %s:%i supports %s.' % (', '.join(map(str, link_protocols)), address, port, ', '.join(map(str, versions_reply.versions))))
# Establishing connections requires sending a NETINFO, but including our
@@ -144,14 +143,14 @@ class Relay(object):
# where it would help.
link_protocol = max(common_protocols)
- conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
+ await conn.send(stem.client.cell.NetinfoCell(relay_addr, []).pack(link_protocol))
return Relay(conn, link_protocol)
- def _recv_bytes(self) -> bytes:
- return self._recv(True) # type: ignore
+ async def _recv_bytes(self) -> bytes:
+ return await self._recv(True) # type: ignore
- def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
+ async def _recv(self, raw: bool = False) -> 'stem.client.cell.Cell':
"""
Reads the next cell from our ORPort. If none is present this blocks
until one is available.
@@ -161,13 +160,13 @@ class Relay(object):
:returns: next :class:`~stem.client.cell.Cell`
"""
- with self._orport_lock:
+ async with self._orport_lock:
# cells begin with [circ_id][cell_type][...]
circ_id_size = self.link_protocol.circ_id_size.size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size):
- self._orport_buffer += self._orport.recv() # read until we know the cell type
+ self._orport_buffer += await self._orport.recv() # read until we know the cell type
cell_type = Cell.by_value(CELL_TYPE_SIZE.pop(self._orport_buffer[circ_id_size:])[0])
@@ -177,13 +176,13 @@ class Relay(object):
# variable length, our next field is the payload size
while len(self._orport_buffer) < (circ_id_size + CELL_TYPE_SIZE.size + FIXED_PAYLOAD_LEN):
- self._orport_buffer += self._orport.recv() # read until we know the cell size
+ self._orport_buffer += await self._orport.recv() # read until we know the cell size
payload_len = PAYLOAD_LEN_SIZE.pop(self._orport_buffer[circ_id_size + CELL_TYPE_SIZE.size:])[0]
cell_size = circ_id_size + CELL_TYPE_SIZE.size + payload_len
while len(self._orport_buffer) < cell_size:
- self._orport_buffer += self._orport.recv() # read until we have the full cell
+ self._orport_buffer += await self._orport.recv() # read until we have the full cell
if raw:
content, self._orport_buffer = split(self._orport_buffer, cell_size)
@@ -192,7 +191,7 @@ class Relay(object):
cell, self._orport_buffer = Cell.pop(self._orport_buffer, self.link_protocol)
return cell
- def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
+ async def _msg(self, cell: 'stem.client.cell.Cell') -> Iterator['stem.client.cell.Cell']:
"""
Sends a cell on the ORPort and provides the response we receive in reply.
@@ -219,9 +218,9 @@ class Relay(object):
# TODO: why is this an iterator?
- self._orport.recv(timeout = 0) # discard unread data
- self._orport.send(cell.pack(self.link_protocol))
- response = self._orport.recv(timeout = 1)
+ await self._orport.recv(timeout = 0) # discard unread data
+ await self._orport.send(cell.pack(self.link_protocol))
+ response = await self._orport.recv(timeout = 1)
yield stem.client.cell.Cell.pop(response, self.link_protocol)[0]
def is_alive(self) -> bool:
@@ -246,27 +245,27 @@ class Relay(object):
return self._orport.connection_time()
- def close(self) -> None:
+ async def close(self) -> None:
"""
Closes our socket connection. This is a pass-through for our socket's
:func:`~stem.socket.BaseSocket.close` method.
"""
- with self._orport_lock:
- return self._orport.close()
+ async with self._orport_lock:
+ return await self._orport.close()
- def create_circuit(self) -> 'stem.client.Circuit':
+ async def create_circuit(self) -> 'stem.client.Circuit':
"""
Establishes a new circuit.
"""
- with self._orport_lock:
+ async with self._orport_lock:
circ_id = max(self._circuits) + 1 if self._circuits else self.link_protocol.first_circ_id
create_fast_cell = stem.client.cell.CreateFastCell(circ_id)
created_fast_cell = None
- for cell in self._msg(create_fast_cell):
+ async for cell in self._msg(create_fast_cell):
if isinstance(cell, stem.client.cell.CreatedFastCell):
created_fast_cell = cell
break
@@ -284,16 +283,16 @@ class Relay(object):
return circ
- def __iter__(self) -> Iterator['stem.client.Circuit']:
- with self._orport_lock:
+ async def __aiter__(self) -> Iterator['stem.client.Circuit']:
+ async with self._orport_lock:
for circ in self._circuits.values():
yield circ
- def __enter__(self) -> 'stem.client.Relay':
+ async def __aenter__(self) -> 'stem.client.Relay':
return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+ await self.close()
class Circuit(object):
@@ -327,7 +326,7 @@ class Circuit(object):
self.forward_key = Cipher(algorithms.AES(kdf.forward_key), ctr, default_backend()).encryptor()
self.backward_key = Cipher(algorithms.AES(kdf.backward_key), ctr, default_backend()).decryptor()
- def directory(self, request: str, stream_id: int = 0) -> bytes:
+ async def directory(self, request: str, stream_id: int = 0) -> bytes:
"""
Request descriptors from the relay.
@@ -337,9 +336,9 @@ class Circuit(object):
:returns: **str** with the requested descriptor data
"""
- with self.relay._orport_lock:
- self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
- self._send(RelayCommand.DATA, request, stream_id = stream_id)
+ async with self.relay._orport_lock:
+ await self._send(RelayCommand.BEGIN_DIR, stream_id = stream_id)
+ await self._send(RelayCommand.DATA, request, stream_id = stream_id)
response = [] # type: List[stem.client.cell.RelayCell]
@@ -347,7 +346,7 @@ class Circuit(object):
# Decrypt relay cells received in response. Our digest/key only
# updates when handled successfully.
- encrypted_cell = self.relay._recv_bytes()
+ encrypted_cell = await self.relay._recv_bytes()
decrypted_cell, backward_key, backward_digest = stem.client.cell.RelayCell.decrypt(self.relay.link_protocol, encrypted_cell, self.backward_key, self.backward_digest)
@@ -362,7 +361,7 @@ class Circuit(object):
else:
response.append(decrypted_cell)
- def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
+ async def _send(self, command: 'stem.client.datatype.RelayCommand', data: Union[bytes, str] = b'', stream_id: int = 0) -> None:
"""
Sends a message over the circuit.
@@ -371,24 +370,24 @@ class Circuit(object):
:param stream_id: specific stream this concerns
"""
- with self.relay._orport_lock:
+ async with self.relay._orport_lock:
# Encrypt and send the cell. Our digest/key only updates if the cell is
# successfully sent.
cell = stem.client.cell.RelayCell(self.id, command, data, stream_id = stream_id)
payload, forward_key, forward_digest = cell.encrypt(self.relay.link_protocol, self.forward_key, self.forward_digest)
- self.relay._orport.send(payload)
+ await self.relay._orport.send(payload)
self.forward_digest = forward_digest
self.forward_key = forward_key
- def close(self) -> None:
- with self.relay._orport_lock:
- self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
+ async def close(self)- > None:
+ async with self.relay._orport_lock:
+ await self.relay._orport.send(stem.client.cell.DestroyCell(self.id).pack(self.relay.link_protocol))
del self.relay._circuits[self.id]
- def __enter__(self) -> 'stem.client.Circuit':
+ async def __aenter__(self) -> 'stem.client.Circuit':
return self
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
+ async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
+ await self.close()
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index e90c4442..eca846ee 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -83,6 +83,8 @@ content. For example...
hashes.
"""
+import asyncio
+import functools
import io
import random
import socket
@@ -93,6 +95,7 @@ import urllib.request
import stem
import stem.client
+import stem.control
import stem.descriptor
import stem.descriptor.networkstatus
import stem.directory
@@ -227,7 +230,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
return get_instance().get_detached_signatures(**query_args)
-class Query(object):
+class AsyncQuery(object):
"""
Asynchronous request for descriptor content from a directory authority or
mirror. These can either be made through the
@@ -427,32 +430,27 @@ class Query(object):
self.reply_headers = None # type: Optional[Dict[str, str]]
self.kwargs = kwargs
- self._downloader_thread = None # type: Optional[threading.Thread]
- self._downloader_thread_lock = threading.RLock()
+ self._downloader_task = None
+ self._downloader_lock = threading.RLock()
+
+ self._asyncio_loop = asyncio.get_event_loop()
if start:
self.start()
if block:
- self.run(True)
+ self._asyncio_loop.create_task(self.run(True))
def start(self) -> None:
"""
Starts downloading the scriptors if we haven't started already.
"""
- with self._downloader_thread_lock:
- if self._downloader_thread is None:
- self._downloader_thread = threading.Thread(
- name = 'Descriptor query',
- target = self._download_descriptors,
- args = (self.retries, self.timeout)
- )
-
- self._downloader_thread.setDaemon(True)
- self._downloader_thread.start()
+ with self._downloader_lock:
+ if self._downloader_task is None:
+ self._downloader_task = self._asyncio_loop.create_task(self._download_descriptors(self.retries, self.timeout))
- def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
+ async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
@@ -470,12 +468,12 @@ class Query(object):
* :class:`~stem.DownloadFailed` if our request fails
"""
- return list(self._run(suppress))
+ return [desc async for desc in self._run(suppress)]
- def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
- with self._downloader_thread_lock:
+ async def _run(self, suppress: bool) -> Iterator[stem.descriptor.Descriptor]:
+ with self._downloader_lock:
self.start()
- self._downloader_thread.join()
+ await self._downloader_task
if self.error:
if suppress:
@@ -508,8 +506,8 @@ class Query(object):
raise self.error
- def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
- for desc in self._run(True):
+ async def __aiter__(self) -> Iterator[stem.descriptor.Descriptor]:
+ async for desc in self._run(True):
yield desc
def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
@@ -530,18 +528,18 @@ class Query(object):
else:
return random.choice(self.endpoints)
- def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
+ async def _download_descriptors(self, retries: int, timeout: Optional[float]) -> None:
try:
self.start_time = time.time()
endpoint = self._pick_endpoint(use_authority = retries == 0 and self.fall_back_to_authority)
if isinstance(endpoint, stem.ORPort):
downloaded_from = 'ORPort %s:%s (resource %s)' % (endpoint.address, endpoint.port, self.resource)
- self.content, self.reply_headers = _download_from_orport(endpoint, self.compression, self.resource)
+ self.content, self.reply_headers = await _download_from_orport(endpoint, self.compression, self.resource)
elif isinstance(endpoint, stem.DirPort):
self.download_url = 'http://%s:%i/%s' % (endpoint.address, endpoint.port, self.resource.lstrip('/'))
downloaded_from = self.download_url
- self.content, self.reply_headers = _download_from_dirport(self.download_url, self.compression, timeout)
+ self.content, self.reply_headers = await _download_from_dirport(self.download_url, self.compression, timeout)
else:
raise ValueError("BUG: endpoints can only be ORPorts or DirPorts, '%s' was a %s" % (endpoint, type(endpoint).__name__))
@@ -555,7 +553,7 @@ class Query(object):
if retries > 0 and (timeout is None or timeout > 0):
log.debug("Unable to download descriptors from '%s' (%i retries remaining): %s" % (self.download_url, retries, exc))
- return self._download_descriptors(retries - 1, timeout)
+ return await self._download_descriptors(retries - 1, timeout)
else:
log.debug("Unable to download descriptors from '%s': %s" % (self.download_url, exc))
self.error = exc
@@ -563,6 +561,64 @@ class Query(object):
self.is_done = True
+class Query(stem.util.AsyncClassWrapper):
+ def __init__(self, resource, descriptor_type = None, endpoints = None, compression = (Compression.GZIP,), retries = 2, fall_back_to_authority = False, timeout = None, start = True, block = False, validate = False, document_handler = stem.descriptor.DocumentHandler.ENTRIES, **kwargs):
+ self._thread_for_wrapped_class = stem.util.ThreadForWrappedAsyncClass()
+ self._thread_for_wrapped_class.start()
+ self._wrapped_instance = self._init_async_class(
+ AsyncQuery,
+ resource,
+ descriptor_type,
+ endpoints,
+ compression,
+ retries,
+ fall_back_to_authority,
+ timeout,
+ start,
+ block,
+ validate,
+ document_handler,
+ **kwargs,
+ )
+
+ def start(self):
+ return self._call_async_method_soon('start')
+
+ def run(self, suppress = False):
+ return self._execute_async_method('run', suppress)
+
+ def __iter__(self):
+ for desc in self._execute_async_generator_method('__aiter__'):
+ yield desc
+
+ # Add public attributes of `AsyncQuery` as properties.
+ for attr in (
+ 'descriptor_type',
+ 'endpoints',
+ 'resource',
+ 'compression',
+ 'retries',
+ 'fall_back_to_authority',
+ 'content',
+ 'error',
+ 'is_done',
+ 'download_url',
+ 'start_time',
+ 'timeout',
+ 'runtime',
+ 'validate',
+ 'document_handler',
+ 'reply_headers',
+ 'kwargs',
+ ):
+ locals()[attr] = property(
+ functools.partial(
+ lambda self, attr_name: getattr(self._wrapped_instance, attr_name),
+ attr_name=attr,
+ ),
+ )
+
+
class DescriptorDownloader(object):
"""
Configurable class that issues :class:`~stem.descriptor.remote.Query`
@@ -925,7 +981,7 @@ class DescriptorDownloader(object):
return Query(resource, **args)
-def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.descriptor._Compression], resource: str) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given orport. Payload is just like an http
response (headers and all)...
@@ -956,15 +1012,15 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
link_protocols = endpoint.link_protocols if endpoint.link_protocols else [3]
- with stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
- with relay.create_circuit() as circ:
+ async with await stem.client.Relay.connect(endpoint.address, endpoint.port, link_protocols) as relay:
+ async with await relay.create_circuit() as circ:
request = '\r\n'.join((
'GET %s HTTP/1.0' % resource,
'Accept-Encoding: %s' % ', '.join(map(lambda c: c.encoding, compression)),
'User-Agent: %s' % stem.USER_AGENT,
)) + '\r\n\r\n'
- response = circ.directory(request, stream_id = 1)
+ response = await circ.directory(request, stream_id = 1)
first_line, data = response.split(b'\r\n', 1)
header_data, body_data = data.split(b'\r\n\r\n', 1)
@@ -983,7 +1039,7 @@ def _download_from_orport(endpoint: stem.ORPort, compression: Sequence[stem.desc
return _decompress(body_data, headers.get('Content-Encoding')), headers
-def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
+async def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Compression], timeout: Optional[float]) -> Tuple[bytes, Dict[str, str]]:
"""
Downloads descriptors from the given url.
@@ -998,17 +1054,19 @@ def _download_from_dirport(url: str, compression: Sequence[stem.descriptor._Comp
* :class:`~stem.DownloadFailed` if our request fails
"""
+ # TODO: use an asyncronous solution for the HTTP request.
+ request = urllib.request.Request(
+ url,
+ headers = {
+ 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
+ 'User-Agent': stem.USER_AGENT,
+ }
+ )
+ get_response = functools.partial(urllib.request.urlopen, request, timeout = timeout)
+
+ loop = asyncio.get_event_loop()
try:
- response = urllib.request.urlopen(
- urllib.request.Request(
- url,
- headers = {
- 'Accept-Encoding': ', '.join(map(lambda c: c.encoding, compression)),
- 'User-Agent': stem.USER_AGENT,
- }
- ),
- timeout = timeout,
- )
+ response = await loop.run_in_executor(None, get_response)
except socket.timeout as exc:
raise stem.DownloadTimeout(url, exc, sys.exc_info()[2], timeout)
except:
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index f3c736a1..455f3da3 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -251,7 +251,7 @@ class TimedTestRunner(unittest.TextTestRunner):
TEST_RUNTIMES[self.id()] = time.time() - start_time
return result
- def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, func: Callable, *args: Any, **kwargs: Any) -> None:
+ def assertRaisesWith(self, exc_type: Type[Exception], exc_msg: str, *args: Any, **kwargs: Any) -> None:
"""
Asserts the given invokation raises the expected excepiton. This is
similar to unittest's assertRaises and assertRaisesRegexp, but checks
@@ -262,7 +262,7 @@ class TimedTestRunner(unittest.TextTestRunner):
vended API then please let us know.
"""
- return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), func, *args, **kwargs)
+ return self.assertRaisesRegexp(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs)
def id(self) -> str:
return '%s.%s.%s' % (original_type.__module__, original_type.__name__, self._testMethodName)
diff --git a/test/integ/client/connection.py b/test/integ/client/connection.py
index 2294a07d..316d54ba 100644
--- a/test/integ/client/connection.py
+++ b/test/integ/client/connection.py
@@ -9,46 +9,57 @@ import stem
import test.runner
from stem.client import Relay
+from stem.util.test_tools import async_test
class TestConnection(unittest.TestCase):
- def test_invalid_arguments(self):
+ @async_test
+ async def test_invalid_arguments(self):
"""
Provide invalid arguments to Relay.connect().
"""
- self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address", Relay.connect, 'nope', 80)
- self.assertRaisesWith(ValueError, "'-54' isn't a valid port", Relay.connect, '127.0.0.1', -54)
- self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol.", Relay.connect, '127.0.0.1', 54, [])
+ with self.assertRaisesWith(ValueError, "'nope' isn't an IPv4 or IPv6 address"):
+ await Relay.connect('nope', 80)
+ with self.assertRaisesWith(ValueError, "'-54' isn't a valid port"):
+ await Relay.connect('127.0.0.1', -54)
+ with self.assertRaisesWith(ValueError, "Connection can't be established without a link protocol."):
+ await Relay.connect('127.0.0.1', 54, [])
- def test_not_orport(self):
+ @async_test
+ async def test_not_orport(self):
"""
Attempt to connect to an ORPort that doesn't exist.
"""
- self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', 1587)
+ with self.assertRaisesWith(stem.SocketError, "Failed to connect to 127.0.0.1:1587. Maybe it isn't an ORPort?"):
+ await Relay.connect('127.0.0.1', 1587)
# connect to our ControlPort like it's an ORPort
if test.runner.Torrc.PORT in test.runner.get_runner().get_options():
- self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?", Relay.connect, '127.0.0.1', test.runner.CONTROL_PORT)
+ with self.assertRaisesWith(stem.SocketError, "Failed to SSL authenticate to 127.0.0.1:1111. Maybe it isn't an ORPort?"):
+ await Relay.connect('127.0.0.1', test.runner.CONTROL_PORT)
- def test_no_common_link_protocol(self):
+ @async_test
+ async def test_no_common_link_protocol(self):
"""
Connection without a commonly accepted link protocol version.
"""
for link_protocol in (1, 2, 6, 20):
- self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113', Relay.connect, '127.0.0.1', test.runner.ORPORT, [link_protocol])
+ with self.assertRaisesWith(stem.SocketError, 'Unable to establish a common link protocol with 127.0.0.1:1113'):
+ await Relay.connect('127.0.0.1', test.runner.ORPORT, [link_protocol])
- def test_connection_time(self):
+ @async_test
+ async def test_connection_time(self):
"""
Checks duration we've been connected.
"""
before = time.time()
- with Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
+ async with await Relay.connect('127.0.0.1', test.runner.ORPORT) as conn:
connection_time = conn.connection_time()
self.assertTrue(time.time() >= connection_time >= before)
time.sleep(0.02)
@@ -57,10 +68,11 @@ class TestConnection(unittest.TestCase):
self.assertFalse(conn.is_alive())
self.assertTrue(conn.connection_time() >= connection_time + 0.02)
- def test_established(self):
+ @async_test
+ async def test_established(self):
"""
Successfully establish ORPort connection.
"""
- conn = Relay.connect('127.0.0.1', test.runner.ORPORT)
+ conn = await Relay.connect('127.0.0.1', test.runner.ORPORT)
self.assertTrue(int(conn.link_protocol) in (4, 5))
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index e57da92b..33ee57fb 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -13,9 +13,10 @@ import stem.descriptor.remote
import stem.util.str_tools
import test.require
-from unittest.mock import patch, Mock, MagicMock
+from unittest.mock import patch, Mock
from stem.descriptor.remote import Compression
+from stem.util.test_tools import coro_func_returning_value
from test.unit.descriptor import read_resource
TEST_RESOURCE = '/tor/server/fp/9695DFC35FFEB861329B9F1AB04C46397020CE31'
@@ -78,11 +79,20 @@ def _orport_mock(data, encoding = 'identity', response_code_header = None):
cell.data = hunk
cells.append(cell)
- connect_mock = MagicMock()
- relay_mock = connect_mock().__enter__()
- circ_mock = relay_mock.create_circuit().__enter__()
- circ_mock.directory.return_value = data
- return connect_mock
+ class AsyncMock(Mock):
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ return
+
+ circ_mock = AsyncMock()
+ circ_mock.directory.side_effect = coro_func_returning_value(data)
+
+ relay_mock = AsyncMock()
+ relay_mock.create_circuit.side_effect = coro_func_returning_value(circ_mock)
+
+ return coro_func_returning_value(relay_mock)
def _dirport_mock(data, encoding = 'identity'):
@@ -294,7 +304,7 @@ class TestDescriptorDownloader(unittest.TestCase):
skip_crypto_validation = not test.require.CRYPTOGRAPHY_AVAILABLE,
)
- self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._pick_endpoint())
+ self.assertEqual(stem.DirPort('128.31.0.39', 9131), query._wrapped_instance._pick_endpoint())
descriptors = list(query)
self.assertEqual(1, len(descriptors))
1
0