tor-commits
Threads by month
- ----- 2025 -----
- May
- April
- March
- February
- January
- ----- 2024 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2023 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2022 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2021 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2020 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2019 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2018 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2017 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2016 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2015 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2014 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2013 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2012 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
- January
- ----- 2011 -----
- December
- November
- October
- September
- August
- July
- June
- May
- April
- March
- February
July 2020
- 17 participants
- 2100 discussions
commit 9f71ce9b21d8c710025a440e69920006e37eeb88
Author: Damian Johnson <atagar(a)torproject.org>
Date: Tue Jul 7 18:38:06 2020 -0700
Make Synchronous class resumable
Our Controller needs to start and stop with its connect/close methods, so we
need for this to be resumable.
---
stem/util/__init__.py | 50 +++++++++++++++++++++++++++---------------
test/settings.cfg | 2 ++
test/unit/descriptor/remote.py | 20 ++++++++---------
test/unit/util/synchronous.py | 29 +++++++++++++++++-------
4 files changed, 65 insertions(+), 36 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index cddce755..54f90376 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -12,7 +12,7 @@ import inspect
import threading
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Type, Union
+from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
__all__ = [
'conf',
@@ -162,12 +162,12 @@ class Synchronous(object):
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
- instance.close()
+ instance.stop()
async def async_demo():
instance = Example()
print('%s from an asynchronous context' % await instance.hello())
- instance.close()
+ instance.stop()
sync_demo()
asyncio.run(async_demo())
@@ -194,35 +194,34 @@ class Synchronous(object):
# asyncio.get_running_loop(), and construct objects that
# require it (like asyncio.Queue and asyncio.Lock).
- Users are responsible for calling :func:`~stem.util.Synchronous.close` when
+ Users are responsible for calling :func:`~stem.util.Synchronous.stop` when
finished to clean up underlying resources.
"""
def __init__(self) -> None:
+ self._loop_thread = None # type: Optional[threading.Thread]
+ self._loop_thread_lock = threading.RLock()
+
if Synchronous.is_asyncio_context():
self._loop = asyncio.get_running_loop()
- self._loop_thread = None
self.__ainit__()
else:
self._loop = asyncio.new_event_loop()
- self._loop_thread = threading.Thread(
- name = '%s asyncio' % type(self).__name__,
- target = self._loop.run_forever,
- daemon = True,
- )
- self._loop_thread.start()
+ Synchronous.start(self)
# call any coroutines through this loop
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
if Synchronous.is_asyncio_context():
return func(*args, **kwargs)
- elif not self._loop_thread.is_alive():
- raise RuntimeError('%s has been closed' % type(self).__name__)
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+ with self._loop_thread_lock:
+ if not self._loop_thread.is_alive():
+ raise RuntimeError('%s has been stopped' % type(self).__name__)
+
+ return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
if inspect.iscoroutinefunction(func):
@@ -273,16 +272,31 @@ class Synchronous(object):
pass
- def close(self) -> None:
+ def start(self) -> None:
+ """
+ Initiate resources to make this object callable from synchronous contexts.
+ """
+
+ with self._loop_thread_lock:
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % type(self).__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
+
+ self._loop_thread.start()
+
+ def stop(self) -> None:
"""
Terminate resources that permits this from being callable from synchronous
contexts. Once called any further synchronous invocations will fail with a
**RuntimeError**.
"""
- if self._loop_thread and self._loop_thread.is_alive():
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
+ with self._loop_thread_lock:
+ if self._loop_thread and self._loop_thread.is_alive():
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
@staticmethod
def is_asyncio_context() -> bool:
diff --git a/test/settings.cfg b/test/settings.cfg
index fcef5ec1..70bdd069 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -229,6 +229,8 @@ mypy.ignore * => "_IntegerEnum" has no attribute *
mypy.ignore * => See https://mypy.readthedocs.io/en/latest/common_issues.html*
mypy.ignore * => *is not valid as a type*
+mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]" of "start" *
+
# Metaprogramming prevents mypy from determining descriptor attributes.
mypy.ignore * => "Descriptor" has no attribute "*
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index 1fd2aaf9..bb6f554c 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -100,7 +100,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
self.assertEqual(TEST_DESCRIPTOR, desc.get_bytes())
- reply.close()
+ reply.stop()
def test_response_header_code(self):
"""
@@ -150,13 +150,13 @@ class TestDescriptorDownloader(unittest.TestCase):
descriptors = list(query)
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
- query.close()
+ query.stop()
def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
- query.close()
+ query.stop()
@mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
@@ -172,7 +172,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -191,7 +191,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -212,7 +212,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -233,7 +233,7 @@ class TestDescriptorDownloader(unittest.TestCase):
)
descriptors = list(query)
- query.close()
+ query.stop()
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
@@ -258,7 +258,7 @@ class TestDescriptorDownloader(unittest.TestCase):
queries.append(downloader.get_detached_signatures())
for query in queries:
- query.close()
+ query.stop()
@mock_download(b'some malformed stuff')
def test_malformed_content(self):
@@ -285,7 +285,7 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertRaises(ValueError, query.run)
- query.close()
+ query.stop()
def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
@@ -316,4 +316,4 @@ class TestDescriptorDownloader(unittest.TestCase):
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
- query.close()
+ query.stop()
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 22271ffd..dd27c3c6 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -31,12 +31,12 @@ class TestSynchronous(unittest.TestCase):
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
- instance.close()
+ instance.stop()
async def async_demo():
instance = Example()
print('%s from an asynchronous context' % await instance.hello())
- instance.close()
+ instance.stop()
sync_demo()
asyncio.run(async_demo())
@@ -66,20 +66,33 @@ class TestSynchronous(unittest.TestCase):
sync_demo()
asyncio.run(async_demo())
- def test_after_close(self):
+ def test_after_stop(self):
"""
- Check that closed instances raise a RuntimeError to synchronous callers.
+ Check that stopped instances raise a RuntimeError to synchronous callers.
"""
- # close a used instance
+ # stop a used instance
instance = Example()
self.assertEqual('hello', instance.hello())
- instance.close()
+ instance.stop()
self.assertRaises(RuntimeError, instance.hello)
- # close an unused instance
+ # stop an unused instance
instance = Example()
- instance.close()
+ instance.stop()
self.assertRaises(RuntimeError, instance.hello)
+
+ def test_resuming(self):
+ """
+ Resume a previously stopped instance.
+ """
+
+ instance = Example()
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
+ self.assertRaises(RuntimeError, instance.hello)
+ instance.start()
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
1
0
commit 1cbf3397ccbb77ea7df35109839a522dcf03556b
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sat Jun 20 16:31:17 2020 -0700
Synchronous mixin
Illia's AsyncClassWrapper does the trick but I think we can make this more
transparent. Lets try a mixin that overwrites asyncio methods dynamically.
Earlier I added a AsyncClassWrapper.__del__() method to clean itself up,
but doing so was a mistake. When our Python interpreter shuts down asyncio
closes its scheduler *before* this method invokes, which makes join() hang
because loop.stop() never runs.
To avoid these deadlocks we need Synchronous (or AsyncClassWrapper) users to
explicitly close the class themself.
---
stem/util/__init__.py | 89 +++++++++++++++++++++++++++++++++++++++++++
test/settings.cfg | 1 +
test/unit/util/synchronous.py | 54 ++++++++++++++++++++++++++
3 files changed, 144 insertions(+)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 25282b99..72239273 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -7,6 +7,8 @@ Utility functions used by the stem library.
import asyncio
import datetime
+import functools
+import inspect
import threading
from concurrent.futures import Future
@@ -144,6 +146,93 @@ def _hash_attr(obj: Any, *attributes: str, **kwargs: Any):
return my_hash
+class Synchronous(object):
+ """
+ Mixin that lets a class be called from both synchronous and asynchronous
+ contexts.
+
+ ::
+
+ class Example(Synchronous):
+ async def hello(self):
+ return 'hello'
+
+ def sync_demo():
+ instance = Example()
+ print('%s from a synchronous context' % instance.hello())
+ instance.close()
+
+ async def async_demo():
+ instance = Example()
+ print('%s from an asynchronous context' % await instance.hello())
+ instance.close()
+
+ sync_demo()
+ asyncio.run(async_demo())
+
+ Users are responsible for calling :func:`~stem.util.Synchronous.close` when
+ finished to clean up underlying resources.
+ """
+
+ def __init__(self):
+ self._loop = asyncio.new_event_loop()
+ self._loop_lock = threading.RLock()
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % self.__class__.__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
+
+ self._is_closed = False
+
+ # overwrite asynchronous class methods with instance methods that can be
+ # called from either context
+
+ def wrap(func, *args, **kwargs):
+ if Synchronous.is_asyncio_context():
+ return func(*args, **kwargs)
+ else:
+ with self._loop_lock:
+ if self._is_closed:
+ raise RuntimeError('%s has been closed' % type(self).__name__)
+ elif not self._loop_thread.is_alive():
+ self._loop_thread.start()
+
+ return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
+
+ for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
+ if inspect.iscoroutinefunction(func):
+ setattr(self, method_name, functools.partial(wrap, func))
+
+ def close(self):
+ """
+ Terminate resources that permits this from being callable from synchronous
+ contexts. Once called any further synchronous invocations will fail with a
+ **RuntimeError**.
+ """
+
+ with self._loop_lock:
+ if self._loop_thread.is_alive():
+ self._loop.call_soon_threadsafe(self._loop.stop)
+ self._loop_thread.join()
+
+ self._is_closed = True
+
+ @staticmethod
+ def is_asyncio_context():
+ """
+ Check if running within a synchronous or asynchronous context.
+
+ :returns: **True** if within an asyncio conext, **False** otherwise
+ """
+
+ try:
+ asyncio.get_running_loop()
+ return True
+ except RuntimeError:
+ return False
+
+
class AsyncClassWrapper:
_loop: asyncio.AbstractEventLoop
_loop_thread: threading.Thread
diff --git a/test/settings.cfg b/test/settings.cfg
index 51109f96..fcef5ec1 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -248,6 +248,7 @@ test.unit_tests
|test.unit.util.system.TestSystem
|test.unit.util.term.TestTerminal
|test.unit.util.tor_tools.TestTorTools
+|test.unit.util.synchronous.TestSynchronous
|test.unit.util.__init__.TestBaseUtil
|test.unit.installation.TestInstallation
|test.unit.descriptor.descriptor.TestDescriptor
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
new file mode 100644
index 00000000..26dad98d
--- /dev/null
+++ b/test/unit/util/synchronous.py
@@ -0,0 +1,54 @@
+"""
+Unit tests for the stem.util.Synchronous class.
+"""
+
+import asyncio
+import io
+import unittest
+
+from unittest.mock import patch
+
+from stem.util import Synchronous
+
+EXAMPLE_OUTPUT = """\
+hello from a synchronous context
+hello from an asynchronous context
+"""
+
+
+class Example(Synchronous):
+ async def hello(self):
+ return 'hello'
+
+
+class TestSynchronous(unittest.TestCase):
+ @patch('sys.stdout', new_callable = io.StringIO)
+ def test_example(self, stdout_mock):
+ def sync_demo():
+ instance = Example()
+ print('%s from a synchronous context' % instance.hello())
+ instance.close()
+
+ async def async_demo():
+ instance = Example()
+ print('%s from an asynchronous context' % await instance.hello())
+ instance.close()
+
+ sync_demo()
+ asyncio.run(async_demo())
+
+ self.assertEqual(EXAMPLE_OUTPUT, stdout_mock.getvalue())
+
+ def test_after_close(self):
+ # close a used instance
+
+ instance = Example()
+ self.assertEqual('hello', instance.hello())
+ instance.close()
+ self.assertRaises(RuntimeError, instance.hello)
+
+ # close an unused instance
+
+ instance = Example()
+ instance.close()
+ self.assertRaises(RuntimeError, instance.hello)
1
0
commit 007cf1ae5654ac057cc56dc06364561ce1d25c58
Author: Damian Johnson <atagar(a)torproject.org>
Date: Mon Jul 13 16:09:35 2020 -0700
Match loop scope to thread
Asyncio threads can be restarted, but doing so lacks a significant benefit and
can get complicated. For instance, when we're stopped from an async method our
loop is closed asynchronously (because we cannot join our own thread). This is
fine, except that start() can subsiquently fail because we cannot resume a
running loop...
Traceback (most recent call last):
File "/home/atagar/Python-3.7.0/Lib/threading.py", line 917, in _bootstrap_inner
self.run()
File "/home/atagar/Python-3.7.0/Lib/threading.py", line 865, in run
self._target(*self._args, **self._kwargs)
File "/home/atagar/Python-3.7.0/Lib/asyncio/base_events.py", line 510, in run_forever
raise RuntimeError('This event loop is already running')
RuntimeError: This event loop is already running
By creating a new loop for each thread we not only sidestep this but simplify
asynchronicity beacause each run of our class will have its own event queue.
---
stem/util/__init__.py | 40 +++++++++++++++++++---------------------
test/unit/util/synchronous.py | 9 +++++++--
2 files changed, 26 insertions(+), 23 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index fbd844f8..c28b0b27 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -204,17 +204,19 @@ class Synchronous(object):
def __init__(self) -> None:
self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._loop_thread = None # type: Optional[threading.Thread]
- self._loop_thread_lock = threading.RLock()
+ self._loop_lock = threading.RLock()
- # this class is a no-op when created from an asyncio context
+ # this class is a no-op when created within an asyncio context
self._no_op = Synchronous.is_asyncio_context()
- if not self._no_op:
- self._loop = asyncio.new_event_loop()
+ if self._no_op:
+ self.__ainit__() # this is already an asyncio context
+ else:
Synchronous.start(self)
- # call any coroutines through our loop
+ # Run coroutines through our loop. This calls methods by name rather than
+ # reference so runtime replacements (like mocks) work.
for name, func in inspect.getmembers(self):
if name in ('__aiter__', '__aenter__', '__aexit__'):
@@ -224,9 +226,6 @@ class Synchronous(object):
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
setattr(self, name, functools.partial(self._run_async_method, name))
- if self._no_op:
- self.__ainit__() # this is already an asyncio context
- else:
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
def __ainit__(self):
@@ -277,8 +276,9 @@ class Synchronous(object):
Initiate resources to make this object callable from synchronous contexts.
"""
- with self._loop_thread_lock:
- if not self._no_op and self._loop_thread is None:
+ with self._loop_lock:
+ if not self._no_op and self._loop is None:
+ self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
name = '%s asyncio' % type(self).__name__,
target = self._loop.run_forever,
@@ -294,13 +294,14 @@ class Synchronous(object):
**RuntimeError**.
"""
- with self._loop_thread_lock:
- if not self._no_op and self._loop_thread is not None:
+ with self._loop_lock:
+ if not self._no_op and self._loop is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
if threading.current_thread() != self._loop_thread:
self._loop_thread.join()
+ self._loop = None
self._loop_thread = None
@staticmethod
@@ -330,19 +331,16 @@ class Synchronous(object):
:raises: **AttributeError** if this method doesn't exist
"""
- # Retrieving methods by name (rather than keeping a reference) so runtime
- # replacements like test mocks work.
-
- func = getattr(type(self), method_name)
+ func = getattr(type(self), method_name, None)
- if self._no_op or Synchronous.is_asyncio_context():
+ if not func:
+ raise AttributeError("'%s' does not have a %s method" % (type(self).__name__, method_name))
+ elif self._no_op or Synchronous.is_asyncio_context():
return func(self, *args, **kwargs)
- with self._loop_thread_lock:
- if self._loop_thread is None:
+ with self._loop_lock:
+ if self._loop is None:
raise RuntimeError('%s has been stopped' % type(self).__name__)
- elif not func:
- raise TypeError("'%s' does not have a %s method" % (type(self).__name__, method_name))
# convert iterator if indicated by this method's name or type hint
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index bbf91d18..bfe6113c 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -120,15 +120,20 @@ class TestSynchronous(unittest.TestCase):
def test_stop_from_async(self):
"""
- Ensure we can stop our instance from within an async method without
- deadlock.
+ Ensure we can start and stop our instance from within an async method
+ without deadlock.
"""
class AsyncStop(Synchronous):
+ async def restart(self):
+ self.stop()
+ self.start()
+
async def call_stop(self):
self.stop()
instance = AsyncStop()
+ instance.restart()
instance.call_stop()
self.assertRaises(RuntimeError, instance.call_stop)
1
0

16 Jul '20
commit a2e0da98669a7a173e25a6abe6e3e2e996be8e67
Author: Damian Johnson <atagar(a)torproject.org>
Date: Sun Jul 12 17:01:18 2020 -0700
Fix deadlock when stopping from an async context
Reentrant locks can only be acquired multiple from within the same thread. When
our _run_async_method() invoked start() or stop() we deadlocked.
---
stem/util/__init__.py | 6 ++++--
test/unit/util/synchronous.py | 14 ++++++++++++++
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index c147a5a4..fbd844f8 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -350,9 +350,11 @@ class Synchronous(object):
async def convert_generator(generator: AsyncIterator) -> Iterator:
return iter([d async for d in generator])
- return asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop).result()
+ future = asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop)
else:
- return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+ future = asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop)
+
+ return future.result()
def __iter__(self) -> Iterator:
return self._run_async_method('__aiter__')
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index d4429f3e..bbf91d18 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -118,6 +118,20 @@ class TestSynchronous(unittest.TestCase):
sync_test()
asyncio.run(async_test())
+ def test_stop_from_async(self):
+ """
+ Ensure we can stop our instance from within an async method without
+ deadlock.
+ """
+
+ class AsyncStop(Synchronous):
+ async def call_stop(self):
+ self.stop()
+
+ instance = AsyncStop()
+ instance.call_stop()
+ self.assertRaises(RuntimeError, instance.call_stop)
+
def test_resuming(self):
"""
Resume a previously stopped instance.
1
0

16 Jul '20
commit 75834c06d1d2ade574f11c95f980143408350f11
Author: Damian Johnson <atagar(a)torproject.org>
Date: Tue Jul 14 14:12:47 2020 -0700
Resume Synchronous when async methods are called
When our Synchronous class was stopped all further invocations of an async
method raised a RuntimeError. For most classes (sockets, threads, etc) this is
proper, but it made working with these objects within synchronous contexts
error prone.
For example, our Controller's async connect() method resumes our instance, but
was uncallable due to this behavior. Stopping should be the last action callers
take, and failing to so so is inconsequential (it simply orphans a daemon
thread) so erring toward our object always being callable.
---
stem/util/__init__.py | 13 ++++++------
test/unit/util/synchronous.py | 48 +++++++++++--------------------------------
2 files changed, 18 insertions(+), 43 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index c28b0b27..d780a0de 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -213,8 +213,6 @@ class Synchronous(object):
if self._no_op:
self.__ainit__() # this is already an asyncio context
else:
- Synchronous.start(self)
-
# Run coroutines through our loop. This calls methods by name rather than
# reference so runtime replacements (like mocks) work.
@@ -226,6 +224,7 @@ class Synchronous(object):
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
setattr(self, name, functools.partial(self._run_async_method, name))
+ Synchronous.start(self)
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
def __ainit__(self):
@@ -246,8 +245,8 @@ class Synchronous(object):
#
# However, when constructed from an asynchronous context the above will
# likely hang because our loop is already processing a task (namely,
- # whatever is constructing us). While we can schedule tasks, we cannot
- # invoke it during our construction.
+ # whatever is constructing us). While we can schedule a follow-up task, we
+ # cannot invoke it during our construction.
#
# Finally, when this method is simple we could directly invoke it...
#
@@ -290,8 +289,8 @@ class Synchronous(object):
def stop(self) -> None:
"""
Terminate resources that permits this from being callable from synchronous
- contexts. Once called any further synchronous invocations will fail with a
- **RuntimeError**.
+ contexts. Calling either :func:`~stem.util.Synchronous.start` or any async
+ method will resume us.
"""
with self._loop_lock:
@@ -340,7 +339,7 @@ class Synchronous(object):
with self._loop_lock:
if self._loop is None:
- raise RuntimeError('%s has been stopped' % type(self).__name__)
+ Synchronous.start(self)
# convert iterator if indicated by this method's name or type hint
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index bfe6113c..d99da922 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -92,7 +92,7 @@ class TestSynchronous(unittest.TestCase):
def test_stop(self):
"""
- Synchronous callers should receive a RuntimeError when stopped.
+ Stop and resume our instances.
"""
def sync_test():
@@ -100,11 +100,17 @@ class TestSynchronous(unittest.TestCase):
self.assertEqual('async call', instance.async_method())
instance.stop()
- self.assertRaises(RuntimeError, instance.async_method)
-
- # synchronous methods still work
+ # synchronous methods won't resume us
self.assertEqual('sync call', instance.sync_method())
+ self.assertTrue(instance._loop is None)
+
+ # ... but async methods will
+
+ self.assertEqual('async call', instance.async_method())
+ self.assertTrue(isinstance(instance._loop, asyncio.AbstractEventLoop))
+
+ instance.stop()
async def async_test():
instance = Demo()
@@ -120,7 +126,7 @@ class TestSynchronous(unittest.TestCase):
def test_stop_from_async(self):
"""
- Ensure we can start and stop our instance from within an async method
+ Ensure we can restart and stop our instance from within an async method
without deadlock.
"""
@@ -135,37 +141,7 @@ class TestSynchronous(unittest.TestCase):
instance = AsyncStop()
instance.restart()
instance.call_stop()
- self.assertRaises(RuntimeError, instance.call_stop)
-
- def test_resuming(self):
- """
- Resume a previously stopped instance.
- """
-
- def sync_test():
- instance = Demo()
- self.assertEqual('async call', instance.async_method())
- instance.stop()
-
- self.assertRaises(RuntimeError, instance.async_method)
-
- instance.start()
- self.assertEqual('async call', instance.async_method())
- instance.stop()
-
- async def async_test():
- instance = Demo()
- self.assertEqual('async call', await instance.async_method())
- instance.stop()
-
- # start has no affect on async users
-
- instance.start()
- self.assertEqual('async call', await instance.async_method())
- instance.stop()
-
- sync_test()
- asyncio.run(async_test())
+ self.assertTrue(instance._loop is None)
def test_iteration(self):
"""
1
0
commit 8b539f2facdd86d07e76b5bf5daa379bf0d3d2ba
Author: Damian Johnson <atagar(a)torproject.org>
Date: Thu Jul 9 17:29:58 2020 -0700
Synchronous context management
Make our class handle 'with' statements, and tidy up both its implementation
and tests.
---
stem/util/__init__.py | 79 +++++++++------
test/unit/util/synchronous.py | 223 +++++++++++++++++++++++++++++-------------
2 files changed, 205 insertions(+), 97 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 15840f56..c147a5a4 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,10 +10,11 @@ import datetime
import functools
import inspect
import threading
+import typing
import unittest.mock
from concurrent.futures import Future
-
+from types import TracebackType
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
@@ -201,26 +202,31 @@ class Synchronous(object):
"""
def __init__(self) -> None:
+ self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._loop_thread = None # type: Optional[threading.Thread]
self._loop_thread_lock = threading.RLock()
- if Synchronous.is_asyncio_context():
- self._loop = asyncio.get_running_loop()
+ # this class is a no-op when created from an asyncio context
- self.__ainit__()
- else:
- self._loop = asyncio.new_event_loop()
+ self._no_op = Synchronous.is_asyncio_context()
+ if not self._no_op:
+ self._loop = asyncio.new_event_loop()
Synchronous.start(self)
- # call any coroutines through this loop
+ # call any coroutines through our loop
for name, func in inspect.getmembers(self):
- if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
- setattr(self, name, functools.partial(self._call_async_method, name))
+ if name in ('__aiter__', '__aenter__', '__aexit__'):
+ pass # async object methods with synchronous counterparts
+ elif isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+ setattr(self, name, functools.partial(self._run_async_method, name))
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
- setattr(self, name, functools.partial(self._call_async_method, name))
+ setattr(self, name, functools.partial(self._run_async_method, name))
+ if self._no_op:
+ self.__ainit__() # this is already an asyncio context
+ else:
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
def __ainit__(self):
@@ -272,13 +278,14 @@ class Synchronous(object):
"""
with self._loop_thread_lock:
- self._loop_thread = threading.Thread(
- name = '%s asyncio' % type(self).__name__,
- target = self._loop.run_forever,
- daemon = True,
- )
+ if not self._no_op and self._loop_thread is None:
+ self._loop_thread = threading.Thread(
+ name = '%s asyncio' % type(self).__name__,
+ target = self._loop.run_forever,
+ daemon = True,
+ )
- self._loop_thread.start()
+ self._loop_thread.start()
def stop(self) -> None:
"""
@@ -288,9 +295,13 @@ class Synchronous(object):
"""
with self._loop_thread_lock:
- if self._loop_thread and self._loop_thread.is_alive():
+ if not self._no_op and self._loop_thread is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
+
+ if threading.current_thread() != self._loop_thread:
+ self._loop_thread.join()
+
+ self._loop_thread = None
@staticmethod
def is_asyncio_context() -> bool:
@@ -306,7 +317,7 @@ class Synchronous(object):
except RuntimeError:
return False
- def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+ def _run_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""
Run this async method from either a synchronous or asynchronous context.
@@ -324,25 +335,33 @@ class Synchronous(object):
func = getattr(type(self), method_name)
- if Synchronous.is_asyncio_context():
+ if self._no_op or Synchronous.is_asyncio_context():
return func(self, *args, **kwargs)
with self._loop_thread_lock:
- if self._loop_thread and not self._loop_thread.is_alive():
- raise RuntimeError('%s has been closed' % type(self).__name__)
+ if self._loop_thread is None:
+ raise RuntimeError('%s has been stopped' % type(self).__name__)
+ elif not func:
+ raise TypeError("'%s' does not have a %s method" % (type(self).__name__, method_name))
+
+ # convert iterator if indicated by this method's name or type hint
+
+ if method_name == '__aiter__' or (inspect.ismethod(func) and typing.get_type_hints(func).get('return') == AsyncIterator):
+ async def convert_generator(generator: AsyncIterator) -> Iterator:
+ return iter([d async for d in generator])
- return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+ return asyncio.run_coroutine_threadsafe(convert_generator(func(self, *args, **kwargs)), self._loop).result()
+ else:
+ return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
def __iter__(self) -> Iterator:
- async def convert_generator(generator: AsyncIterator) -> Iterator:
- return iter([d async for d in generator])
+ return self._run_async_method('__aiter__')
- iter_func = getattr(self, '__aiter__', None)
+ def __enter__(self):
+ return self._run_async_method('__aenter__')
- if iter_func:
- return asyncio.run_coroutine_threadsafe(convert_generator(iter_func()), self._loop).result()
- else:
- raise TypeError("'%s' object is not iterable" % type(self).__name__)
+ def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
+ return self._run_async_method('__aexit__', exit_type, value, traceback)
class AsyncClassWrapper:
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index 5b38a7b5..d4429f3e 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -17,12 +17,34 @@ hello from an asynchronous context
"""
-class Example(Synchronous):
- async def hello(self):
- return 'hello'
+class Demo(Synchronous):
+ def __init__(self):
+ super(Demo, self).__init__()
+
+ self.called_enter = False
+ self.called_exit = False
+
+ def __ainit__(self):
+ self.ainit_loop = asyncio.get_running_loop()
+
+ async def async_method(self):
+ return 'async call'
+
+ def sync_method(self):
+ return 'sync call'
+
+ async def __aiter__(self):
+ for i in range(3):
+ yield i
+
+ async def __aenter__(self):
+ self.called_enter = True
+ return self
+
+ async def __aexit__(self, exit_type, value, traceback):
+ self.called_exit = True
+ return
- def sync_hello(self):
- return 'hello'
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@@ -31,6 +53,10 @@ class TestSynchronous(unittest.TestCase):
Run the example from our pydoc.
"""
+ class Example(Synchronous):
+ async def hello(self):
+ return 'hello'
+
def sync_demo():
instance = Example()
print('%s from a synchronous context' % instance.hello())
@@ -48,102 +74,165 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self):
"""
- Check that our constructor runs __ainit__ when present.
+ Check that construction runs __ainit__ with a loop when present.
"""
- class AinitDemo(Synchronous):
- def __init__(self):
- super(AinitDemo, self).__init__()
-
- def __ainit__(self):
- self.ainit_loop = asyncio.get_running_loop()
-
- def sync_demo():
- instance = AinitDemo()
- self.assertTrue(hasattr(instance, 'ainit_loop'))
+ def sync_test():
+ instance = Demo()
+ self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+ instance.stop()
- async def async_demo():
- instance = AinitDemo()
- self.assertTrue(hasattr(instance, 'ainit_loop'))
+ async def async_test():
+ instance = Demo()
+ self.assertTrue(isinstance(instance.ainit_loop, asyncio.AbstractEventLoop))
+ instance.stop()
- sync_demo()
- asyncio.run(async_demo())
+ sync_test()
+ asyncio.run(async_test())
- def test_after_stop(self):
+ def test_stop(self):
"""
- Check that stopped instances raise a RuntimeError to synchronous callers.
+ Synchronous callers should receive a RuntimeError when stopped.
"""
- # stop a used instance
+ def sync_test():
+ instance = Demo()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ self.assertRaises(RuntimeError, instance.async_method)
+
+ # synchronous methods still work
- instance = Example()
- self.assertEqual('hello', instance.hello())
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
+ self.assertEqual('sync call', instance.sync_method())
- # stop an unused instance
+ async def async_test():
+ instance = Demo()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ # stop has no affect on async users
+
+ self.assertEqual('async call', await instance.async_method())
- instance = Example()
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
+ sync_test()
+ asyncio.run(async_test())
def test_resuming(self):
"""
Resume a previously stopped instance.
"""
- instance = Example()
- self.assertEqual('hello', instance.hello())
- instance.stop()
- self.assertRaises(RuntimeError, instance.hello)
- instance.start()
- self.assertEqual('hello', instance.hello())
- instance.stop()
+ def sync_test():
+ instance = Demo()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ self.assertRaises(RuntimeError, instance.async_method)
- def test_asynchronous_mockability(self):
+ instance.start()
+ self.assertEqual('async call', instance.async_method())
+ instance.stop()
+
+ async def async_test():
+ instance = Demo()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ # start has no affect on async users
+
+ instance.start()
+ self.assertEqual('async call', await instance.async_method())
+ instance.stop()
+
+ sync_test()
+ asyncio.run(async_test())
+
+ def test_iteration(self):
"""
- Check that method mocks are respected.
+ Check that we can iterate in both contexts.
"""
- # mock prior to construction
+ def sync_test():
+ instance = Demo()
+ result = []
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
- instance = Example()
- self.assertEqual('mocked hello', instance.hello())
+ for val in instance:
+ result.append(val)
- self.assertEqual('hello', instance.hello()) # mock should now be reverted
- instance.stop()
+ self.assertEqual([0, 1, 2], result)
+ instance.stop()
- # mock after construction
+ async def async_test():
+ instance = Demo()
+ result = []
- instance = Example()
+ async for val in instance:
+ result.append(val)
- with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
- self.assertEqual('mocked hello', instance.hello())
+ self.assertEqual([0, 1, 2], result)
+ instance.stop()
- self.assertEqual('hello', instance.hello())
- instance.stop()
+ sync_test()
+ asyncio.run(async_test())
- def test_synchronous_mockability(self):
+ def test_context_management(self):
"""
- Ensure we do not disrupt non-asynchronous method mocks.
+ Exercise context management via 'with' statements.
"""
- # mock prior to construction
+ def sync_test():
+ instance = Demo()
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
- instance = Example()
- self.assertEqual('mocked hello', instance.sync_hello())
+ self.assertFalse(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ with instance:
+ self.assertTrue(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ self.assertTrue(instance.called_enter)
+ self.assertTrue(instance.called_exit)
+
+ async def async_test():
+ instance = Demo()
+
+ self.assertFalse(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ async with instance:
+ self.assertTrue(instance.called_enter)
+ self.assertFalse(instance.called_exit)
+
+ self.assertTrue(instance.called_enter)
+ self.assertTrue(instance.called_exit)
+
+ sync_test()
+ asyncio.run(async_test())
+
+ def test_mockability(self):
+ """
+ Check that method mocks are respected for both previously constructed
+ instances and those made after the mock.
+ """
+
+ pre_constructed = Demo()
+
+ with patch('test.unit.util.synchronous.Demo.async_method', Mock(side_effect = coro_func_returning_value('mocked call'))):
+ post_constructed = Demo()
+
+ self.assertEqual('mocked call', pre_constructed.async_method())
+ self.assertEqual('mocked call', post_constructed.async_method())
- self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
- instance.stop()
+ self.assertEqual('async call', pre_constructed.async_method())
+ self.assertEqual('async call', post_constructed.async_method())
- # mock after construction
+ # synchronous methods are unaffected
- instance = Example()
+ with patch('test.unit.util.synchronous.Demo.sync_method', Mock(return_value = 'mocked call')):
+ self.assertEqual('mocked call', pre_constructed.sync_method())
- with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
- self.assertEqual('mocked hello', instance.sync_hello())
+ self.assertEqual('sync call', pre_constructed.sync_method())
- self.assertEqual('hello', instance.sync_hello())
- instance.stop()
+ pre_constructed.stop()
+ post_constructed.stop()
1
0
commit f4b77a2c9194aa964aec0e633255e549186b7f2d
Author: Damian Johnson <atagar(a)torproject.org>
Date: Tue Jul 14 14:54:01 2020 -0700
Always test for lingering threads
Our unit tests are just as liable to orphan threads as our integration tests.
It's confusing to only detect unit test leaks when running them along side our
integration tests, so making this check independent of which test suite we run.
---
run_tests.py | 15 ++++++---------
stem/descriptor/remote.py | 3 +++
test/unit/util/synchronous.py | 4 ++++
3 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/run_tests.py b/run_tests.py
index fd46211f..c738f9b8 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -321,18 +321,15 @@ def main():
integ_runner.stop()
println()
- # We should have joined on all threads. If not then that indicates a
- # leak that could both likely be a bug and disrupt further targets.
+ # ensure that we join all our threads
- active_threads = threading.enumerate()
+ active_threads = threading.enumerate()
- if len(active_threads) > 1:
- println('Threads lingering after test run:', ERROR)
+ if len(active_threads) > 1:
+ println('Threads lingering after test run:', ERROR)
- for lingering_thread in active_threads:
- println(' %s' % lingering_thread, ERROR)
-
- break
+ for lingering_thread in active_threads:
+ println(' %s' % lingering_thread, ERROR)
static_check_issues = {}
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index ad6a02ab..942d81e9 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -369,6 +369,7 @@ class Query(Synchronous):
super(Query, self).__init__()
if not resource.startswith('/'):
+ self.stop()
raise ValueError("Resources should start with a '/': %s" % resource)
if resource.endswith('.z'):
@@ -379,6 +380,7 @@ class Query(Synchronous):
elif isinstance(compression, stem.descriptor._Compression):
compression = [compression] # caller provided only a single option
else:
+ self.stop()
raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))
if Compression.ZSTD in compression and not Compression.ZSTD.available:
@@ -402,6 +404,7 @@ class Query(Synchronous):
if isinstance(endpoint, (stem.ORPort, stem.DirPort)):
self.endpoints.append(endpoint)
else:
+ self.stop()
raise ValueError("Endpoints must be an stem.ORPort or stem.DirPort. '%s' is a %s." % (endpoint, type(endpoint).__name__))
self.resource = resource
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index d99da922..602d81f4 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -189,6 +189,8 @@ class TestSynchronous(unittest.TestCase):
self.assertTrue(instance.called_enter)
self.assertTrue(instance.called_exit)
+ instance.stop()
+
async def async_test():
instance = Demo()
@@ -202,6 +204,8 @@ class TestSynchronous(unittest.TestCase):
self.assertTrue(instance.called_enter)
self.assertTrue(instance.called_exit)
+ instance.stop()
+
sync_test()
asyncio.run(async_test())
1
0
commit ef1e41ebce0aa1bb5fde9064410402bee9887451
Author: Damian Johnson <atagar(a)torproject.org>
Date: Wed Jul 8 17:12:39 2020 -0700
Synchronous class mockability
The meta-programming behind our Synchronous class doesn't play well with test
mocks. Handling this, and testing the permutations I can think of.
---
stem/util/__init__.py | 49 +++++++++++++++++++++++++++-----------
test/unit/util/synchronous.py | 55 +++++++++++++++++++++++++++++++++++++++++--
2 files changed, 88 insertions(+), 16 deletions(-)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index 54f90376..15840f56 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -10,9 +10,11 @@ import datetime
import functools
import inspect
import threading
+import unittest.mock
+
from concurrent.futures import Future
-from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
+from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
'conf',
@@ -213,19 +215,11 @@ class Synchronous(object):
# call any coroutines through this loop
- def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
- if Synchronous.is_asyncio_context():
- return func(*args, **kwargs)
-
- with self._loop_thread_lock:
- if not self._loop_thread.is_alive():
- raise RuntimeError('%s has been stopped' % type(self).__name__)
-
- return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
-
- for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
- if inspect.iscoroutinefunction(func):
- setattr(self, method_name, functools.partial(call_async, func))
+ for name, func in inspect.getmembers(self):
+ if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
+ setattr(self, name, functools.partial(self._call_async_method, name))
+ elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
+ setattr(self, name, functools.partial(self._call_async_method, name))
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
@@ -312,6 +306,33 @@ class Synchronous(object):
except RuntimeError:
return False
+ def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
+ """
+ Run this async method from either a synchronous or asynchronous context.
+
+ :param method_name: name of the method to invoke
+ :param args: positional arguments
+ :param kwargs: keyword arguments
+
+ :returns: method's return value
+
+ :raises: **AttributeError** if this method doesn't exist
+ """
+
+ # Retrieving methods by name (rather than keeping a reference) so runtime
+ # replacements like test mocks work.
+
+ func = getattr(type(self), method_name)
+
+ if Synchronous.is_asyncio_context():
+ return func(self, *args, **kwargs)
+
+ with self._loop_thread_lock:
+ if self._loop_thread and not self._loop_thread.is_alive():
+ raise RuntimeError('%s has been closed' % type(self).__name__)
+
+ return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
+
def __iter__(self) -> Iterator:
async def convert_generator(generator: AsyncIterator) -> Iterator:
return iter([d async for d in generator])
diff --git a/test/unit/util/synchronous.py b/test/unit/util/synchronous.py
index dd27c3c6..5b38a7b5 100644
--- a/test/unit/util/synchronous.py
+++ b/test/unit/util/synchronous.py
@@ -6,9 +6,10 @@ import asyncio
import io
import unittest
-from unittest.mock import patch
+from unittest.mock import patch, Mock
from stem.util import Synchronous
+from stem.util.test_tools import coro_func_returning_value
EXAMPLE_OUTPUT = """\
hello from a synchronous context
@@ -20,6 +21,8 @@ class Example(Synchronous):
async def hello(self):
return 'hello'
+ def sync_hello(self):
+ return 'hello'
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self):
"""
- Check that our constructor runs __ainit__ if present.
+ Check that our constructor runs __ainit__ when present.
"""
class AinitDemo(Synchronous):
@@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
instance.start()
self.assertEqual('hello', instance.hello())
instance.stop()
+
+ def test_asynchronous_mockability(self):
+ """
+ Check that method mocks are respected.
+ """
+
+ # mock prior to construction
+
+ with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+ instance = Example()
+ self.assertEqual('mocked hello', instance.hello())
+
+ self.assertEqual('hello', instance.hello()) # mock should now be reverted
+ instance.stop()
+
+ # mock after construction
+
+ instance = Example()
+
+ with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
+ self.assertEqual('mocked hello', instance.hello())
+
+ self.assertEqual('hello', instance.hello())
+ instance.stop()
+
+ def test_synchronous_mockability(self):
+ """
+ Ensure we do not disrupt non-asynchronous method mocks.
+ """
+
+ # mock prior to construction
+
+ with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+ instance = Example()
+ self.assertEqual('mocked hello', instance.sync_hello())
+
+ self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
+ instance.stop()
+
+ # mock after construction
+
+ instance = Example()
+
+ with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
+ self.assertEqual('mocked hello', instance.sync_hello())
+
+ self.assertEqual('hello', instance.sync_hello())
+ instance.stop()
1
0
commit dc93ee7257b9a0eed1a4632dec3c7b16a65e9782
Author: Damian Johnson <atagar(a)torproject.org>
Date: Thu Jun 25 16:13:54 2020 -0700
Use Synchronous for Controller
Finally migrating our Controller class from Illia's AsyncClassWrapper to our
Synchronous mixin.
Benefits are...
* Class no longer requires a synchronous and asynchronous copy.
* Controller can be implemented as a fully asynchronous class, while still
functioning in synchronous contexts.
Downside is...
* Python type checkers (like mypy) only recognice our Controller as an
asynchronous class, producing false positives for synchronous users.
---
run_tests.py | 4 +
stem/connection.py | 9 +-
stem/control.py | 579 +++++-----------------
stem/descriptor/remote.py | 8 +-
stem/interpreter/__init__.py | 2 +-
stem/interpreter/commands.py | 4 +-
stem/util/__init__.py | 43 +-
stem/util/test_tools.py | 3 +
test/integ/connection/authentication.py | 13 +-
test/integ/control/controller.py | 818 ++++++++++++++++++--------------
test/runner.py | 64 ++-
test/settings.cfg | 8 +
test/unit/control/controller.py | 119 +++--
test/unit/descriptor/remote.py | 2 +-
14 files changed, 715 insertions(+), 961 deletions(-)
diff --git a/run_tests.py b/run_tests.py
index c738f9b8..5218008f 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -259,6 +259,10 @@ def main():
# 2.7 or later because before that test results didn't have a 'skipped'
# attribute.
+ # TODO: handling of earlier python versions is no longer necessary here
+ # TODO: this invokes all asynchronous tests, even if we have a --test or
+ # --exclude-test argument
+
skipped_tests = 0
if args.run_integ:
diff --git a/stem/connection.py b/stem/connection.py
index 86d32d7f..8495da2a 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -89,7 +89,7 @@ fine-grained control over the authentication process. For instance...
::
connect - Simple method for getting authenticated control connection for synchronous usage.
- async_connect - Simple method for getting authenticated control connection for asynchronous usage.
+ async_connect - Simple method for getting authenticated control connection for asynchronous usage.
authenticate - Main method for authenticating to a control socket
authenticate_none - Authenticates to an open control socket
@@ -292,7 +292,7 @@ def connect(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default')
raise
-async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.AsyncController) -> Any:
+async def connect_async(control_port: Tuple[str, Union[str, int]] = ('127.0.0.1', 'default'), control_socket: str = '/var/run/tor/control', password: Optional[str] = None, password_prompt: bool = False, chroot_path: Optional[str] = None, controller: Type[stem.control.BaseController] = stem.control.Controller) -> Any:
"""
Convenience function for quickly getting a control connection for
asynchronous usage. This is very handy for debugging or CLI setup, handling
@@ -364,6 +364,7 @@ async def _connect_async(control_port: Tuple[str, Union[str, int]], control_sock
control_connection = _connection_for_default_port(address)
else:
control_connection = stem.socket.ControlPort(address, int(port))
+
await control_connection.connect()
except stem.SocketError as exc:
error_msg = CONNECT_MESSAGES['unable_to_use_port'].format(address = address, port = port, error = exc)
@@ -405,9 +406,7 @@ async def _connect_auth(control_socket: stem.socket.ControlSocket, password: str
if controller is None:
return control_socket
- elif issubclass(controller, stem.control.BaseController) or issubclass(controller, stem.control.Controller):
- # TODO: Controller no longer extends BaseController (we'll probably change that)
-
+ else:
return controller(control_socket, is_authenticated = True)
except IncorrectSocketType:
if isinstance(control_socket, stem.socket.ControlPort):
diff --git a/stem/control.py b/stem/control.py
index 47ddaa35..7b90eed0 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -269,9 +269,9 @@ import stem.util.tor_tools
import stem.version
from stem import UNDEFINED, CircStatus, Signal
-from stem.util import log
+from stem.util import Synchronous, log
from types import TracebackType
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
# When closing the controller we attempt to finish processing enqueued events,
# but if it takes longer than this we terminate.
@@ -553,56 +553,7 @@ def event_description(event: str) -> str:
return EVENT_DESCRIPTIONS.get(event.lower())
-class _BaseControllerSocketMixin:
- _socket: stem.socket.ControlSocket
-
- def is_alive(self) -> bool:
- """
- Checks if our socket is currently connected. This is a pass-through for our
- socket's :func:`~stem.socket.BaseSocket.is_alive` method.
-
- :returns: **bool** that's **True** if our socket is connected and **False** otherwise
- """
-
- return self._socket.is_alive()
-
- def is_localhost(self) -> bool:
- """
- Returns if the connection is for the local system or not.
-
- .. versionadded:: 1.3.0
-
- :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise
- """
-
- return self._socket.is_localhost()
-
- def connection_time(self) -> float:
- """
- Provides the unix timestamp for when our socket was either connected or
- disconnected. That is to say, the time we connected if we're currently
- connected and the time we disconnected if we're not connected.
-
- .. versionadded:: 1.3.0
-
- :returns: **float** for when we last connected or disconnected, zero if
- we've never connected
- """
-
- return self._socket.connection_time()
-
- def get_socket(self) -> stem.socket.ControlSocket:
- """
- Provides the socket used to speak with the tor process. Communicating with
- the socket directly isn't advised since it may confuse this controller.
-
- :returns: :class:`~stem.socket.ControlSocket` we're communicating with
- """
-
- return self._socket
-
-
-class BaseController(_BaseControllerSocketMixin):
+class BaseController(Synchronous):
"""
Controller for the tor process. This is a minimal base class for other
controllers, providing basic process communication and event listing. Don't
@@ -619,21 +570,13 @@ class BaseController(_BaseControllerSocketMixin):
"""
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
- self._socket = control_socket
+ super(BaseController, self).__init__()
- self._asyncio_loop = asyncio.get_event_loop()
-
- self._msg_lock = asyncio.Lock()
+ self._socket = control_socket
self._status_listeners = [] # type: List[Tuple[Callable[[stem.control.BaseController, stem.control.State, float], None], bool]] # tuples of the form (callback, spawn_thread)
self._status_listeners_lock = threading.RLock()
- # queues where incoming messages are directed
- self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]]
- self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage]
-
- self._event_notice = asyncio.Event()
-
# saves our socket's prior _connect() and _close() methods so they can be
# called along with ours
@@ -650,11 +593,22 @@ class BaseController(_BaseControllerSocketMixin):
self._reader_loop_task = None # type: Optional[asyncio.Task]
self._event_loop_task = None # type: Optional[asyncio.Task]
+
if self._socket.is_alive():
self._create_loop_tasks()
if is_authenticated:
- self._asyncio_loop.create_task(self._post_authentication())
+ self._loop.create_task(self._post_authentication())
+
+ def __ainit__(self) -> None:
+ self._msg_lock = asyncio.Lock()
+
+ # queues where incoming messages are directed
+
+ self._reply_queue = asyncio.Queue() # type: asyncio.Queue[Union[stem.response.ControlMessage, stem.ControllerError]]
+ self._event_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.ControlMessage]
+
+ self._event_notice = asyncio.Event()
async def msg(self, message: str) -> stem.response.ControlMessage:
"""
@@ -736,8 +690,43 @@ class BaseController(_BaseControllerSocketMixin):
# provide an assurance to the caller that when we raise a SocketClosed
# exception we are shut down afterward for realz.
- await self.close()
- raise
+ await self.close()
+ raise
+
+ def is_alive(self) -> bool:
+ """
+ Checks if our socket is currently connected. This is a pass-through for our
+ socket's :func:`~stem.socket.BaseSocket.is_alive` method.
+
+ :returns: **bool** that's **True** if our socket is connected and **False** otherwise
+ """
+
+ return self._socket.is_alive()
+
+ def is_localhost(self) -> bool:
+ """
+ Returns if the connection is for the local system or not.
+
+ .. versionadded:: 1.3.0
+
+ :returns: **bool** that's **True** if the connection is for the local host and **False** otherwise
+ """
+
+ return self._socket.is_localhost()
+
+ def connection_time(self) -> float:
+ """
+ Provides the unix timestamp for when our socket was either connected or
+ disconnected. That is to say, the time we connected if we're currently
+ connected and the time we disconnected if we're not connected.
+
+ .. versionadded:: 1.3.0
+
+ :returns: **float** for when we last connected or disconnected, zero if
+ we've never connected
+ """
+
+ return self._socket.connection_time()
def is_authenticated(self) -> bool:
"""
@@ -778,6 +767,18 @@ class BaseController(_BaseControllerSocketMixin):
if t.is_alive() and threading.current_thread() != t:
t.join()
+ self.stop()
+
+ def get_socket(self) -> stem.socket.ControlSocket:
+ """
+ Provides the socket used to speak with the tor process. Communicating with
+ the socket directly isn't advised since it may confuse this controller.
+
+ :returns: :class:`~stem.socket.ControlSocket` we're communicating with
+ """
+
+ return self._socket
+
def get_latest_heartbeat(self) -> float:
"""
Provides the unix timestamp for when we last heard from tor. This is zero
@@ -858,7 +859,7 @@ class BaseController(_BaseControllerSocketMixin):
async def _connect(self) -> None:
self._create_loop_tasks()
- await self._notify_status_listeners(State.INIT, acquire_send_lock=False)
+ await self._notify_status_listeners(State.INIT, acquire_send_lock = False)
await self._socket_connect()
self._is_authenticated = False
@@ -874,13 +875,14 @@ class BaseController(_BaseControllerSocketMixin):
self._reader_loop_task = None
event_loop_task = self._event_loop_task
self._event_loop_task = None
+
if reader_loop_task and self.is_alive():
await reader_loop_task
+
if event_loop_task:
await event_loop_task
- await self._notify_status_listeners(State.CLOSED, acquire_send_lock=False)
-
+ await self._notify_status_listeners(State.CLOSED, acquire_send_lock = False)
await self._socket_close()
async def _post_authentication(self) -> None:
@@ -899,6 +901,7 @@ class BaseController(_BaseControllerSocketMixin):
# need to have it to ensure it doesn't change beneath us.
send_lock = self._socket._get_send_lock()
+
try:
if acquire_send_lock:
await send_lock.acquire()
@@ -944,8 +947,8 @@ class BaseController(_BaseControllerSocketMixin):
them if we're restarted.
"""
- self._reader_loop_task = self._asyncio_loop.create_task(self._reader_loop())
- self._event_loop_task = self._asyncio_loop.create_task(self._event_loop())
+ self._reader_loop_task = self._loop.create_task(self._reader_loop())
+ self._event_loop_task = self._loop.create_task(self._event_loop())
async def _reader_loop(self) -> None:
"""
@@ -1011,21 +1014,24 @@ class BaseController(_BaseControllerSocketMixin):
self._event_notice.clear()
-class AsyncController(BaseController):
+class Controller(BaseController):
"""
Connection with Tor's control socket. This is built on top of the
BaseController and provides a more user friendly API for library users.
"""
- @classmethod
- def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'AsyncController':
+ @staticmethod
+ def from_port(address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'stem.control.Controller':
"""
- Constructs a :class:`~stem.socket.ControlPort` based AsyncController.
+ Constructs a :class:`~stem.socket.ControlPort` based Controller.
If the **port** is **'default'** then this checks on both 9051 (default
for relays) and 9151 (default for the Tor Browser). This default may change
in the future.
+ .. versionchanged:: 1.5.0
+ Use both port 9051 and 9151 by default.
+
:param address: ip address of the controller
:param port: port number of the controller
@@ -1034,13 +1040,31 @@ class AsyncController(BaseController):
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- control_socket = _init_control_port(address, port)
- return cls(control_socket)
+ import stem.connection
+
+ if not stem.util.connection.is_valid_ipv4_address(address):
+ raise ValueError('Invalid IP address: %s' % address)
+ elif port != 'default' and not stem.util.connection.is_valid_port(port):
+ raise ValueError('Invalid port: %s' % port)
+
+ if port == 'default':
+ control_port = stem.connection._connection_for_default_port(address)
+ else:
+ control_port = stem.socket.ControlPort(address, int(port))
+
+ controller = Controller(control_port)
- @classmethod
- def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'AsyncController':
+ try:
+ controller.connect()
+ return controller
+ except:
+ controller.stop()
+ raise
+
+ @staticmethod
+ def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller':
"""
- Constructs a :class:`~stem.socket.ControlSocketFile` based AsyncController.
+ Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
:param path: path where the control socket is located
@@ -1049,8 +1073,15 @@ class AsyncController(BaseController):
:raises: :class:`stem.SocketError` if we're unable to establish a connection
"""
- control_socket = _init_control_socket_file(path)
- return cls(control_socket)
+ control_socket = stem.socket.ControlSocketFile(path)
+ controller = Controller(control_socket)
+
+ try:
+ controller.connect()
+ return controller
+ except:
+ controller.stop()
+ raise
def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
self._is_caching_enabled = True
@@ -1062,13 +1093,12 @@ class AsyncController(BaseController):
# mapping of event types to their listeners
self._event_listeners = {} # type: Dict[stem.control.EventType, List[Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]]]
- self._event_listeners_lock = asyncio.Lock()
self._enabled_features = [] # type: List[str]
self._last_address_exc = None # type: Optional[BaseException]
self._last_fingerprint_exc = None # type: Optional[BaseException]
- super(AsyncController, self).__init__(control_socket, is_authenticated)
+ super(Controller, self).__init__(control_socket, is_authenticated)
async def _sighup_listener(event: stem.response.events.SignalEvent) -> None:
if event.signal == Signal.RELOAD:
@@ -1101,11 +1131,16 @@ class AsyncController(BaseController):
self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER),
)
- self._asyncio_loop.create_task(_add_event_listeners())
+ self._loop.create_task(_add_event_listeners())
+
+ def __ainit__(self):
+ super(Controller, self).__ainit__()
+
+ self._event_listeners_lock = asyncio.Lock()
async def close(self) -> None:
self.clear_cache()
- await super(AsyncController, self).close()
+ await super(Controller, self).close()
async def authenticate(self, *args: Any, **kwargs: Any) -> None:
"""
@@ -1186,7 +1221,7 @@ class AsyncController(BaseController):
raise stem.ProtocolError('Tor geoip database is unavailable')
elif param == 'address' and self._last_address_exc:
raise self._last_address_exc # we already know we can't resolve an address
- elif param == 'fingerprint' and self._last_fingerprint_exc and self.get_conf('ORPort', None) is None:
+ elif param == 'fingerprint' and self._last_fingerprint_exc and await self.get_conf('ORPort', None) is None:
raise self._last_fingerprint_exc # we already know we're not a relay
# check for cached results
@@ -2082,7 +2117,6 @@ class AsyncController(BaseController):
request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
response = stem.response._convert_to_single_line(await self.msg(request))
- stem.response.convert('SINGLELINE', response)
if not response.is_ok():
raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code)
@@ -3778,7 +3812,7 @@ class AsyncController(BaseController):
await self.msg('DROPGUARDS')
async def _post_authentication(self) -> None:
- await super(AsyncController, self)._post_authentication()
+ await super(Controller, self)._post_authentication()
# try to re-attach event listeners to the new instance
@@ -3834,9 +3868,10 @@ class AsyncController(BaseController):
if listener_type == event_type:
for listener in event_listeners:
try:
- potential_coroutine = listener(event)
- if asyncio.iscoroutine(potential_coroutine):
- await potential_coroutine
+ listener_call = listener(event)
+
+ if asyncio.iscoroutine(listener_call):
+ await listener_call
except Exception as exc:
log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
@@ -3883,346 +3918,6 @@ class AsyncController(BaseController):
return (set_events, failed_events)
-def _set_doc_from_async_controller(func: Callable) -> Callable:
- func.__doc__ = getattr(AsyncController, func.__name__).__doc__
- return func
-
-
-class Controller(_BaseControllerSocketMixin, stem.util.AsyncClassWrapper):
- """
- Connection with Tor's control socket. This wraps
- :class:`~stem.control.AsyncController` to provide a synchronous
- interface and for backwards compatibility.
- """
-
- @classmethod
- def from_port(cls, address: str = '127.0.0.1', port: Union[int, str] = 'default') -> 'Controller':
- """
- Constructs a :class:`~stem.socket.ControlPort` based Controller.
-
- If the **port** is **'default'** then this checks on both 9051 (default
- for relays) and 9151 (default for the Tor Browser). This default may change
- in the future.
-
- .. versionchanged:: 1.5.0
- Use both port 9051 and 9151 by default.
-
- :param address: ip address of the controller
- :param port: port number of the controller
-
- :returns: :class:`~stem.control.Controller` attached to the given port
-
- :raises: :class:`stem.SocketError` if we're unable to establish a connection
- """
-
- control_socket = _init_control_port(address, port)
- controller = cls(control_socket)
- controller.connect()
- return controller
-
- @classmethod
- def from_socket_file(cls, path: str = '/var/run/tor/control') -> 'Controller':
- """
- Constructs a :class:`~stem.socket.ControlSocketFile` based Controller.
-
- :param str path: path where the control socket is located
-
- :returns: :class:`~stem.control.Controller` attached to the given socket file
-
- :raises: :class:`stem.SocketError` if we're unable to establish a connection
- """
-
- control_socket = _init_control_socket_file(path)
- controller = cls(control_socket)
- controller.connect()
- return controller
-
- def __init__(
- self,
- control_socket: stem.socket.ControlSocket,
- is_authenticated: bool = False,
- ) -> None:
- # if within an asyncio context use its loop, otherwise spawn our own
-
- try:
- self._loop = asyncio.get_running_loop()
- self._loop_thread = threading.current_thread()
- except RuntimeError:
- self._loop = asyncio.new_event_loop()
- self._loop_thread = threading.Thread(target = self._loop.run_forever, name = 'asyncio')
- self._loop_thread.setDaemon(True)
- self._loop_thread.start()
-
- self._wrapped_instance: AsyncController = self._init_async_class(AsyncController, control_socket, is_authenticated) # type: ignore
- self._socket = self._wrapped_instance._socket
-
- @_set_doc_from_async_controller
- def msg(self, message: str) -> stem.response.ControlMessage:
- return self._execute_async_method('msg', message)
-
- @_set_doc_from_async_controller
- def is_authenticated(self) -> bool:
- return self._wrapped_instance.is_authenticated()
-
- @_set_doc_from_async_controller
- def connect(self) -> None:
- self._execute_async_method('connect')
-
- @_set_doc_from_async_controller
- def reconnect(self, *args: Any, **kwargs: Any) -> None:
- self._execute_async_method('reconnect', *args, **kwargs)
-
- @_set_doc_from_async_controller
- def close(self) -> None:
- self._execute_async_method('close')
-
- @_set_doc_from_async_controller
- def get_latest_heartbeat(self) -> float:
- return self._wrapped_instance.get_latest_heartbeat()
-
- @_set_doc_from_async_controller
- def add_status_listener(self, callback: Callable[['stem.control.BaseController', 'stem.control.State', float], None], spawn: bool = True) -> None:
- self._wrapped_instance.add_status_listener(callback, spawn)
-
- @_set_doc_from_async_controller
- def remove_status_listener(self, callback: Callable[['stem.control.Controller', 'stem.control.State', float], None]) -> bool:
- return self._wrapped_instance.remove_status_listener(callback)
-
- @_set_doc_from_async_controller
- def authenticate(self, *args: Any, **kwargs: Any) -> None:
- self._execute_async_method('authenticate', *args, **kwargs)
-
- @_set_doc_from_async_controller
- def get_info(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, get_bytes: bool = False) -> Union[str, Dict[str, str]]:
- return self._execute_async_method('get_info', params, default, get_bytes)
-
- @_set_doc_from_async_controller
- def get_version(self, default: Any = UNDEFINED) -> stem.version.Version:
- return self._execute_async_method('get_version', default)
-
- @_set_doc_from_async_controller
- def get_exit_policy(self, default: Any = UNDEFINED) -> stem.exit_policy.ExitPolicy:
- return self._execute_async_method('get_exit_policy', default)
-
- @_set_doc_from_async_controller
- def get_ports(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[int]:
- return self._execute_async_method('get_ports', listener_type, default)
-
- @_set_doc_from_async_controller
- def get_listeners(self, listener_type: 'stem.control.Listener', default: Any = UNDEFINED) -> Sequence[Tuple[str, int]]:
- return self._execute_async_method('get_listeners', listener_type, default)
-
- @_set_doc_from_async_controller
- def get_accounting_stats(self, default: Any = UNDEFINED) -> 'stem.control.AccountingStats':
- return self._execute_async_method('get_accounting_stats', default)
-
- @_set_doc_from_async_controller
- def get_protocolinfo(self, default: Any = UNDEFINED) -> stem.response.protocolinfo.ProtocolInfoResponse:
- return self._execute_async_method('get_protocolinfo', default)
-
- @_set_doc_from_async_controller
- def get_user(self, default: Any = UNDEFINED) -> str:
- return self._execute_async_method('get_user', default)
-
- @_set_doc_from_async_controller
- def get_pid(self, default: Any = UNDEFINED) -> int:
- return self._execute_async_method('get_pid', default)
-
- @_set_doc_from_async_controller
- def get_start_time(self, default: Any = UNDEFINED) -> float:
- return self._execute_async_method('get_start_time', default)
-
- @_set_doc_from_async_controller
- def get_uptime(self, default: Any = UNDEFINED) -> float:
- return self._execute_async_method('get_uptime', default)
-
- @_set_doc_from_async_controller
- def is_user_traffic_allowed(self) -> 'stem.control.UserTrafficAllowed':
- return self._execute_async_method('is_user_traffic_allowed')
-
- @_set_doc_from_async_controller
- def get_microdescriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.microdescriptor.Microdescriptor:
- return self._execute_async_method('get_microdescriptor', relay, default)
-
- @_set_doc_from_async_controller
- def get_microdescriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.microdescriptor.Microdescriptor]:
- return self._execute_async_generator_method('get_microdescriptors', default)
-
- @_set_doc_from_async_controller
- def get_server_descriptor(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.server_descriptor.RelayDescriptor:
- return self._execute_async_method('get_server_descriptor', relay, default)
-
- @_set_doc_from_async_controller
- def get_server_descriptors(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.server_descriptor.RelayDescriptor]:
- return self._execute_async_generator_method('get_server_descriptors', default)
-
- @_set_doc_from_async_controller
- def get_network_status(self, relay: Optional[str] = None, default: Any = UNDEFINED) -> stem.descriptor.router_status_entry.RouterStatusEntryV3:
- return self._execute_async_method('get_network_status', relay, default)
-
- @_set_doc_from_async_controller
- def get_network_statuses(self, default: Any = UNDEFINED) -> Iterator[stem.descriptor.router_status_entry.RouterStatusEntryV3]:
- return self._execute_async_generator_method('get_network_statuses', default)
-
- @_set_doc_from_async_controller
- def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2:
- return self._execute_async_method('get_hidden_service_descriptor', address, default, servers, await_result, timeout)
-
- @_set_doc_from_async_controller
- def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
- return self._execute_async_method('get_conf', param, default, multiple)
-
- @_set_doc_from_async_controller
- def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
- return self._execute_async_method('get_conf_map', params, default, multiple)
-
- @_set_doc_from_async_controller
- def is_set(self, param: str, default: Any = UNDEFINED) -> bool:
- return self._execute_async_method('is_set', param, default)
-
- @_set_doc_from_async_controller
- def set_conf(self, param: str, value: Union[str, Sequence[str]]) -> None:
- self._execute_async_method('set_conf', param, value)
-
- @_set_doc_from_async_controller
- def reset_conf(self, *params: str) -> None:
- self._execute_async_method('reset_conf', *params)
-
- @_set_doc_from_async_controller
- def set_options(self, params: Union[Mapping[str, Union[str, Sequence[str]]], Sequence[Tuple[str, Union[str, Sequence[str]]]]], reset: bool = False) -> None:
- self._execute_async_method('set_options', params, reset)
-
- @_set_doc_from_async_controller
- def get_hidden_service_conf(self, default: Any = UNDEFINED) -> Dict[str, Any]:
- return self._execute_async_method('get_hidden_service_conf', default)
-
- @_set_doc_from_async_controller
- def set_hidden_service_conf(self, conf: Mapping[str, Any]) -> None:
- self._execute_async_method('set_hidden_service_conf', conf)
-
- @_set_doc_from_async_controller
- def create_hidden_service(self, path: str, port: int, target_address: Optional[str] = None, target_port: Optional[int] = None, auth_type: Optional[str] = None, client_names: Optional[Sequence[str]] = None) -> 'stem.control.CreateHiddenServiceOutput':
- return self._execute_async_method('create_hidden_service', path, port, target_address, target_port, auth_type, client_names)
-
- @_set_doc_from_async_controller
- def remove_hidden_service(self, path: str, port: Optional[int] = None) -> bool:
- return self._execute_async_method('remove_hidden_service', path, port)
-
- @_set_doc_from_async_controller
- def list_ephemeral_hidden_services(self, default: Any = UNDEFINED, our_services: bool = True, detached: bool = False) -> Sequence[str]:
- return self._execute_async_method('list_ephemeral_hidden_services', default, our_services, detached)
-
- @_set_doc_from_async_controller
- def create_ephemeral_hidden_service(self, ports: Union[int, Sequence[int], Mapping[int, str]], key_type: str = 'NEW', key_content: str = 'BEST', discard_key: bool = False, detached: bool = False, await_publication: bool = False, timeout: Optional[float] = None, basic_auth: Optional[Mapping[str, str]] = None, max_streams: Optional[int] = None) -> stem.response.add_onion.AddOnionResponse:
- return self._execute_async_method('create_ephemeral_hidden_service', ports, key_type, key_content, discard_key, detached, await_publication, timeout, basic_auth, max_streams)
-
- @_set_doc_from_async_controller
- def remove_ephemeral_hidden_service(self, service_id: str) -> bool:
- return self._execute_async_method('remove_ephemeral_hidden_service', service_id)
-
- @_set_doc_from_async_controller
- def add_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]], *events: 'stem.control.EventType') -> None:
- self._execute_async_method('add_event_listener', listener, *events)
-
- @_set_doc_from_async_controller
- def remove_event_listener(self, listener: Callable[[stem.response.events.Event], Union[None, Awaitable[None]]]) -> None:
- self._execute_async_method('remove_event_listener', listener)
-
- @_set_doc_from_async_controller
- def is_caching_enabled(self) -> bool:
- return self._wrapped_instance.is_caching_enabled()
-
- @_set_doc_from_async_controller
- def set_caching(self, enabled: bool) -> None:
- self._wrapped_instance.set_caching(enabled)
-
- @_set_doc_from_async_controller
- def clear_cache(self) -> None:
- self._wrapped_instance.clear_cache()
-
- @_set_doc_from_async_controller
- def load_conf(self, configtext: str) -> None:
- self._execute_async_method('load_conf', configtext)
-
- @_set_doc_from_async_controller
- def save_conf(self, force: bool = False) -> None:
- return self._execute_async_method('save_conf', force)
-
- @_set_doc_from_async_controller
- def is_feature_enabled(self, feature: str) -> bool:
- return self._wrapped_instance.is_feature_enabled(feature)
-
- @_set_doc_from_async_controller
- def enable_feature(self, features: Union[str, Sequence[str]]) -> None:
- self._wrapped_instance.enable_feature(features)
-
- @_set_doc_from_async_controller
- def get_circuit(self, circuit_id: int, default: Any = UNDEFINED) -> stem.response.events.CircuitEvent:
- return self._execute_async_method('get_circuit', circuit_id, default)
-
- @_set_doc_from_async_controller
- def get_circuits(self, default: Any = UNDEFINED) -> List[stem.response.events.CircuitEvent]:
- return self._execute_async_method('get_circuits', default)
-
- @_set_doc_from_async_controller
- def new_circuit(self, path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
- return self._execute_async_method('new_circuit', path, purpose, await_build, timeout)
-
- @_set_doc_from_async_controller
- def extend_circuit(self, circuit_id: str = '0', path: Union[None, str, Sequence[str]] = None, purpose: str = 'general', await_build: bool = False, timeout: Optional[float] = None) -> str:
- return self._execute_async_method('extend_circuit', circuit_id, path, purpose, await_build, timeout)
-
- @_set_doc_from_async_controller
- def repurpose_circuit(self, circuit_id: str, purpose: str) -> None:
- self._execute_async_method('repurpose_circuit', circuit_id, purpose)
-
- @_set_doc_from_async_controller
- def close_circuit(self, circuit_id: str, flag: str = '') -> None:
- self._execute_async_method('close_circuit', circuit_id, flag)
-
- @_set_doc_from_async_controller
- def get_streams(self, default: Any = UNDEFINED) -> List[stem.response.events.StreamEvent]:
- return self._execute_async_method('get_streams', default)
-
- @_set_doc_from_async_controller
- def attach_stream(self, stream_id: str, circuit_id: str, exiting_hop: Optional[int] = None) -> None:
- self._execute_async_method('attach_stream', stream_id, circuit_id, exiting_hop)
-
- @_set_doc_from_async_controller
- def close_stream(self, stream_id: str, reason: stem.RelayEndReason = stem.RelayEndReason.MISC, flag: str = '') -> None:
- self._execute_async_method('close_stream', stream_id, reason, flag)
-
- @_set_doc_from_async_controller
- def signal(self, signal: stem.Signal) -> None:
- self._execute_async_method('signal', signal)
-
- @_set_doc_from_async_controller
- def is_newnym_available(self) -> bool:
- return self._wrapped_instance.is_newnym_available()
-
- @_set_doc_from_async_controller
- def get_newnym_wait(self) -> float:
- return self._wrapped_instance.get_newnym_wait()
-
- @_set_doc_from_async_controller
- def get_effective_rate(self, default: Any = UNDEFINED, burst: bool = False) -> int:
- return self._execute_async_method('get_effective_rate', default, burst)
-
- @_set_doc_from_async_controller
- def map_address(self, mapping: Mapping[str, str]) -> Dict[str, str]:
- return self._execute_async_method('map_address', mapping)
-
- @_set_doc_from_async_controller
- def drop_guards(self) -> None:
- self._execute_async_method('drop_guards')
-
- def __enter__(self) -> 'stem.control.Controller':
- return self
-
- def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
- self.close()
-
-
def _parse_circ_path(path: str) -> Sequence[Tuple[str, str]]:
"""
Parses a circuit path as a list of **(fingerprint, nickname)** tuples. Tor
@@ -4342,26 +4037,6 @@ async def _get_with_timeout(event_queue: asyncio.Queue, timeout: Optional[float]
time_left = None
try:
- return await asyncio.wait_for(event_queue.get(), timeout=time_left)
+ return await asyncio.wait_for(event_queue.get(), timeout = time_left)
except asyncio.TimeoutError:
raise stem.Timeout('Reached our %0.1f second timeout' % timeout)
-
-
-def _init_control_port(address: str, port: Union[int, str]) -> stem.socket.ControlPort:
- import stem.connection
-
- if not stem.util.connection.is_valid_ipv4_address(address):
- raise ValueError('Invalid IP address: %s' % address)
- elif port != 'default' and not stem.util.connection.is_valid_port(port):
- raise ValueError('Invalid port: %s' % port)
-
- if port == 'default':
- control_port = stem.connection._connection_for_default_port(address)
- else:
- control_port = stem.socket.ControlPort(address, int(port))
-
- return control_port
-
-
-def _init_control_socket_file(path: str) -> stem.socket.ControlSocketFile:
- return stem.socket.ControlSocketFile(path)
diff --git a/stem/descriptor/remote.py b/stem/descriptor/remote.py
index 942d81e9..3428f0d2 100644
--- a/stem/descriptor/remote.py
+++ b/stem/descriptor/remote.py
@@ -445,13 +445,13 @@ class Query(Synchronous):
loop = asyncio.get_running_loop()
self._downloader_task = loop.create_task(self._download_descriptors(self.retries, self.timeout))
- async def run(self, suppress: bool = False, close: bool = True) -> List['stem.descriptor.Descriptor']:
+ async def run(self, suppress: bool = False, stop: bool = True) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
:param suppress: avoids raising exceptions if **True**
- :param close: terminates the resources backing this query if **True**,
+ :param stop: terminates the resources backing this query if **True**,
further method calls will raise a RuntimeError
:returns: list for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
@@ -465,14 +465,14 @@ class Query(Synchronous):
* :class:`~stem.DownloadFailed` if our request fails
"""
- # TODO: We should replace our 'close' argument with a new API design prior
+ # TODO: We should replace our 'stop' argument with a new API design prior
# to release. Self-destructing this object by default for synchronous users
# is quite a step backward, but is acceptable as we iterate on this.
try:
return [desc async for desc in self._run(suppress)]
finally:
- if close:
+ if stop:
self._loop.call_soon_threadsafe(self._loop.stop)
async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
diff --git a/stem/interpreter/__init__.py b/stem/interpreter/__init__.py
index 370b9aa6..872e7f6f 100644
--- a/stem/interpreter/__init__.py
+++ b/stem/interpreter/__init__.py
@@ -127,7 +127,7 @@ def main() -> None:
async def handle_event(event_message: stem.response.ControlMessage) -> None:
print(format(str(event_message), *STANDARD_OUTPUT))
- controller._wrapped_instance._handle_event = handle_event # type: ignore
+ controller._handle_event = handle_event # type: ignore
if sys.stdout.isatty():
events = args.run_cmd.upper().split(' ', 1)[1]
diff --git a/stem/interpreter/commands.py b/stem/interpreter/commands.py
index 99f1219d..b04fcc85 100644
--- a/stem/interpreter/commands.py
+++ b/stem/interpreter/commands.py
@@ -128,7 +128,7 @@ class ControlInterpreter(code.InteractiveConsole):
# Intercept events our controller hears about at a pretty low level since
# the user will likely be requesting them by direct 'SETEVENTS' calls.
- handle_event_real = self._controller._wrapped_instance._handle_event
+ handle_event_real = self._controller._handle_event
async def handle_event_wrapper(event_message: stem.response.ControlMessage) -> None:
await handle_event_real(event_message)
@@ -139,7 +139,7 @@ class ControlInterpreter(code.InteractiveConsole):
# type check disabled due to https://github.com/python/mypy/issues/708
- self._controller._wrapped_instance._handle_event = handle_event_wrapper # type: ignore
+ self._controller._handle_event = handle_event_wrapper # type: ignore
def get_events(self, *event_types: stem.control.EventType) -> List[stem.response.events.Event]:
events = list(self._received_events)
diff --git a/stem/util/__init__.py b/stem/util/__init__.py
index d780a0de..de946fd9 100644
--- a/stem/util/__init__.py
+++ b/stem/util/__init__.py
@@ -13,7 +13,6 @@ import threading
import typing
import unittest.mock
-from concurrent.futures import Future
from types import TracebackType
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
@@ -211,6 +210,7 @@ class Synchronous(object):
self._no_op = Synchronous.is_asyncio_context()
if self._no_op:
+ self._loop = asyncio.get_running_loop()
self.__ainit__() # this is already an asyncio context
else:
# Run coroutines through our loop. This calls methods by name rather than
@@ -361,44 +361,3 @@ class Synchronous(object):
def __exit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]):
return self._run_async_method('__aexit__', exit_type, value, traceback)
-
-
-class AsyncClassWrapper:
- _loop: asyncio.AbstractEventLoop
- _loop_thread: threading.Thread
- _wrapped_instance: type
-
- def _init_async_class(self, async_class: Type, *args: Any, **kwargs: Any) -> Any:
- # The asynchronous class should be initialized in the thread where
- # its methods will be executed.
- if self._loop_thread != threading.current_thread():
- async def init():
- return async_class(*args, **kwargs)
-
- return asyncio.run_coroutine_threadsafe(init(), self._loop).result()
-
- return async_class(*args, **kwargs)
-
- def _call_async_method_soon(self, method_name: str, *args: Any, **kwargs: Any) -> Future:
- return asyncio.run_coroutine_threadsafe(
- getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- self._loop,
- )
-
- def _execute_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
- return self._call_async_method_soon(method_name, *args, **kwargs).result()
-
- def _execute_async_generator_method(self, method_name: str, *args: Any, **kwargs: Any) -> Iterator:
- async def convert_async_generator(generator: AsyncIterator) -> Iterator:
- return iter([d async for d in generator])
-
- return asyncio.run_coroutine_threadsafe(
- convert_async_generator(
- getattr(self._wrapped_instance, method_name)(*args, **kwargs),
- ),
- self._loop,
- ).result()
-
- def __del__(self) -> None:
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._loop_thread.join()
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 67133195..ac9f8b88 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -696,11 +696,14 @@ def async_test(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.new_event_loop()
+
try:
result = loop.run_until_complete(func(*args, **kwargs))
finally:
loop.close()
+
return result
+
return wrapper
diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 042d3939..bbca5e43 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -3,7 +3,6 @@ Integration tests for authenticating to the control socket via
stem.connection.authenticate* functions.
"""
-import asyncio
import os
import unittest
@@ -121,11 +120,8 @@ class TestAuthenticate(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- asyncio.run_coroutine_threadsafe(
- stem.connection.authenticate(controller._wrapped_instance, test.runner.CONTROL_PASSWORD, runner.get_chroot()),
- controller._loop,
- ).result()
+ async with await runner.get_tor_controller(False) as controller:
+ await stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
await test.runner.exercise_controller(self, controller)
@test.require.controller
@@ -276,7 +272,8 @@ class TestAuthenticate(unittest.TestCase):
await self._check_auth(auth_type, auth_value)
@test.require.controller
- def test_wrong_password_with_controller(self):
+ @async_test
+ async def test_wrong_password_with_controller(self):
"""
We ran into a race condition where providing the wrong password to the
Controller caused inconsistent responses. Checking for that...
@@ -290,7 +287,7 @@ class TestAuthenticate(unittest.TestCase):
self.skipTest('(requires only password auth)')
for i in range(10):
- with runner.get_tor_controller(False) as controller:
+ async with await runner.get_tor_controller(False) as controller:
with self.assertRaises(stem.connection.IncorrectPassword):
controller.authenticate('wrong_password')
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 7853d407..33b62ea7 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -7,7 +7,6 @@ import os
import shutil
import socket
import tempfile
-import threading
import time
import unittest
@@ -38,24 +37,25 @@ TEST_ROUTER_STATUS_ENTRY = None
class TestController(unittest.TestCase):
@test.require.only_run_once
@test.require.controller
- def test_missing_capabilities(self):
+ @async_test
+ async def test_missing_capabilities(self):
"""
Check to see if tor supports any events, signals, or features that we
don't.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- for event in controller.get_info('events/names').split():
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ for event in (await controller.get_info('events/names')).split():
if event not in EventType:
test.register_new_capability('Event', event)
- for signal in controller.get_info('signal/names').split():
+ for signal in (await controller.get_info('signal/names')).split():
if signal not in Signal:
test.register_new_capability('Signal', signal)
# new features should simply be added to enable_feature()'s docs
- for feature in controller.get_info('features/names').split():
+ for feature in (await controller.get_info('features/names')).split():
if feature not in ('EXTENDED_EVENTS', 'VERBOSE_NAMES'):
test.register_new_capability('Feature', feature)
@@ -88,7 +88,7 @@ class TestController(unittest.TestCase):
Checks that a notificiation listener is... well, notified of SIGHUPs.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
received_events = []
def status_listener(my_controller, state, timestamp):
@@ -97,7 +97,7 @@ class TestController(unittest.TestCase):
controller.add_status_listener(status_listener)
before = time.time()
- controller.signal(Signal.HUP)
+ await controller.signal(Signal.HUP)
# I really hate adding a sleep here, but signal() is non-blocking.
while len(received_events) == 0:
@@ -112,20 +112,21 @@ class TestController(unittest.TestCase):
state_controller, state_type, state_timestamp = received_events[0]
- self.assertEqual(controller._wrapped_instance, state_controller)
+ self.assertEqual(controller, state_controller)
self.assertEqual(State.RESET, state_type)
self.assertTrue(state_timestamp > before and state_timestamp < after)
- controller.reset_conf('__OwningControllerProcess')
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_event_handling(self):
+ @async_test
+ async def test_event_handling(self):
"""
Add a couple listeners for various events and make sure that they receive
them. Then remove the listeners.
"""
- event_notice1, event_notice2 = threading.Event(), threading.Event()
+ event_notice1, event_notice2 = asyncio.Event(), asyncio.Event()
event_buffer1, event_buffer2 = [], []
def listener1(event):
@@ -138,30 +139,30 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.add_event_listener(listener1, EventType.CONF_CHANGED)
- controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
+ async with await runner.get_tor_controller() as controller:
+ await controller.add_event_listener(listener1, EventType.CONF_CHANGED)
+ await controller.add_event_listener(listener2, EventType.CONF_CHANGED, EventType.DEBUG)
# The NodeFamily is a harmless option we can toggle
- controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
+ await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
# Wait for the event. Assert that we get it within 10 seconds
- event_notice1.wait(10)
+ await asyncio.wait_for(event_notice1.wait(), timeout = 10)
self.assertEqual(len(event_buffer1), 1)
event_notice1.clear()
- event_notice2.wait(10)
+ await asyncio.wait_for(event_notice2.wait(), timeout = 10)
self.assertTrue(len(event_buffer2) >= 1)
event_notice2.clear()
# Checking that a listener's no longer called after being removed.
- controller.remove_event_listener(listener2)
+ await controller.remove_event_listener(listener2)
buffer2_size = len(event_buffer2)
- controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
- event_notice1.wait(10)
+ await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
+ await asyncio.wait_for(event_notice1.wait(), timeout = 10)
self.assertEqual(len(event_buffer1), 2)
event_notice1.clear()
@@ -174,16 +175,17 @@ class TestController(unittest.TestCase):
self.assertTrue(isinstance(event, stem.response.events.ConfChangedEvent))
- controller.reset_conf('NodeFamily')
+ await controller.reset_conf('NodeFamily')
@test.require.controller
- def test_reattaching_listeners(self):
+ @async_test
+ async def test_reattaching_listeners(self):
"""
Checks that event listeners are re-attached when a controller disconnects
then reconnects to tor.
"""
- event_notice = threading.Event()
+ event_notice = asyncio.Event()
event_buffer = []
def listener(event):
@@ -192,79 +194,85 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.add_event_listener(listener, EventType.CONF_CHANGED)
+ async with await runner.get_tor_controller() as controller:
+ await controller.add_event_listener(listener, EventType.CONF_CHANGED)
# trigger an event
- controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
- event_notice.wait(4)
+ await controller.set_conf('NodeFamily', 'FD4CC275C5AA4D27A487C6CA29097900F85E2C33')
+ await asyncio.wait_for(event_notice.wait(), timeout = 4)
self.assertTrue(len(event_buffer) >= 1)
# disconnect, then reconnect and check that we get events again
- controller.close()
+ await controller.close()
event_notice.clear()
event_buffer = []
- controller.connect()
- controller.authenticate(password = test.runner.CONTROL_PASSWORD)
+ await controller.connect()
+ await controller.authenticate(password = test.runner.CONTROL_PASSWORD)
self.assertTrue(len(event_buffer) == 0)
- controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
+ await controller.set_conf('NodeFamily', 'A82F7EFDB570F6BC801805D0328D30A99403C401')
- event_notice.wait(4)
+ await asyncio.wait_for(event_notice.wait(), timeout = 4)
self.assertTrue(len(event_buffer) >= 1)
- controller.reset_conf('NodeFamily')
+ await controller.reset_conf('NodeFamily')
@test.require.controller
- def test_getinfo(self):
+ @async_test
+ async def test_getinfo(self):
"""
Exercises GETINFO with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# successful single query
torrc_path = runner.get_torrc_path()
- self.assertEqual(torrc_path, controller.get_info('config-file'))
- self.assertEqual(torrc_path, controller.get_info('config-file', 'ho hum'))
+ self.assertEqual(torrc_path, await controller.get_info('config-file'))
+ self.assertEqual(torrc_path, await controller.get_info('config-file', 'ho hum'))
expected = {'config-file': torrc_path}
- self.assertEqual(expected, controller.get_info(['config-file']))
- self.assertEqual(expected, controller.get_info(['config-file'], 'ho hum'))
+ self.assertEqual(expected, await controller.get_info(['config-file']))
+ self.assertEqual(expected, await controller.get_info(['config-file'], 'ho hum'))
# successful batch query, we don't know the values so just checking for
# the keys
getinfo_params = set(['version', 'config-file', 'config/names'])
- self.assertEqual(getinfo_params, set(controller.get_info(['version', 'config-file', 'config/names']).keys()))
+ self.assertEqual(getinfo_params, set((await controller.get_info(['version', 'config-file', 'config/names'])).keys()))
# non-existant option
- self.assertRaises(stem.ControllerError, controller.get_info, 'blarg')
- self.assertEqual('ho hum', controller.get_info('blarg', 'ho hum'))
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_info('blarg')
+
+ self.assertEqual('ho hum', await controller.get_info('blarg', 'ho hum'))
# empty input
- self.assertRaises(stem.ControllerError, controller.get_info, '')
- self.assertEqual('ho hum', controller.get_info('', 'ho hum'))
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_info('')
+
+ self.assertEqual('ho hum', await controller.get_info('', 'ho hum'))
- self.assertEqual({}, controller.get_info([]))
- self.assertEqual({}, controller.get_info([], {}))
+ self.assertEqual({}, await controller.get_info([]))
+ self.assertEqual({}, await controller.get_info([], {}))
@test.require.controller
- def test_getinfo_freshrelaydescs(self):
+ @async_test
+ async def test_getinfo_freshrelaydescs(self):
"""
Exercises 'GETINFO status/fresh-relay-descs'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- response = controller.get_info('status/fresh-relay-descs')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ response = await controller.get_info('status/fresh-relay-descs')
div = response.find('\nextra-info ')
- nickname = controller.get_conf('Nickname')
+ nickname = await controller.get_conf('Nickname')
if div == -1:
self.fail('GETINFO response should have both a server and extrainfo descriptor:\n%s' % response)
@@ -274,44 +282,47 @@ class TestController(unittest.TestCase):
self.assertEqual(nickname, server_desc.nickname)
self.assertEqual(nickname, extrainfo_desc.nickname)
- self.assertEqual(controller.get_info('address'), server_desc.address)
+ self.assertEqual(await controller.get_info('address'), server_desc.address)
self.assertEqual(test.runner.ORPORT, server_desc.or_port)
@test.require.controller
@test.require.online
- def test_getinfo_dir_status(self):
+ @async_test
+ async def test_getinfo_dir_status(self):
"""
Exercise 'GETINFO dir/status-vote/*'.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- consensus = controller.get_info('dir/status-vote/current/consensus')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ consensus = await controller.get_info('dir/status-vote/current/consensus')
self.assertTrue('moria1' in consensus, 'moria1 not found in the consensus')
if test.tor_version() >= stem.version.Version('0.4.3.1-alpha'):
- microdescs = controller.get_info('dir/status-vote/current/consensus-microdesc')
+ microdescs = await controller.get_info('dir/status-vote/current/consensus-microdesc')
self.assertTrue('moria1' in microdescs, 'moria1 not found in the microdescriptor consensus')
@test.require.controller
- def test_get_version(self):
+ @async_test
+ async def test_get_version(self):
"""
Test that the convenient method get_version() works.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- version = controller.get_version()
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ version = await controller.get_version()
self.assertTrue(isinstance(version, stem.version.Version))
self.assertEqual(version, test.tor_version())
@test.require.controller
- def test_get_exit_policy(self):
+ @async_test
+ async def test_get_exit_policy(self):
"""
Sanity test for get_exit_policy(). Our 'ExitRelay 0' torrc entry causes us
to have a simple reject-all policy.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertEqual(ExitPolicy('reject *:*'), controller.get_exit_policy())
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ self.assertEqual(ExitPolicy('reject *:*'), await controller.get_exit_policy())
@test.require.controller
@async_test
@@ -322,20 +333,21 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- controller.authenticate(test.runner.CONTROL_PASSWORD)
+ async with await runner.get_tor_controller(False) as controller:
+ await controller.authenticate(test.runner.CONTROL_PASSWORD)
await test.runner.exercise_controller(self, controller)
@test.require.controller
- def test_protocolinfo(self):
+ @async_test
+ async def test_protocolinfo(self):
"""
Test that the convenient method protocolinfo() works.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller(False) as controller:
- protocolinfo = controller.get_protocolinfo()
+ async with await runner.get_tor_controller(False) as controller:
+ protocolinfo = await controller.get_protocolinfo()
self.assertTrue(isinstance(protocolinfo, stem.response.protocolinfo.ProtocolInfoResponse))
# Doing a sanity test on the ProtocolInfoResponse instance returned.
@@ -355,14 +367,15 @@ class TestController(unittest.TestCase):
self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
@test.require.controller
- def test_getconf(self):
+ @async_test
+ async def test_getconf(self):
"""
Exercises GETCONF with valid and invalid queries.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
control_socket = controller.get_socket()
if isinstance(control_socket, stem.socket.ControlPort):
@@ -373,79 +386,89 @@ class TestController(unittest.TestCase):
config_key = 'ControlSocket'
# successful single query
- self.assertEqual(connection_value, controller.get_conf(config_key))
- self.assertEqual(connection_value, controller.get_conf(config_key, 'la-di-dah'))
+ self.assertEqual(connection_value, await controller.get_conf(config_key))
+ self.assertEqual(connection_value, await controller.get_conf(config_key, 'la-di-dah'))
# succeessful batch query
expected = {config_key: [connection_value]}
- self.assertEqual(expected, controller.get_conf_map([config_key]))
- self.assertEqual(expected, controller.get_conf_map([config_key], 'la-di-dah'))
+ self.assertEqual(expected, await controller.get_conf_map([config_key]))
+ self.assertEqual(expected, await controller.get_conf_map([config_key], 'la-di-dah'))
request_params = ['ControlPORT', 'dirport', 'datadirectory']
- reply_params = controller.get_conf_map(request_params, multiple=False).keys()
+ reply_params = (await controller.get_conf_map(request_params, multiple=False)).keys()
self.assertEqual(set(request_params), set(reply_params))
# queries an option that is unset
- self.assertEqual(None, controller.get_conf('HTTPSProxy'))
- self.assertEqual('la-di-dah', controller.get_conf('HTTPSProxy', 'la-di-dah'))
- self.assertEqual([], controller.get_conf('HTTPSProxy', [], multiple = True))
+ self.assertEqual(None, await controller.get_conf('HTTPSProxy'))
+ self.assertEqual('la-di-dah', await controller.get_conf('HTTPSProxy', 'la-di-dah'))
+ self.assertEqual([], await controller.get_conf('HTTPSProxy', [], multiple = True))
# non-existant option(s)
- self.assertRaises(stem.InvalidArguments, controller.get_conf, 'blarg')
- self.assertEqual('la-di-dah', controller.get_conf('blarg', 'la-di-dah'))
- self.assertRaises(stem.InvalidArguments, controller.get_conf_map, 'blarg')
- self.assertEqual({'blarg': 'la-di-dah'}, controller.get_conf_map('blarg', 'la-di-dah'))
- self.assertRaises(stem.InvalidRequest, controller.get_conf_map, ['blarg', 'huadf'], multiple = True)
- self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True))
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.get_conf('blarg')
+
+ self.assertEqual('la-di-dah', await controller.get_conf('blarg', 'la-di-dah'))
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.get_conf_map('blarg')
+
+ self.assertEqual({'blarg': 'la-di-dah'}, await controller.get_conf_map('blarg', 'la-di-dah'))
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.get_conf_map(['blarg', 'huadf'], multiple = True)
+
+ self.assertEqual({'erfusdj': 'la-di-dah', 'afiafj': 'la-di-dah'}, await controller.get_conf_map(['erfusdj', 'afiafj'], 'la-di-dah', multiple = True))
# multivalue configuration keys
nodefamilies = [('abc', 'xyz', 'pqrs'), ('mno', 'tuv', 'wxyz')]
- controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies]))
- self.assertEqual([','.join(n) for n in nodefamilies], controller.get_conf('nodefamily', multiple = True))
- controller.msg('RESETCONF NodeFamily')
+ await controller.msg('SETCONF %s' % ' '.join(['nodefamily="' + ','.join(x) + '"' for x in nodefamilies]))
+ self.assertEqual([','.join(n) for n in nodefamilies], await controller.get_conf('nodefamily', multiple = True))
+ await controller.msg('RESETCONF NodeFamily')
# empty input
- self.assertEqual(None, controller.get_conf(''))
- self.assertEqual({}, controller.get_conf_map([]))
- self.assertEqual({}, controller.get_conf_map(['']))
- self.assertEqual(None, controller.get_conf(' '))
- self.assertEqual({}, controller.get_conf_map([' ', ' ']))
+ self.assertEqual(None, await controller.get_conf(''))
+ self.assertEqual({}, await controller.get_conf_map([]))
+ self.assertEqual({}, await controller.get_conf_map(['']))
+ self.assertEqual(None, await controller.get_conf(' '))
+ self.assertEqual({}, await controller.get_conf_map([' ', ' ']))
- self.assertEqual('la-di-dah', controller.get_conf('', 'la-di-dah'))
- self.assertEqual({}, controller.get_conf_map('', 'la-di-dah'))
- self.assertEqual({}, controller.get_conf_map([], 'la-di-dah'))
+ self.assertEqual('la-di-dah', await controller.get_conf('', 'la-di-dah'))
+ self.assertEqual({}, await controller.get_conf_map('', 'la-di-dah'))
+ self.assertEqual({}, await controller.get_conf_map([], 'la-di-dah'))
@test.require.controller
- def test_is_set(self):
+ @async_test
+ async def test_is_set(self):
"""
Exercises our is_set() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- custom_options = controller._execute_async_method('_get_custom_options')
+ async with await runner.get_tor_controller() as controller:
+ custom_options = await controller._get_custom_options()
self.assertTrue('ControlPort' in custom_options or 'ControlSocket' in custom_options)
self.assertEqual('1', custom_options['DownloadExtraInfo'])
self.assertEqual('1112', custom_options['SocksPort'])
- self.assertTrue(controller.is_set('DownloadExtraInfo'))
- self.assertTrue(controller.is_set('SocksPort'))
- self.assertFalse(controller.is_set('CellStatistics'))
- self.assertFalse(controller.is_set('ConnLimit'))
+ self.assertTrue(await controller.is_set('DownloadExtraInfo'))
+ self.assertTrue(await controller.is_set('SocksPort'))
+ self.assertFalse(await controller.is_set('CellStatistics'))
+ self.assertFalse(await controller.is_set('ConnLimit'))
# check we update when setting and resetting values
- controller.set_conf('ConnLimit', '1005')
- self.assertTrue(controller.is_set('ConnLimit'))
+ await controller.set_conf('ConnLimit', '1005')
+ self.assertTrue(await controller.is_set('ConnLimit'))
- controller.reset_conf('ConnLimit')
- self.assertFalse(controller.is_set('ConnLimit'))
+ await controller.reset_conf('ConnLimit')
+ self.assertFalse(await controller.is_set('ConnLimit'))
@test.require.controller
- def test_hidden_services_conf(self):
+ @async_test
+ async def test_hidden_services_conf(self):
"""
Exercises the hidden service family of methods (get_hidden_service_conf,
set_hidden_service_conf, create_hidden_service, and remove_hidden_service).
@@ -459,16 +482,16 @@ class TestController(unittest.TestCase):
service3_path = os.path.join(test_dir, 'test_hidden_service3')
service4_path = os.path.join(test_dir, 'test_hidden_service4')
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
try:
# initially we shouldn't be running any hidden services
- self.assertEqual({}, controller.get_hidden_service_conf())
+ self.assertEqual({}, await controller.get_hidden_service_conf())
# try setting a blank config, shouldn't have any impact
- controller.set_hidden_service_conf({})
- self.assertEqual({}, controller.get_hidden_service_conf())
+ await controller.set_hidden_service_conf({})
+ self.assertEqual({}, await controller.get_hidden_service_conf())
# create a hidden service
@@ -491,58 +514,58 @@ class TestController(unittest.TestCase):
},
}
- controller.set_hidden_service_conf(initialconf)
- self.assertEqual(initialconf, controller.get_hidden_service_conf())
+ await controller.set_hidden_service_conf(initialconf)
+ self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add already existing services, with/without explicit target
- self.assertEqual(None, controller.create_hidden_service(service1_path, 8020))
- self.assertEqual(None, controller.create_hidden_service(service1_path, 8021, target_port = 8021))
- self.assertEqual(initialconf, controller.get_hidden_service_conf())
+ self.assertEqual(None, await controller.create_hidden_service(service1_path, 8020))
+ self.assertEqual(None, await controller.create_hidden_service(service1_path, 8021, target_port = 8021))
+ self.assertEqual(initialconf, await controller.get_hidden_service_conf())
# add a new service, with/without explicit target
hs_path = os.path.join(os.getcwd(), service3_path)
- hs_address1 = controller.create_hidden_service(hs_path, 8888).hostname
- hs_address2 = controller.create_hidden_service(hs_path, 8989, target_port = 8021).hostname
+ hs_address1 = (await controller.create_hidden_service(hs_path, 8888)).hostname
+ hs_address2 = (await controller.create_hidden_service(hs_path, 8989, target_port = 8021)).hostname
self.assertEqual(hs_address1, hs_address2)
self.assertTrue(hs_address1.endswith('.onion'))
- conf = controller.get_hidden_service_conf()
+ conf = await controller.get_hidden_service_conf()
self.assertEqual(3, len(conf))
self.assertEqual(2, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service, the service dir should still be there
- controller.remove_hidden_service(hs_path, 8888)
- self.assertEqual(3, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8888)
+ self.assertEqual(3, len(await controller.get_hidden_service_conf()))
# remove a service completely, it should now be gone
- controller.remove_hidden_service(hs_path, 8989)
- self.assertEqual(2, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8989)
+ self.assertEqual(2, len(await controller.get_hidden_service_conf()))
# add a new service, this time with client authentication
hs_path = os.path.join(os.getcwd(), service4_path)
- hs_attributes = controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2'])
+ hs_attributes = await controller.create_hidden_service(hs_path, 8888, auth_type = 'basic', client_names = ['c1', 'c2'])
self.assertEqual(2, len(hs_attributes.hostname.splitlines()))
self.assertEqual(2, len(hs_attributes.hostname_for_client))
self.assertTrue(hs_attributes.hostname_for_client['c1'].endswith('.onion'))
self.assertTrue(hs_attributes.hostname_for_client['c2'].endswith('.onion'))
- conf = controller.get_hidden_service_conf()
+ conf = await controller.get_hidden_service_conf()
self.assertEqual(3, len(conf))
self.assertEqual(1, len(conf[hs_path]['HiddenServicePort']))
# remove a hidden service
- controller.remove_hidden_service(hs_path, 8888)
- self.assertEqual(2, len(controller.get_hidden_service_conf()))
+ await controller.remove_hidden_service(hs_path, 8888)
+ self.assertEqual(2, len(await controller.get_hidden_service_conf()))
finally:
- controller.set_hidden_service_conf({}) # drop hidden services created during the test
+ await controller.set_hidden_service_conf({}) # drop hidden services created during the test
# clean up the hidden service directories created as part of this test
@@ -553,47 +576,50 @@ class TestController(unittest.TestCase):
pass
@test.require.controller
- def test_without_ephemeral_hidden_services(self):
+ @async_test
+ async def test_without_ephemeral_hidden_services(self):
"""
Exercises ephemeral hidden service methods when none are present.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertEqual([], controller.list_ephemeral_hidden_services())
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
- self.assertEqual(False, controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(False, await controller.remove_ephemeral_hidden_service('gfzprpioee3hoppz'))
@test.require.controller
- def test_with_invalid_ephemeral_hidden_service_port(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_with_invalid_ephemeral_hidden_service_port(self):
+ async with await test.runner.get_runner().get_tor_controller() as controller:
for ports in (4567890, [4567, 4567890], {4567: '-:4567'}):
- exc_msg = "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"
- self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, ports)
+ with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: Invalid VIRTPORT/TARGET"):
+ await controller.create_ephemeral_hidden_service(ports)
@test.require.controller
- def test_ephemeral_hidden_services_v2(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v2(self):
"""
Exercises creating v2 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024')
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual('RSA1024', response.private_key_type)
self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one
@@ -603,41 +629,42 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True)
- self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services())
+ response = await controller.create_ephemeral_hidden_service(4568, key_content = 'RSA1024', discard_key = True)
+ self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services())
self.assertEqual(None, response.private_key)
self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
- self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
- self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual(2, len(await controller.list_ephemeral_hidden_services()))
+ self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_ephemeral_hidden_services_v3(self):
+ @async_test
+ async def test_ephemeral_hidden_services_v3(self):
"""
Exercises creating v3 ephemeral hidden services.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'ED25519-V3')
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual('ED25519-V3', response.private_key_type)
self.assertEqual({}, response.client_auth)
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
# recreate the service with the same private key
- recreate_response = controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ recreate_response = await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertEqual(response.service_id, recreate_response.service_id)
# the response only includes the private key when making a new one
@@ -647,38 +674,40 @@ class TestController(unittest.TestCase):
# create a service where we never see the private key
- response = controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True)
- self.assertTrue(response.service_id in controller.list_ephemeral_hidden_services())
+ response = await controller.create_ephemeral_hidden_service(4568, key_content = 'ED25519-V3', discard_key = True)
+ self.assertTrue(response.service_id in await controller.list_ephemeral_hidden_services())
self.assertEqual(None, response.private_key)
self.assertEqual(None, response.private_key_type)
# other controllers shouldn't be able to see these hidden services
- with runner.get_tor_controller() as second_controller:
- self.assertEqual(2, len(controller.list_ephemeral_hidden_services()))
- self.assertEqual(0, len(second_controller.list_ephemeral_hidden_services()))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual(2, len(await controller.list_ephemeral_hidden_services()))
+ self.assertEqual(0, len(await second_controller.list_ephemeral_hidden_services()))
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth(self):
"""
Exercises creating ephemeral hidden services that uses basic authentication.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services())
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, key_content = 'RSA1024', basic_auth = {'alice': 'nKwfvVPmTNr2k2pG0pzV4g', 'bob': None})
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services())
self.assertTrue(response.private_key is not None)
self.assertEqual(['bob'], list(response.client_auth.keys())) # newly created credentials were only created for bob
# drop the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services())
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
@test.require.controller
- def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
+ @async_test
+ async def test_with_ephemeral_hidden_services_basic_auth_no_credentials(self):
"""
Exercises creating ephemeral hidden services when attempting to use basic
auth but not including any credentials.
@@ -686,12 +715,13 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- exc_msg = "ADD_ONION response didn't have an OK status: No auth clients specified"
- self.assertRaisesWith(stem.ProtocolError, exc_msg, controller.create_ephemeral_hidden_service, 4567, basic_auth = {})
+ async with await runner.get_tor_controller() as controller:
+ with self.assertRaisesWith(stem.ProtocolError, "ADD_ONION response didn't have an OK status: No auth clients specified"):
+ await controller.create_ephemeral_hidden_service(4567, basic_auth = {})
@test.require.controller
- def test_with_detached_ephemeral_hidden_services(self):
+ @async_test
+ async def test_with_detached_ephemeral_hidden_services(self):
"""
Exercises creating detached ephemeral hidden services and methods when
they're present.
@@ -699,34 +729,35 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- response = controller.create_ephemeral_hidden_service(4567, detached = True)
- self.assertEqual([], controller.list_ephemeral_hidden_services())
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
+ async with await runner.get_tor_controller() as controller:
+ response = await controller.create_ephemeral_hidden_service(4567, detached = True)
+ self.assertEqual([], await controller.list_ephemeral_hidden_services())
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# drop and recreate the service
- self.assertEqual(True, controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
- controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(True, await controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
+ await controller.create_ephemeral_hidden_service(4567, key_type = response.private_key_type, key_content = response.private_key, detached = True)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
# other controllers should be able to see this service, and drop it
- with runner.get_tor_controller() as second_controller:
- self.assertEqual([response.service_id], second_controller.list_ephemeral_hidden_services(detached = True))
- self.assertEqual(True, second_controller.remove_ephemeral_hidden_service(response.service_id))
- self.assertEqual([], controller.list_ephemeral_hidden_services(detached = True))
+ async with await runner.get_tor_controller() as second_controller:
+ self.assertEqual([response.service_id], await second_controller.list_ephemeral_hidden_services(detached = True))
+ self.assertEqual(True, await second_controller.remove_ephemeral_hidden_service(response.service_id))
+ self.assertEqual([], await controller.list_ephemeral_hidden_services(detached = True))
# recreate the service and confirms that it outlives this controller
- response = second_controller.create_ephemeral_hidden_service(4567, detached = True)
+ response = await second_controller.create_ephemeral_hidden_service(4567, detached = True)
- self.assertEqual([response.service_id], controller.list_ephemeral_hidden_services(detached = True))
- controller.remove_ephemeral_hidden_service(response.service_id)
+ self.assertEqual([response.service_id], await controller.list_ephemeral_hidden_services(detached = True))
+ await controller.remove_ephemeral_hidden_service(response.service_id)
@test.require.controller
- def test_rejecting_unanonymous_hidden_services_creation(self):
+ @async_test
+ async def test_rejecting_unanonymous_hidden_services_creation(self):
"""
Attempt to create a non-anonymous hidden service despite not setting
HiddenServiceSingleHopMode and HiddenServiceNonAnonymousMode.
@@ -734,11 +765,12 @@ class TestController(unittest.TestCase):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual('Tor is in anonymous hidden service mode', str(controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual('Tor is in anonymous hidden service mode', str(await controller.msg('ADD_ONION NEW:BEST Flags=NonAnonymous Port=4567')))
@test.require.controller
- def test_set_conf(self):
+ @async_test
+ async def test_set_conf(self):
"""
Exercises set_conf(), reset_conf(), and set_options() methods with valid
and invalid requests.
@@ -748,42 +780,42 @@ class TestController(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
try:
# successfully set a single option
- connlimit = int(controller.get_conf('ConnLimit'))
- controller.set_conf('connlimit', str(connlimit - 1))
- self.assertEqual(connlimit - 1, int(controller.get_conf('ConnLimit')))
+ connlimit = int(await controller.get_conf('ConnLimit'))
+ await controller.set_conf('connlimit', str(connlimit - 1))
+ self.assertEqual(connlimit - 1, int(await controller.get_conf('ConnLimit')))
# successfully set a single list option
exit_policy = ['accept *:7777', 'reject *:*']
- controller.set_conf('ExitPolicy', exit_policy)
- self.assertEqual(exit_policy, controller.get_conf('ExitPolicy', multiple = True))
+ await controller.set_conf('ExitPolicy', exit_policy)
+ self.assertEqual(exit_policy, await controller.get_conf('ExitPolicy', multiple = True))
# fail to set a single option
try:
- controller.set_conf('invalidkeyboo', 'abcde')
+ await controller.set_conf('invalidkeyboo', 'abcde')
self.fail()
except stem.InvalidArguments as exc:
self.assertEqual(['invalidkeyboo'], exc.arguments)
# resets configuration parameters
- controller.reset_conf('ConnLimit', 'ExitPolicy')
- self.assertEqual(connlimit, int(controller.get_conf('ConnLimit')))
- self.assertEqual(None, controller.get_conf('ExitPolicy'))
+ await controller.reset_conf('ConnLimit', 'ExitPolicy')
+ self.assertEqual(connlimit, int(await controller.get_conf('ConnLimit')))
+ self.assertEqual(None, await controller.get_conf('ExitPolicy'))
# successfully sets multiple config options
- controller.set_options({
+ await controller.set_options({
'connlimit': str(connlimit - 2),
'contactinfo': 'stem@testing',
})
- self.assertEqual(connlimit - 2, int(controller.get_conf('ConnLimit')))
- self.assertEqual('stem@testing', controller.get_conf('contactinfo'))
+ self.assertEqual(connlimit - 2, int(await controller.get_conf('ConnLimit')))
+ self.assertEqual('stem@testing', await controller.get_conf('contactinfo'))
# fail to set multiple config options
try:
- controller.set_options({
+ await controller.set_options({
'contactinfo': 'stem@testing',
'bombay': 'vadapav',
})
@@ -792,17 +824,17 @@ class TestController(unittest.TestCase):
self.assertEqual(['bombay'], exc.arguments)
# context-sensitive keys (the only retched things for which order matters)
- controller.set_options((
+ await controller.set_options((
('HiddenServiceDir', tmpdir),
('HiddenServicePort', '17234 127.0.0.1:17235'),
))
- self.assertEqual(tmpdir, controller.get_conf('HiddenServiceDir'))
- self.assertEqual('17234 127.0.0.1:17235', controller.get_conf('HiddenServicePort'))
+ self.assertEqual(tmpdir, await controller.get_conf('HiddenServiceDir'))
+ self.assertEqual('17234 127.0.0.1:17235', await controller.get_conf('HiddenServicePort'))
finally:
# reverts configuration changes
- controller.set_options((
+ await controller.set_options((
('ExitPolicy', 'reject *:*'),
('ConnLimit', None),
('ContactInfo', None),
@@ -811,47 +843,53 @@ class TestController(unittest.TestCase):
), reset = True)
@test.require.controller
- def test_set_conf_for_usebridges(self):
+ @async_test
+ async def test_set_conf_for_usebridges(self):
"""
Ensure we can set UseBridges=1 and also set a Bridge. This is a tor
regression check (:trac:`31945`).
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- orport = controller.get_conf('ORPort')
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ orport = await controller.get_conf('ORPort')
try:
- controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe
- controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')])
- self.assertEqual('127.0.0.1:9999', controller.get_conf('Bridge'))
+ await controller.set_conf('ORPort', '0') # ensure we're not a relay so UseBridges is usabe
+ await controller.set_options([('UseBridges', '1'), ('Bridge', '127.0.0.1:9999')])
+ self.assertEqual('127.0.0.1:9999', await controller.get_conf('Bridge'))
finally:
# reverts configuration changes
- controller.set_options((
+ await controller.set_options((
('ORPort', orport),
('UseBridges', None),
('Bridge', None),
), reset = True)
@test.require.controller
- def test_set_conf_when_immutable(self):
+ @async_test
+ async def test_set_conf_when_immutable(self):
"""
Issue a SETCONF for tor options that cannot be changed while running.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
- self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running", controller.set_conf, 'DisableAllSwap', '1')
- self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running", controller.set_options, {'User': 'atagar', 'DisableAllSwap': '1'})
+ async with await test.runner.get_runner().get_tor_controller() as controller:
+ with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap cannot be changed while tor's running"):
+ await controller.set_conf('DisableAllSwap', '1')
+
+ with self.assertRaisesWith(stem.InvalidArguments, "DisableAllSwap, User cannot be changed while tor's running"):
+ await controller.set_options({'User': 'atagar', 'DisableAllSwap': '1'})
@test.require.controller
- def test_loadconf(self):
+ @async_test
+ async def test_loadconf(self):
"""
Exercises Controller.load_conf with valid and invalid requests.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
@@ -863,98 +901,105 @@ class TestController(unittest.TestCase):
# ("/home/atagar/Desktop/stem/test/data"->"/home/atagar/.tor") is not
# allowed.
- self.assertRaises(stem.InvalidRequest, controller.load_conf, 'ContactInfo confloaded')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.load_conf('ContactInfo confloaded')
try:
- controller.load_conf('Blahblah blah')
+ await controller.load_conf('Blahblah blah')
self.fail()
except stem.InvalidArguments as exc:
self.assertEqual(['Blahblah'], exc.arguments)
# valid config
- controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n')
- self.assertEqual('confloaded', controller.get_conf('ContactInfo'))
+ await controller.load_conf(runner.get_torrc_contents() + '\nContactInfo confloaded\n')
+ self.assertEqual('confloaded', await controller.get_conf('ContactInfo'))
finally:
# reload original valid config
- controller.load_conf(oldconf)
- controller.reset_conf('__OwningControllerProcess')
+ await controller.load_conf(oldconf)
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_saveconf(self):
+ @async_test
+ async def test_saveconf(self):
runner = test.runner.get_runner()
# only testing for success, since we need to run out of disk space to test
# for failure
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
oldconf = runner.get_torrc_contents()
try:
- controller.set_conf('ContactInfo', 'confsaved')
- controller.save_conf()
+ await controller.set_conf('ContactInfo', 'confsaved')
+ await controller.save_conf()
with open(runner.get_torrc_path()) as torrcfile:
self.assertTrue('\nContactInfo confsaved\n' in torrcfile.read())
finally:
- controller.load_conf(oldconf)
- controller.save_conf()
- controller.reset_conf('__OwningControllerProcess')
+ await controller.load_conf(oldconf)
+ await controller.save_conf()
+ await controller.reset_conf('__OwningControllerProcess')
@test.require.controller
- def test_get_ports(self):
+ @async_test
+ async def test_get_ports(self):
"""
Test Controller.get_ports against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual([test.runner.ORPORT], controller.get_ports(Listener.OR))
- self.assertEqual([], controller.get_ports(Listener.DIR))
- self.assertEqual([test.runner.SOCKS_PORT], controller.get_ports(Listener.SOCKS))
- self.assertEqual([], controller.get_ports(Listener.TRANS))
- self.assertEqual([], controller.get_ports(Listener.NATD))
- self.assertEqual([], controller.get_ports(Listener.DNS))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual([test.runner.ORPORT], await controller.get_ports(Listener.OR))
+ self.assertEqual([], await controller.get_ports(Listener.DIR))
+ self.assertEqual([test.runner.SOCKS_PORT], await controller.get_ports(Listener.SOCKS))
+ self.assertEqual([], await controller.get_ports(Listener.TRANS))
+ self.assertEqual([], await controller.get_ports(Listener.NATD))
+ self.assertEqual([], await controller.get_ports(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options():
- self.assertEqual([test.runner.CONTROL_PORT], controller.get_ports(Listener.CONTROL))
+ self.assertEqual([test.runner.CONTROL_PORT], await controller.get_ports(Listener.CONTROL))
else:
- self.assertEqual([], controller.get_ports(Listener.CONTROL))
+ self.assertEqual([], await controller.get_ports(Listener.CONTROL))
@test.require.controller
- def test_get_listeners(self):
+ @async_test
+ async def test_get_listeners(self):
"""
Test Controller.get_listeners against a running tor instance.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- self.assertEqual([('0.0.0.0', test.runner.ORPORT)], controller.get_listeners(Listener.OR))
- self.assertEqual([], controller.get_listeners(Listener.DIR))
- self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], controller.get_listeners(Listener.SOCKS))
- self.assertEqual([], controller.get_listeners(Listener.TRANS))
- self.assertEqual([], controller.get_listeners(Listener.NATD))
- self.assertEqual([], controller.get_listeners(Listener.DNS))
+ async with await runner.get_tor_controller() as controller:
+ self.assertEqual([('0.0.0.0', test.runner.ORPORT)], await controller.get_listeners(Listener.OR))
+ self.assertEqual([], await controller.get_listeners(Listener.DIR))
+ self.assertEqual([('127.0.0.1', test.runner.SOCKS_PORT)], await controller.get_listeners(Listener.SOCKS))
+ self.assertEqual([], await controller.get_listeners(Listener.TRANS))
+ self.assertEqual([], await controller.get_listeners(Listener.NATD))
+ self.assertEqual([], await controller.get_listeners(Listener.DNS))
if test.runner.Torrc.PORT in runner.get_options():
- self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], controller.get_listeners(Listener.CONTROL))
+ self.assertEqual([('127.0.0.1', test.runner.CONTROL_PORT)], await controller.get_listeners(Listener.CONTROL))
else:
- self.assertEqual([], controller.get_listeners(Listener.CONTROL))
+ self.assertEqual([], await controller.get_listeners(Listener.CONTROL))
@test.require.controller
@test.require.online
@test.require.version(stem.version.Version('0.1.2.2-alpha'))
- def test_enable_feature(self):
+ @async_test
+ async def test_enable_feature(self):
"""
Test Controller.enable_feature with valid and invalid inputs.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
self.assertTrue(controller.is_feature_enabled('VERBOSE_NAMES'))
- self.assertRaises(stem.InvalidArguments, controller.enable_feature, ['NOT', 'A', 'FEATURE'])
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.enable_feature(['NOT', 'A', 'FEATURE'])
try:
controller.enable_feature(['NOT', 'A', 'FEATURE'])
@@ -964,58 +1009,70 @@ class TestController(unittest.TestCase):
self.fail()
@test.require.controller
- def test_signal(self):
+ @async_test
+ async def test_signal(self):
"""
Test controller.signal with valid and invalid signals.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# valid signal
- controller.signal('CLEARDNSCACHE')
+ await controller.signal('CLEARDNSCACHE')
# invalid signals
- self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR')
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.signal('FOOBAR')
@test.require.controller
- def test_newnym_availability(self):
+ @async_test
+ async def test_newnym_availability(self):
"""
Test the is_newnym_available and get_newnym_wait methods.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
self.assertEqual(True, controller.is_newnym_available())
self.assertEqual(0.0, controller.get_newnym_wait())
- controller.signal(stem.Signal.NEWNYM)
+ await controller.signal(stem.Signal.NEWNYM)
self.assertEqual(False, controller.is_newnym_available())
self.assertTrue(controller.get_newnym_wait() > 9.0)
@test.require.controller
@test.require.online
- def test_extendcircuit(self):
- with test.runner.get_runner().get_tor_controller() as controller:
+ @async_test
+ async def test_extendcircuit(self):
+ async with await test.runner.get_runner().get_tor_controller() as controller:
circuit_id = controller.extend_circuit('0')
# check if our circuit was created
+
self.assertNotEqual(None, controller.get_circuit(circuit_id, None))
circuit_id = controller.new_circuit()
self.assertNotEqual(None, controller.get_circuit(circuit_id, None))
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, 'foo')
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#')
- self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('foo')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.extend_circuit('0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
@test.require.controller
@test.require.online
- def test_repurpose_circuit(self):
+ @async_test
+ async def test_repurpose_circuit(self):
"""
Tests Controller.repurpose_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
circ_id = controller.new_circuit()
controller.repurpose_circuit(circ_id, 'CONTROLLER')
circuit = controller.get_circuit(circ_id)
@@ -1025,38 +1082,47 @@ class TestController(unittest.TestCase):
circuit = controller.get_circuit(circ_id)
self.assertTrue(circuit.purpose == 'GENERAL')
- self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, 'f934h9f3h4', 'fooo')
- self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, '4', 'fooo')
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.repurpose_circuit('f934h9f3h4', 'fooo')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.repurpose_circuit('4', 'fooo')
@test.require.controller
@test.require.online
- def test_close_circuit(self):
+ @async_test
+ async def test_close_circuit(self):
"""
Tests Controller.close_circuit with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id)
- circuit_output = controller.get_info('circuit-status')
+ circuit_output = await controller.get_info('circuit-status')
circ = [x.split()[0] for x in circuit_output.splitlines()]
self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit()
controller.close_circuit(circuit_id, 'IfUnused')
- circuit_output = controller.get_info('circuit-status')
+ circuit_output = await controller.get_info('circuit-status')
circ = [x.split()[0] for x in circuit_output.splitlines()]
self.assertFalse(circuit_id in circ)
circuit_id = controller.new_circuit()
- self.assertRaises(stem.InvalidArguments, controller.close_circuit, circuit_id + '1024')
- self.assertRaises(stem.InvalidRequest, controller.close_circuit, '')
+
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.close_circuit(circuit_id + '1024')
+
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.close_circuit('')
@test.require.controller
@test.require.online
- def test_get_streams(self):
+ @async_test
+ async def test_get_streams(self):
"""
Tests Controller.get_streams().
"""
@@ -1065,9 +1131,11 @@ class TestController(unittest.TestCase):
port = 443
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+
+ async with await runner.get_tor_controller() as controller:
# we only need one proxy port, so take the first
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1081,17 +1149,18 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_close_stream(self):
+ @async_test
+ async def test_close_stream(self):
"""
Tests Controller.close_stream with valid and invalid input.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# use the first socks listener
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1116,16 +1185,18 @@ class TestController(unittest.TestCase):
# unknown stream
- self.assertRaises(stem.InvalidArguments, controller.close_stream, 'blarg')
+ with self.assertRaises(stem.InvalidArguments):
+ await controller.close_stream('blarg')
@test.require.controller
@test.require.online
- def test_mapaddress(self):
+ @async_test
+ async def test_mapaddress(self):
self.skipTest('(https://trac.torproject.org/projects/tor/ticket/25611)')
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
- controller.map_address({'1.2.1.2': 'ifconfig.me'})
+ async with await runner.get_tor_controller() as controller:
+ await controller.map_address({'1.2.1.2': 'ifconfig.me'})
s = None
response = None
@@ -1136,7 +1207,7 @@ class TestController(unittest.TestCase):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(30)
- s.connect(('127.0.0.1', int(controller.get_conf('SocksPort'))))
+ s.connect(('127.0.0.1', int(await controller.get_conf('SocksPort'))))
test.network.negotiate_socks(s, '1.2.1.2', 80)
s.sendall(stem.util.str_tools._to_bytes(test.network.IP_REQUEST)) # make the http request for the ip address
response = s.recv(1000)
@@ -1158,14 +1229,15 @@ class TestController(unittest.TestCase):
self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)), "'%s' isn't an address" % ip_addr)
@test.require.controller
- def test_mapaddress_offline(self):
+ @async_test
+ async def test_mapaddress_offline(self):
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# try mapping one element, ensuring results are as expected
map1 = {'1.2.1.2': 'ifconfig.me'}
- x = controller.map_address(map1)
+ x = await controller.map_address(map1)
self.assertEqual(x, map1)
# try mapping two elements, ensuring results are as expected
@@ -1173,17 +1245,18 @@ class TestController(unittest.TestCase):
map2 = {'1.2.3.4': 'foobar.example.com',
'1.2.3.5': 'barfuzz.example.com'}
- x = controller.map_address(map2)
+ x = await controller.map_address(map2)
self.assertEqual(x, map2)
# try mapping zero elements
- self.assertRaises(stem.InvalidRequest, controller.map_address, {})
+ with self.assertRaises(stem.InvalidRequest):
+ await controller.map_address({})
# try a virtual mapping to IPv4, the default virtualaddressrange is 127.192.0.0/10
map3 = {'0.0.0.0': 'quux'}
- x = controller.map_address(map3)
+ x = await controller.map_address(map3)
self.assertEquals(len(x), 1)
addr1, target = list(x.items())[0]
@@ -1193,15 +1266,15 @@ class TestController(unittest.TestCase):
# try a virtual mapping to IPv6, the default IPv6 virtualaddressrange is FE80::/10
map4 = {'::': 'quibble'}
- x = controller.map_address(map4)
+ x = await controller.map_address(map4)
self.assertEquals(len(x), 1)
addr2, target = list(x.items())[0]
self.assertTrue(addr2.startswith('[fe'), '%s did not start with [fe.' % addr2)
self.assertEquals(target, 'quibble')
- def address_mappings(addr_type):
- response = controller.get_info(['address-mappings/%s' % addr_type])
+ async def address_mappings(addr_type):
+ response = await controller.get_info(['address-mappings/%s' % addr_type])
result = {}
for line in response['address-mappings/%s' % addr_type].splitlines():
@@ -1218,7 +1291,7 @@ class TestController(unittest.TestCase):
'1.2.3.5': 'barfuzz.example.com',
addr1: 'quux',
addr2: 'quibble',
- }, address_mappings('control'))
+ }, await address_mappings('control'))
# ask for a list of all the address mappings
@@ -1228,29 +1301,40 @@ class TestController(unittest.TestCase):
'1.2.3.5': 'barfuzz.example.com',
addr1: 'quux',
addr2: 'quibble',
- }, address_mappings('all'))
+ }, await address_mappings('all'))
# Now ask for a list of only the mappings configured with the
# configuration. Ours should not be there.
- self.assertEquals({}, address_mappings('config'))
+ self.assertEquals({}, await address_mappings('config'))
@test.require.controller
@test.require.online
- def test_get_microdescriptor(self):
+ @async_test
+ async def test_get_microdescriptor(self):
"""
Basic checks for get_microdescriptor().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_microdescriptor, '')
- self.assertRaises(ValueError, controller.get_microdescriptor, 5)
- self.assertRaises(ValueError, controller.get_microdescriptor, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_microdescriptor('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_microdescriptor, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_microdescriptor, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_microdescriptor('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_microdescriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1261,7 +1345,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_microdescriptors(self):
+ @async_test
+ async def test_get_microdescriptors(self):
"""
Fetches a few descriptors via the get_microdescriptors() method.
"""
@@ -1271,7 +1356,7 @@ class TestController(unittest.TestCase):
if not os.path.exists(runner.get_test_dir('cached-microdescs')):
self.skipTest('(no cached microdescriptors)')
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_microdescriptors():
@@ -1283,22 +1368,33 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptor(self):
+ @async_test
+ async def test_get_server_descriptor(self):
"""
Basic checks for get_server_descriptor().
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_server_descriptor, '')
- self.assertRaises(ValueError, controller.get_server_descriptor, 5)
- self.assertRaises(ValueError, controller.get_server_descriptor, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_server_descriptor('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_server_descriptor, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_server_descriptor, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_server_descriptor('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_server_descriptor('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1309,14 +1405,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_server_descriptors(self):
+ @async_test
+ async def test_get_server_descriptors(self):
"""
Fetches a few descriptors via the get_server_descriptors() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_server_descriptors():
@@ -1334,20 +1431,31 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_status(self):
+ @async_test
+ async def test_get_network_status(self):
"""
Basic checks for get_network_status().
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# we should balk at invalid content
- self.assertRaises(ValueError, controller.get_network_status, '')
- self.assertRaises(ValueError, controller.get_network_status, 5)
- self.assertRaises(ValueError, controller.get_network_status, 'z' * 30)
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status('')
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status(5)
+
+ with self.assertRaises(ValueError):
+ await controller.get_network_status('z' * 30)
# try with a relay that doesn't exist
- self.assertRaises(stem.ControllerError, controller.get_network_status, 'blargg')
- self.assertRaises(stem.ControllerError, controller.get_network_status, '5' * 40)
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_network_status('blargg')
+
+ with self.assertRaises(stem.ControllerError):
+ await controller.get_network_status('5' * 40)
test_relay = self._get_router_status_entry(controller)
@@ -1358,14 +1466,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_network_statuses(self):
+ @async_test
+ async def test_get_network_statuses(self):
"""
Fetches a few descriptors via the get_network_statuses() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
count = 0
for desc in controller.get_network_statuses():
@@ -1381,14 +1490,15 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_hidden_service_descriptor(self):
+ @async_test
+ async def test_get_hidden_service_descriptor(self):
"""
Fetches a few descriptors via the get_hidden_service_descriptor() method.
"""
runner = test.runner.get_runner()
- with runner.get_tor_controller() as controller:
+ async with await runner.get_tor_controller() as controller:
# fetch the descriptor for DuckDuckGo
desc = controller.get_hidden_service_descriptor('3g2upl4pq6kufc4m.onion')
@@ -1396,8 +1506,8 @@ class TestController(unittest.TestCase):
# try to fetch something that doesn't exist
- exc_msg = 'No running hidden service at m4cfuk6qp4lpu2g3.onion'
- self.assertRaisesWith(stem.DescriptorUnavailable, exc_msg, controller.get_hidden_service_descriptor, 'm4cfuk6qp4lpu2g3')
+ with self.assertRaisesWith(stem.DescriptorUnavailable, 'No running hidden service at m4cfuk6qp4lpu2g3.onion'):
+ await controller.get_hidden_service_descriptor('m4cfuk6qp4lpu2g3')
# ... but shouldn't fail if we have a default argument or aren't awaiting the descriptor
@@ -1406,7 +1516,8 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_attachstream(self):
+ @async_test
+ async def test_attachstream(self):
host = socket.gethostbyname('www.torproject.org')
port = 80
@@ -1416,15 +1527,16 @@ class TestController(unittest.TestCase):
if stream.status == 'NEW' and circuit_id:
controller.attach_stream(stream.id, circuit_id)
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
# try 10 times to build a circuit we can connect through
+
for i in range(10):
- controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
- controller.set_conf('__LeaveStreamsUnattached', '1')
+ await controller.add_event_listener(handle_streamcreated, stem.control.EventType.STREAM)
+ await controller.set_conf('__LeaveStreamsUnattached', '1')
try:
circuit_id = controller.new_circuit(await_build = True)
- socks_listener = controller.get_listeners(Listener.SOCKS)[0]
+ socks_listener = (await controller.get_listeners(Listener.SOCKS))[0]
with test.network.Socks(socks_listener) as s:
s.settimeout(30)
@@ -1435,7 +1547,7 @@ class TestController(unittest.TestCase):
continue
finally:
controller.remove_event_listener(handle_streamcreated)
- controller.reset_conf('__LeaveStreamsUnattached')
+ await controller.reset_conf('__LeaveStreamsUnattached')
our_stream = [stream for stream in streams if stream.target_address == host][0]
@@ -1446,38 +1558,40 @@ class TestController(unittest.TestCase):
@test.require.controller
@test.require.online
- def test_get_circuits(self):
+ @async_test
+ async def test_get_circuits(self):
"""
Fetches circuits via the get_circuits() method.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
new_circ = controller.new_circuit()
circuits = controller.get_circuits()
self.assertTrue(new_circ in [circ.id for circ in circuits])
@test.require.controller
- def test_transition_to_relay(self):
+ @async_test
+ async def test_transition_to_relay(self):
"""
Transitions Tor to turn into a relay, then back to a client. This helps to
catch transition issues such as the one cited in :trac:`14901`.
"""
- with test.runner.get_runner().get_tor_controller() as controller:
+ async with await test.runner.get_runner().get_tor_controller() as controller:
try:
- controller.reset_conf('OrPort', 'DisableNetwork')
- self.assertEqual(None, controller.get_conf('OrPort'))
+ await controller.reset_conf('OrPort', 'DisableNetwork')
+ self.assertEqual(None, await controller.get_conf('OrPort'))
# DisableNetwork ensures no port is actually opened
- controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'})
+ await controller.set_options({'OrPort': '9090', 'DisableNetwork': '1'})
# TODO once tor 0.2.7.x exists, test that we can generate a descriptor on demand.
- self.assertEqual('9090', controller.get_conf('OrPort'))
- controller.reset_conf('OrPort', 'DisableNetwork')
- self.assertEqual(None, controller.get_conf('OrPort'))
+ self.assertEqual('9090', await controller.get_conf('OrPort'))
+ await controller.reset_conf('OrPort', 'DisableNetwork')
+ self.assertEqual(None, await controller.get_conf('OrPort'))
finally:
- controller.set_conf('OrPort', str(test.runner.ORPORT))
+ await controller.set_conf('OrPort', str(test.runner.ORPORT))
def _get_router_status_entry(self, controller):
"""
diff --git a/test/runner.py b/test/runner.py
index b132b8f5..a02fb769 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -88,7 +88,7 @@ class TorInaccessable(Exception):
async def exercise_controller(test_case, controller):
- """with await test.runner.get_runner().get_tor_socket
+ """
Checks that we can now use the socket by issuing a 'GETINFO config-file'
query. Controller can be either a :class:`stem.socket.ControlSocket` or
:class:`stem.control.BaseController`.
@@ -102,11 +102,10 @@ async def exercise_controller(test_case, controller):
if isinstance(controller, stem.socket.ControlSocket):
await controller.send('GETINFO config-file')
+
config_file_response = await controller.recv()
else:
- config_file_response = controller.msg('GETINFO config-file')
- if asyncio.iscoroutine(config_file_response):
- config_file_response = await config_file_response
+ config_file_response = await controller.msg('GETINFO config-file')
test_case.assertEqual('config-file=%s\nOK' % torrc_path, str(config_file_response))
@@ -261,9 +260,19 @@ class Runner(object):
stem.socket.recv_message = _chroot_recv_message
if self.is_accessible():
- self._owner_controller = stem.control.Controller(self._get_unconnected_socket(), False)
- self._owner_controller.connect()
- self._authenticate_controller(self._owner_controller)
+ # TODO: refactor so owner controller is less convoluted
+
+ loop = asyncio.new_event_loop()
+
+ self._owner_controller_thread = threading.Thread(
+ name = 'owning_controller',
+ target = loop.run_forever,
+ daemon = True,
+ )
+
+ self._owner_controller_thread.start()
+
+ self._owner_controller = asyncio.run_coroutine_threadsafe(self.get_tor_controller(True), loop).result()
if test.Target.RELATIVE in self.attribute_targets:
os.chdir(original_cwd) # revert our cwd back to normal
@@ -279,7 +288,9 @@ class Runner(object):
println('Shutting down tor... ', STATUS, NO_NL)
if self._owner_controller:
- self._owner_controller.close()
+ asyncio.run_coroutine_threadsafe(self._owner_controller.close(), self._owner_controller._loop).result()
+ self._owner_controller._loop.call_soon_threadsafe(self._owner_controller._loop.stop)
+ self._owner_controller_thread.join()
self._owner_controller = None
if self._tor_process:
@@ -445,16 +456,6 @@ class Runner(object):
tor_process = self._get('_tor_process')
return tor_process.pid
- def _get_unconnected_socket(self):
- if Torrc.PORT in self._custom_opts:
- control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
- elif Torrc.SOCKET in self._custom_opts:
- control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
- else:
- raise TorInaccessable('Unable to connect to tor')
-
- return control_socket
-
async def get_tor_socket(self, authenticate = True):
"""
Provides a socket connected to our tor test instance.
@@ -466,7 +467,13 @@ class Runner(object):
:raises: :class:`test.runner.TorInaccessable` if tor can't be connected to
"""
- control_socket = self._get_unconnected_socket()
+ if Torrc.PORT in self._custom_opts:
+ control_socket = stem.socket.ControlPort(port = CONTROL_PORT)
+ elif Torrc.SOCKET in self._custom_opts:
+ control_socket = stem.socket.ControlSocketFile(CONTROL_SOCKET_PATH)
+ else:
+ raise TorInaccessable('Unable to connect to tor')
+
await control_socket.connect()
if authenticate:
@@ -474,10 +481,7 @@ class Runner(object):
return control_socket
- def _authenticate_controller(self, controller):
- controller.authenticate(password=CONTROL_PASSWORD, chroot_path=self.get_chroot())
-
- def get_tor_controller(self, authenticate = True):
+ async def get_tor_controller(self, authenticate = True):
"""
Provides a controller connected to our tor test instance.
@@ -488,19 +492,11 @@ class Runner(object):
:raises: :class: `test.runner.TorInaccessable` if tor can't be connected to
"""
- loop = asyncio.new_event_loop()
- loop_thread = threading.Thread(target = loop.run_forever, name = 'get_tor_controller')
- loop_thread.setDaemon(True)
- loop_thread.start()
-
- async def wrapped_get_controller():
- control_socket = await self.get_tor_socket(False)
- return stem.control.Controller(control_socket)
-
- controller = asyncio.run_coroutine_threadsafe(wrapped_get_controller(), loop).result()
+ control_socket = await self.get_tor_socket(False)
+ controller = stem.control.Controller(control_socket)
if authenticate:
- self._authenticate_controller(controller)
+ await controller.authenticate(password = CONTROL_PASSWORD, chroot_path = self.get_chroot())
return controller
diff --git a/test/settings.cfg b/test/settings.cfg
index 70bdd069..ef543a18 100644
--- a/test/settings.cfg
+++ b/test/settings.cfg
@@ -235,6 +235,14 @@ mypy.ignore stem/descriptor/remote.py => Return type "Coroutine[Any, Any, None]"
mypy.ignore * => "Descriptor" has no attribute "*
+# Metaprogramming false positive for our close method.
+
+mypy.ignore stem/control.py => Return type "Coroutine[Any, Any, None]" of "close" *
+
+# Interpreter uses a synchronous controller, which can cause false positives.
+
+mypy.ignore stem/interpreter/commands.py => "Coroutine[Any, Any, ControlMessage]" has no attribute "*
+
# Test modules we want to run. Modules are roughly ordered by the dependencies
# so the lowest level tests come first. This is because a problem in say,
# controller message parsing, will cause all higher level tests to fail too.
diff --git a/test/unit/control/controller.py b/test/unit/control/controller.py
index 84fcdfed..6c33da6b 100644
--- a/test/unit/control/controller.py
+++ b/test/unit/control/controller.py
@@ -21,11 +21,7 @@ from stem import ControllerError, DescriptorUnavailable, InvalidArguments, Inval
from stem.control import MALFORMED_EVENTS, _parse_circ_path, Listener, Controller, EventType
from stem.response import ControlMessage
from stem.exit_policy import ExitPolicy
-from stem.util.test_tools import (
- async_test,
- coro_func_raising_exc,
- coro_func_returning_value,
-)
+from stem.util.test_tools import coro_func_raising_exc, coro_func_returning_value
NS_DESC = 'r %s %s u5lTXJKGsLKufRLnSyVqT7TdGYw 2012-12-30 22:02:49 77.223.43.54 9001 0\ns Fast Named Running Stable Valid\nw Bandwidth=75'
TEST_TIMESTAMP = 12345
@@ -44,7 +40,6 @@ class TestControl(unittest.TestCase):
with patch('stem.control.BaseController.msg', Mock(side_effect = coro_func_returning_value(None))):
self.controller = Controller(socket)
- self.async_controller = self.controller._wrapped_instance
self.circ_listener = Mock()
self.controller.add_event_listener(self.circ_listener, EventType.CIRC)
@@ -69,24 +64,23 @@ class TestControl(unittest.TestCase):
for event in stem.control.EventType:
self.assertTrue(stem.control.event_description(event) is not None)
- @patch('stem.control.AsyncController.msg')
+ @patch('stem.control.Controller.msg')
def test_get_info(self, msg_mock):
message = ControlMessage.from_str('250-hello=hi right back!\r\n250 OK\r\n', 'GETINFO')
msg_mock.side_effect = coro_func_returning_value(message)
self.assertEqual('hi right back!', self.controller.get_info('hello'))
- @patch('stem.control.AsyncController.msg')
- @async_test
- async def test_get_info_address_caching(self, msg_mock):
+ @patch('stem.control.Controller.msg')
+ def test_get_info_address_caching(self, msg_mock):
def set_message(*args):
message = ControlMessage.from_str(*args)
msg_mock.side_effect = coro_func_returning_value(message)
set_message('551 Address unknown\r\n')
- self.assertEqual(None, self.async_controller._last_address_exc)
+ self.assertEqual(None, self.controller._last_address_exc)
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- self.assertEqual('Address unknown', str(self.async_controller._last_address_exc))
+ self.assertEqual('Address unknown', str(self.controller._last_address_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -98,26 +92,26 @@ class TestControl(unittest.TestCase):
set_message('250-address=17.2.89.80\r\n250 OK\r\n', 'GETINFO')
self.assertRaisesWith(stem.OperationFailed, 'Address unknown', self.controller.get_info, 'address')
- await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
+ self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=17.2.89.80 METHOD=DIRSERV\r\n'))
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
# invalidates the cache, transitioning from one address to another
set_message('250-address=80.89.2.17\r\n250 OK\r\n', 'GETINFO')
self.assertEqual('17.2.89.80', self.controller.get_info('address'))
- await self.async_controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
+ self.controller._handle_event(ControlMessage.from_str('650 STATUS_SERVER NOTICE EXTERNAL_ADDRESS ADDRESS=80.89.2.17 METHOD=DIRSERV\r\n'))
self.assertEqual('80.89.2.17', self.controller.get_info('address'))
- @patch('stem.control.AsyncController.msg')
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.msg')
+ @patch('stem.control.Controller.get_conf')
def test_get_info_without_fingerprint(self, get_conf_mock, msg_mock):
message = ControlMessage.from_str('551 Not running in server mode\r\n')
msg_mock.side_effect = coro_func_returning_value(message)
- get_conf_mock.return_value = None
+ get_conf_mock.side_effect = coro_func_returning_value(None)
- self.assertEqual(None, self.async_controller._last_fingerprint_exc)
+ self.assertEqual(None, self.controller._last_fingerprint_exc)
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
- self.assertEqual('Not running in server mode', str(self.async_controller._last_fingerprint_exc))
+ self.assertEqual('Not running in server mode', str(self.controller._last_fingerprint_exc))
self.assertEqual(1, msg_mock.call_count)
# now that we have a cached failure we should provide that back
@@ -127,11 +121,11 @@ class TestControl(unittest.TestCase):
# ... but if we become a relay we'll call it again
- get_conf_mock.return_value = '443'
+ get_conf_mock.side_effect = coro_func_returning_value('443')
self.assertRaisesWith(stem.OperationFailed, 'Not running in server mode', self.controller.get_info, 'fingerprint')
self.assertEqual(2, msg_mock.call_count)
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_version(self, get_info_mock):
"""
Exercises the get_version() method.
@@ -155,7 +149,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(version_2_1_object, self.controller.get_version())
# Turn off caching.
- self.async_controller._is_caching_enabled = False
+ self.controller._is_caching_enabled = False
# Return a version without caching, so it will be the new version.
self.assertEqual(version_2_2_object, self.controller.get_version())
@@ -184,13 +178,13 @@ class TestControl(unittest.TestCase):
# Turn caching back on before we leave.
self.controller._is_caching_enabled = True
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_exit_policy(self, get_info_mock):
"""
Exercises the get_exit_policy() method.
"""
- async def get_info_mock_side_effect(param, default = None):
+ async def get_info_mock_side_effect(self, param, default = None):
return {
'exit-policy/full': 'reject *:25,reject *:119,reject *:135-139,reject *:445,reject *:563,reject *:1214,reject *:4661-4666,reject *:6346-6429,reject *:6699,reject *:6881-6999,accept *:*',
}[param]
@@ -213,8 +207,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(str(expected), str(self.controller.get_exit_policy()))
- @patch('stem.control.AsyncController.get_info')
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_info')
+ @patch('stem.control.Controller.get_conf')
def test_get_ports(self, get_conf_mock, get_info_mock):
"""
Exercises the get_ports() and get_listeners() methods.
@@ -225,7 +219,7 @@ class TestControl(unittest.TestCase):
get_info_mock.side_effect = coro_func_raising_exc(InvalidArguments)
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['127.0.0.1'],
@@ -239,7 +233,7 @@ class TestControl(unittest.TestCase):
# non-local addresss
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'ControlPort': '9050',
'ControlListenAddress': ['27.4.4.1'],
@@ -290,14 +284,14 @@ class TestControl(unittest.TestCase):
self.assertEqual([], self.controller.get_listeners(Listener.CONTROL))
self.assertEqual([], self.controller.get_ports(Listener.CONTROL))
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
@patch('time.time', Mock(return_value = 1410723598.276578))
def test_get_accounting_stats(self, get_info_mock):
"""
Exercises the get_accounting_stats() method.
"""
- async def get_info_mock_side_effect(param, **kwargs):
+ async def get_info_mock_side_effect(self, param, **kwargs):
return {
'accounting/enabled': '1',
'accounting/hibernating': 'awake',
@@ -358,6 +352,7 @@ class TestControl(unittest.TestCase):
self.assertRaises(ProtocolError, self.controller.get_protocolinfo)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
def test_get_user_remote(self):
"""
Exercise the get_user() method for a non-local socket.
@@ -367,7 +362,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_user(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('atagar')))
def test_get_user_by_getinfo(self):
"""
Exercise the get_user() resolution via its getinfo option.
@@ -376,7 +371,8 @@ class TestControl(unittest.TestCase):
self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.util.system.pid_by_name', Mock(return_value = 432))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value(432)))
@patch('stem.util.system.user', Mock(return_value = 'atagar'))
def test_get_user_by_system(self):
"""
@@ -386,6 +382,7 @@ class TestControl(unittest.TestCase):
self.assertEqual('atagar', self.controller.get_user())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
def test_get_pid_remote(self):
"""
Exercise the get_pid() method for a non-local socket.
@@ -395,7 +392,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(123, self.controller.get_pid(123))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_info', Mock(side_effect = coro_func_returning_value('321')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value('321')))
def test_get_pid_by_getinfo(self):
"""
Exercise the get_pid() resolution via its getinfo option.
@@ -404,7 +401,8 @@ class TestControl(unittest.TestCase):
self.assertEqual(321, self.controller.get_pid())
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_conf')
@patch('stem.control.open', create = True)
def test_get_pid_by_pid_file(self, open_mock, get_conf_mock):
"""
@@ -418,6 +416,8 @@ class TestControl(unittest.TestCase):
open_mock.assert_called_once_with('/tmp/pid_file')
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_conf', Mock(side_effect = coro_func_returning_value(None)))
@patch('stem.util.system.pid_by_name', Mock(return_value = 432))
def test_get_pid_by_name(self):
"""
@@ -426,9 +426,9 @@ class TestControl(unittest.TestCase):
self.assertEqual(432, self.controller.get_pid())
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = False))
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
@patch('time.time', Mock(return_value = 1000.0))
def test_get_uptime_by_getinfo(self, getinfo_mock):
"""
@@ -443,8 +443,9 @@ class TestControl(unittest.TestCase):
self.assertRaisesWith(ValueError, "'GETINFO uptime' did not provide a valid numeric response: abc", self.controller.get_uptime)
@patch('stem.socket.ControlSocket.is_localhost', Mock(return_value = True))
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
- @patch('stem.control.AsyncController.get_pid', Mock(side_effect = coro_func_returning_value('12')))
+ @patch('stem.control.Controller.get_info', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.1.0.14'))))
+ @patch('stem.control.Controller.get_pid', Mock(side_effect = coro_func_returning_value('12')))
@patch('stem.util.system.start_time', Mock(return_value = 5000.0))
@patch('time.time', Mock(return_value = 5200.0))
def test_get_uptime_by_process(self):
@@ -454,7 +455,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(200.0, self.controller.get_uptime())
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status_for_ourselves(self, get_info_mock):
"""
Exercises the get_network_status() method for getting our own relay.
@@ -472,7 +473,7 @@ class TestControl(unittest.TestCase):
desc = NS_DESC % ('moria1', '/96bKo4soysolMgKn5Hex2nyFSY')
- async def get_info_mock_side_effect(param, **kwargs):
+ async def get_info_mock_side_effect(self, param, **kwargs):
return {
'fingerprint': '9695DFC35FFEB861329B9F1AB04C46397020CE31',
'ns/id/9695DFC35FFEB861329B9F1AB04C46397020CE31': desc,
@@ -482,7 +483,7 @@ class TestControl(unittest.TestCase):
self.assertEqual(stem.descriptor.router_status_entry.RouterStatusEntryV3(desc), self.controller.get_network_status())
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status_when_unavailable(self, get_info_mock):
"""
Exercises the get_network_status() method.
@@ -494,7 +495,7 @@ class TestControl(unittest.TestCase):
exc_msg = "Tor was unable to provide the descriptor for '5AC9C5AA75BA1F18D8459B326B4B8111A856D290'"
self.assertRaisesWith(DescriptorUnavailable, exc_msg, self.controller.get_network_status, '5AC9C5AA75BA1F18D8459B326B4B8111A856D290')
- @patch('stem.control.AsyncController.get_info')
+ @patch('stem.control.Controller.get_info')
def test_get_network_status(self, get_info_mock):
"""
Exercises the get_network_status() method.
@@ -540,16 +541,14 @@ class TestControl(unittest.TestCase):
self.assertRaises(InvalidArguments, self.controller.get_network_status, nickname)
- @patch('stem.control.AsyncController.is_authenticated', Mock(return_value = True))
- @patch('stem.control.AsyncController._attach_listeners')
- @patch('stem.control.AsyncController.get_version')
- def test_add_event_listener(self, get_version_mock, attach_listeners_mock):
+ @patch('stem.control.Controller.is_authenticated', Mock(return_value = True))
+ @patch('stem.control.Controller._attach_listeners', Mock(side_effect = coro_func_returning_value(([], []))))
+ @patch('stem.control.Controller.get_version')
+ def test_add_event_listener(self, get_version_mock):
"""
Exercises the add_event_listener and remove_event_listener methods.
"""
- attach_listeners_mock.side_effect = coro_func_returning_value(([], []))
-
def set_version(version_str):
version = stem.version.Version(version_str)
get_version_mock.side_effect = coro_func_returning_value(version)
@@ -621,10 +620,10 @@ class TestControl(unittest.TestCase):
self._emit_event(BW_EVENT)
self.bw_listener.assert_called_once_with(BW_EVENT)
- @patch('stem.control.AsyncController.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
- @patch('stem.control.AsyncController.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
- @patch('stem.control.AsyncController.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
- @patch('stem.control.AsyncController.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.get_version', Mock(side_effect = coro_func_returning_value(stem.version.Version('0.5.0.14'))))
+ @patch('stem.control.Controller.msg', Mock(side_effect = coro_func_returning_value(ControlMessage.from_str('250 OK\r\n'))))
+ @patch('stem.control.Controller.add_event_listener', Mock(side_effect = coro_func_returning_value(None)))
+ @patch('stem.control.Controller.remove_event_listener', Mock(side_effect = coro_func_returning_value(None)))
def test_timeout(self):
"""
Methods that have an 'await' argument also have an optional timeout. Check
@@ -648,7 +647,7 @@ class TestControl(unittest.TestCase):
response = ''.join(['%s\r\n' % ' '.join(entry) for entry in valid_streams])
get_info_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.get_info', get_info_mock):
+ with patch('stem.control.Controller.get_info', get_info_mock):
streams = self.controller.get_streams()
self.assertEqual(len(valid_streams), len(streams))
@@ -669,7 +668,7 @@ class TestControl(unittest.TestCase):
response = stem.response.ControlMessage.from_str('555 Connection is not managed by controller.\r\n')
msg_mock = Mock(side_effect = coro_func_returning_value(response))
- with patch('stem.control.AsyncController.msg', msg_mock):
+ with patch('stem.control.Controller.msg', msg_mock):
self.assertRaises(UnsatisfiableRequest, self.controller.attach_stream, 'stream_id', 'circ_id')
def test_parse_circ_path(self):
@@ -712,7 +711,7 @@ class TestControl(unittest.TestCase):
for test_input in malformed_inputs:
self.assertRaises(ProtocolError, _parse_circ_path, test_input)
- @patch('stem.control.AsyncController.get_conf')
+ @patch('stem.control.Controller.get_conf')
def test_get_effective_rate(self, get_conf_mock):
"""
Exercise the get_effective_rate() method.
@@ -720,7 +719,7 @@ class TestControl(unittest.TestCase):
# check default if nothing was set
- async def get_conf_mock_side_effect(param, *args, **kwargs):
+ async def get_conf_mock_side_effect(self, param, *args, **kwargs):
return {
'BandwidthRate': '1073741824',
'BandwidthBurst': '1073741824',
@@ -749,19 +748,19 @@ class TestControl(unittest.TestCase):
# with its work is to join on the thread.
with patch('time.time', Mock(return_value = TEST_TIMESTAMP)):
- with patch('stem.control.AsyncController.is_alive') as is_alive_mock:
+ with patch('stem.control.Controller.is_alive') as is_alive_mock:
is_alive_mock.return_value = True
loop = self.controller._loop
- asyncio.run_coroutine_threadsafe(self.async_controller._event_loop(), loop)
+ asyncio.run_coroutine_threadsafe(Controller._event_loop(self.controller), loop)
try:
# Converting an event back into an uncast ControlMessage, then feeding it
# into our controller's event queue.
uncast_event = ControlMessage.from_str(event.raw_content())
- event_queue = self.async_controller._event_queue
+ event_queue = self.controller._event_queue
asyncio.run_coroutine_threadsafe(event_queue.put(uncast_event), loop).result()
asyncio.run_coroutine_threadsafe(event_queue.join(), loop).result() # block until the event is consumed
finally:
is_alive_mock.return_value = False
- asyncio.run_coroutine_threadsafe(self.async_controller._close(), loop).result()
+ self.controller._close()
diff --git a/test/unit/descriptor/remote.py b/test/unit/descriptor/remote.py
index bb6f554c..3facd6a5 100644
--- a/test/unit/descriptor/remote.py
+++ b/test/unit/descriptor/remote.py
@@ -135,7 +135,7 @@ class TestDescriptorDownloader(unittest.TestCase):
def test_reply_header_data(self):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
self.assertEqual(None, query.reply_headers) # initially we don't have a reply
- query.run(close = False)
+ query.run(stop = False)
self.assertEqual('Fri, 13 Apr 2018 16:35:50 GMT', query.reply_headers.get('Date'))
self.assertEqual('application/octet-stream', query.reply_headers.get('Content-Type'))
1
0
commit 6f8f8d5c26079ed0e1e08bbb5815f2f39891a46e
Merge: 8634aa04 46dec7b9
Author: Damian Johnson <atagar(a)torproject.org>
Date: Wed Jul 15 18:10:33 2020 -0700
Migrate to asyncio
Python 3.5 added asyncio, an asynchronous IO framework similar to Twisted...
https://www.python.org/dev/peps/pep-0492/
This added new python keywords (async and await) which give callers more
control over how they await asynchronous operations (for instance,
asyncio.wait_for() to apply a timeout).
This branch migrates our stem.control, stem.client, and stem.descriptor.remote
modules from a synchronous to asynchronous implementation. Usually this would
preclude us from being used by non-asyncio users, but this also adds a
Synchronous mixin that allows us to be used from either context.
In other words, internally Stem is now an asynchronous library that is usable
by asyncio users, while retaining its ability to be used by synchronous users
in the exact same way we always have.
Win-win for everyone. Many thanks to Illia for all his hard work on this
branch!
.gitignore | 3 +
run_tests.py | 19 +-
stem/client/__init__.py | 93 ++--
stem/connection.py | 227 ++++++---
stem/control.py | 648 ++++++++++++++-----------
stem/descriptor/remote.py | 266 +++++-----
stem/interpreter/__init__.py | 5 +-
stem/interpreter/autocomplete.py | 8 +-
stem/interpreter/commands.py | 12 +-
stem/interpreter/help.py | 9 +-
stem/response/__init__.py | 2 +-
stem/socket.py | 441 +++++++++--------
stem/util/__init__.py | 226 ++++++++-
stem/util/test_tools.py | 61 ++-
test/integ/client/connection.py | 38 +-
test/integ/connection/authentication.py | 143 +++---
test/integ/connection/connect.py | 17 +-
test/integ/control/base_controller.py | 133 ++---
test/integ/control/controller.py | 828 ++++++++++++++++++--------------
test/integ/process.py | 28 +-
test/integ/response/protocolinfo.py | 43 +-
test/integ/socket/control_message.py | 84 ++--
test/integ/socket/control_socket.py | 84 ++--
test/integ/util/connection.py | 6 +-
test/integ/util/proc.py | 6 +-
test/integ/version.py | 12 +-
test/runner.py | 48 +-
test/settings.cfg | 11 +
test/unit/connection/authentication.py | 36 +-
test/unit/connection/connect.py | 54 ++-
test/unit/control/controller.py | 206 ++++----
test/unit/descriptor/remote.py | 230 ++++-----
test/unit/response/control_message.py | 10 +-
test/unit/util/synchronous.py | 237 +++++++++
34 files changed, 2602 insertions(+), 1672 deletions(-)
1
0