[tor-commits] [stem/master] @asynchronous testing decorator

atagar at torproject.org atagar at torproject.org
Thu Jun 8 17:17:55 UTC 2017


commit 06836f52258e0e0d92acf523ac249e9e39cae7f0
Author: Damian Johnson <atagar at torproject.org>
Date:   Tue Jun 6 09:42:24 2017 -0700

    @asynchronous testing decorator
    
    Adding a @asynchronous decorator which registers the function as being
    asynchronous yet also allows it to be run normally. This way if the user calls
    run() it's run early, but if not it's executed by unittest like a normal test.
---
 run_tests.py                                  |  7 +-
 stem/util/test_tools.py                       | 92 ++++++++++++++++++---------
 test/integ/descriptor/extrainfo_descriptor.py |  7 +-
 test/integ/descriptor/microdescriptor.py      |  6 +-
 test/integ/descriptor/networkstatus.py        | 10 +--
 test/integ/descriptor/server_descriptor.py    |  6 +-
 test/integ/installation.py                    | 12 ++--
 test/integ/process.py                         | 61 +++++++-----------
 8 files changed, 114 insertions(+), 87 deletions(-)

diff --git a/run_tests.py b/run_tests.py
index 0a162eb..973da06 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -347,9 +347,10 @@ def main():
     if task:
       task.join()
 
-      for path, issues in task.result.items():
-        for issue in issues:
-          static_check_issues.setdefault(path, []).append(issue)
+      if task.result:
+        for path, issues in task.result.items():
+          for issue in issues:
+            static_check_issues.setdefault(path, []).append(issue)
     elif not task.is_available and task.unavailable_msg:
       println(task.unavailable_msg, ERROR)
 
diff --git a/stem/util/test_tools.py b/stem/util/test_tools.py
index 0d167d0..43befd4 100644
--- a/stem/util/test_tools.py
+++ b/stem/util/test_tools.py
@@ -39,6 +39,7 @@ import unittest
 
 import stem.prereq
 import stem.util.conf
+import stem.util.enum
 import stem.util.system
 
 CONFIG = stem.util.conf.config_dict('test', {
@@ -49,7 +50,13 @@ CONFIG = stem.util.conf.config_dict('test', {
 })
 
 TEST_RUNTIMES = {}
+ASYNC_TESTS = {}
 
+AsyncStatus = stem.util.enum.UppercaseEnum('PENDING', 'RUNNING', 'FINISHED')
+AsyncResult = collections.namedtuple('AsyncResult', 'type msg')
+
+# TODO: Providing a copy of SkipTest that works with python 2.6. This will be
+# dropped when we remove python 2.6 support.
 
 if stem.prereq._is_python_26():
   class SkipTest(Exception):
@@ -58,6 +65,12 @@ else:
   SkipTest = unittest.case.SkipTest
 
 
+def asynchronous(func):
+  test = stem.util.test_tools.AsyncTest(func)
+  ASYNC_TESTS['%s.%s' % (func.__module__, func.__name__)] = test
+  return test.method
+
+
 class AsyncTest(object):
   """
   Test that's run asychronously. These are functions (no self reference)
@@ -80,51 +93,72 @@ class AsyncTest(object):
   .. versionadded:: 1.6.0
   """
 
-  def __init__(self, test_runner, args = None, threaded = False):
-    def _wrapper(conn, runner, test_args):
+  def __init__(self, runner, args = None, threaded = False):
+    self._runner = runner
+    self._runner_args = args
+    self._threaded = threaded
+
+    self.method = lambda test: self.result(test)  # method that can be mixed into TestCases
+    self.method.async = self
+
+    self._process = None
+    self._process_pipe = None
+    self._process_lock = threading.RLock()
+
+    self._result = None
+    self._status = AsyncStatus.PENDING
+
+  def run(self, *runner_args, **kwargs):
+    def _wrapper(conn, runner, args):
       try:
-        runner(*test_args) if test_args else runner()
-        conn.send(('success', None))
+        runner(*args) if args else runner()
+        conn.send(AsyncResult('success', None))
       except AssertionError as exc:
-        conn.send(('failure', str(exc)))
+        conn.send(AsyncResult('failure', str(exc)))
       except SkipTest as exc:
-        conn.send(('skipped', str(exc)))
+        conn.send(AsyncResult('skipped', str(exc)))
       finally:
         conn.close()
 
-    self.method = lambda test: self.result(test)  # method that can be mixed into TestCases
+    with self._process_lock:
+      if self._status == AsyncStatus.PENDING:
+        if runner_args:
+          self._runner_args = runner_args
 
-    self._result_type, self._result_msg = None, None
-    self._result_lock = threading.RLock()
-    self._results_pipe, child_pipe = multiprocessing.Pipe()
+        if 'threaded' in kwargs:
+          self._threaded = kwargs['threaded']
 
-    if threaded:
-      self._test_process = threading.Thread(target = _wrapper, args = (child_pipe, test_runner, args))
-    else:
-      self._test_process = multiprocessing.Process(target = _wrapper, args = (child_pipe, test_runner, args))
+        self._process_pipe, child_pipe = multiprocessing.Pipe()
+
+        if self._threaded:
+          self._process = threading.Thread(target = _wrapper, args = (child_pipe, self._runner, self._runner_args))
+        else:
+          self._process = multiprocessing.Process(target = _wrapper, args = (child_pipe, self._runner, self._runner_args))
 
-    self._test_process.start()
+        self._process.start()
+        self._status = AsyncStatus.RUNNING
 
   def pid(self):
-    with self._result_lock:
-      return self._test_process.pid if self._test_process else None
+    with self._process_lock:
+      return self._process.pid if (self._process and not self._threaded) else None
 
   def join(self):
     self.result(None)
 
   def result(self, test):
-    with self._result_lock:
-      if self._test_process:
-        self._result_type, self._result_msg = self._results_pipe.recv()
-        self._test_process.join()
-        self._test_process = None
-
-      if not test:
-        return
-      elif self._result_type == 'failure':
-        test.fail(self._result_msg)
-      elif self._result_type == 'skipped':
-        test.skipTest(self._result_msg)
+    with self._process_lock:
+      if self._status == AsyncStatus.PENDING:
+        self.run()
+
+      if self._status == AsyncStatus.RUNNING:
+        self._result = self._process_pipe.recv()
+        self._process.join()
+        self._status = AsyncStatus.FINISHED
+
+      if test and self._result.type == 'failure':
+        test.fail(self._result.msg)
+      elif test and self._result.type == 'skipped':
+        test.skipTest(self._result.msg)
 
 
 class Issue(collections.namedtuple('Issue', ['line_number', 'message', 'line'])):
diff --git a/test/integ/descriptor/extrainfo_descriptor.py b/test/integ/descriptor/extrainfo_descriptor.py
index e2350ce..0e13435 100644
--- a/test/integ/descriptor/extrainfo_descriptor.py
+++ b/test/integ/descriptor/extrainfo_descriptor.py
@@ -8,15 +8,16 @@ import unittest
 import stem.descriptor
 import stem.util.test_tools
 import test
-import test.require
+
+from stem.util.test_tools import asynchronous
 
 
 class TestExtraInfoDescriptor(unittest.TestCase):
   @staticmethod
   def run_tests(test_dir):
-    TestExtraInfoDescriptor.test_cached_descriptor = stem.util.test_tools.AsyncTest(TestExtraInfoDescriptor.test_cached_descriptor, args = (test_dir,), threaded = True).method
+    stem.util.test_tools.ASYNC_TESTS['test.integ.descriptor.extrainfo_descriptor.test_cached_descriptor'].run(test_dir, threaded = True)
 
-  @staticmethod
+  @asynchronous
   def test_cached_descriptor(test_dir):
     """
     Parses the cached descriptor file in our data directory, checking that it
diff --git a/test/integ/descriptor/microdescriptor.py b/test/integ/descriptor/microdescriptor.py
index cc48fce..8987906 100644
--- a/test/integ/descriptor/microdescriptor.py
+++ b/test/integ/descriptor/microdescriptor.py
@@ -9,13 +9,15 @@ import stem.descriptor
 import stem.util.test_tools
 import test
 
+from stem.util.test_tools import asynchronous
+
 
 class TestMicrodescriptor(unittest.TestCase):
   @staticmethod
   def run_tests(test_dir):
-    TestMicrodescriptor.test_cached_microdescriptors = stem.util.test_tools.AsyncTest(TestMicrodescriptor.test_cached_microdescriptors, args = (test_dir,), threaded = True).method
+    stem.util.test_tools.ASYNC_TESTS['test.integ.descriptor.microdescriptor.test_cached_microdescriptors'].run(test_dir, threaded = True)
 
-  @staticmethod
+  @asynchronous
   def test_cached_microdescriptors(test_dir):
     """
     Parses the cached microdescriptor file in our data directory, checking that
diff --git a/test/integ/descriptor/networkstatus.py b/test/integ/descriptor/networkstatus.py
index 46d7654..dee8e7a 100644
--- a/test/integ/descriptor/networkstatus.py
+++ b/test/integ/descriptor/networkstatus.py
@@ -14,12 +14,14 @@ import test
 import test.require
 import test.runner
 
+from stem.util.test_tools import asynchronous
+
 
 class TestNetworkStatus(unittest.TestCase):
   @staticmethod
   def run_tests(test_dir):
-    TestNetworkStatus.test_cached_consensus = stem.util.test_tools.AsyncTest(TestNetworkStatus.test_cached_consensus, args = (test_dir,), threaded = True).method
-    TestNetworkStatus.test_cached_microdesc_consensus = stem.util.test_tools.AsyncTest(TestNetworkStatus.test_cached_microdesc_consensus, args = (test_dir,), threaded = True).method
+    stem.util.test_tools.ASYNC_TESTS['test.integ.descriptor.networkstatus.test_cached_consensus'].run(test_dir, threaded = True)
+    stem.util.test_tools.ASYNC_TESTS['test.integ.descriptor.networkstatus.test_cached_microdesc_consensus'].run(test_dir, threaded = True)
 
   @test.require.only_run_once
   @test.require.online
@@ -32,7 +34,7 @@ class TestNetworkStatus(unittest.TestCase):
 
     stem.descriptor.remote.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT, validate = True).run()
 
-  @staticmethod
+  @asynchronous
   def test_cached_consensus(test_dir):
     """
     Parses the cached-consensus file in our data directory.
@@ -68,7 +70,7 @@ class TestNetworkStatus(unittest.TestCase):
     if count < 100:
       raise AssertionError('%s only included %s relays' % (consensus_path, count))
 
-  @staticmethod
+  @asynchronous
   def test_cached_microdesc_consensus(test_dir):
     """
     Parses the cached-microdesc-consensus file in our data directory.
diff --git a/test/integ/descriptor/server_descriptor.py b/test/integ/descriptor/server_descriptor.py
index 0726c35..a4a78ac 100644
--- a/test/integ/descriptor/server_descriptor.py
+++ b/test/integ/descriptor/server_descriptor.py
@@ -9,13 +9,15 @@ import stem.descriptor
 import stem.util.test_tools
 import test
 
+from stem.util.test_tools import asynchronous
+
 
 class TestServerDescriptor(unittest.TestCase):
   @staticmethod
   def run_tests(test_dir):
-    TestServerDescriptor.test_cached_descriptor = stem.util.test_tools.AsyncTest(TestServerDescriptor.test_cached_descriptor, args = (test_dir,), threaded = True).method
+    stem.util.test_tools.ASYNC_TESTS['test.integ.descriptor.server_descriptor.test_cached_descriptor'].run(test_dir, threaded = True)
 
-  @staticmethod
+  @asynchronous
   def test_cached_descriptor(test_dir):
     """
     Parses the cached descriptor file in our data directory, checking that it
diff --git a/test/integ/installation.py b/test/integ/installation.py
index 0253efd..82bc717 100644
--- a/test/integ/installation.py
+++ b/test/integ/installation.py
@@ -15,6 +15,8 @@ import stem.util.system
 import stem.util.test_tools
 import test
 
+from stem.util.test_tools import asynchronous
+
 BASE_INSTALL_PATH = '/tmp/stem_test'
 DIST_PATH = os.path.join(test.STEM_BASE, 'dist')
 PYTHON_EXE = sys.executable if sys.executable else 'python'
@@ -56,11 +58,11 @@ def _assert_has_all_files(path):
 class TestInstallation(unittest.TestCase):
   @staticmethod
   def run_tests():
-    test_install = stem.util.test_tools.AsyncTest(TestInstallation.test_install)
-    TestInstallation.test_install = test_install.method
-    TestInstallation.test_sdist = stem.util.test_tools.AsyncTest(TestInstallation.test_sdist, args = (test_install.pid(),)).method
+    test_install = stem.util.test_tools.ASYNC_TESTS['test.integ.installation.test_install']
+    test_install.run()
+    stem.util.test_tools.ASYNC_TESTS['test.integ.installation.test_sdist'].run(test_install.pid())
 
-  @staticmethod
+  @asynchronous
   def test_install():
     """
     Installs with 'python setup.py install' and checks we can use what we
@@ -89,7 +91,7 @@ class TestInstallation(unittest.TestCase):
       if os.path.exists(BASE_INSTALL_PATH):
         shutil.rmtree(BASE_INSTALL_PATH)
 
-  @staticmethod
+  @asynchronous
   def test_sdist(dependency_pid):
     """
     Creates a source distribution tarball with 'python setup.py sdist' and
diff --git a/test/integ/process.py b/test/integ/process.py
index 4e59e00..0aa639d 100644
--- a/test/integ/process.py
+++ b/test/integ/process.py
@@ -26,6 +26,8 @@ import test
 import test.require
 import test.runner
 
+from stem.util.test_tools import asynchronous
+
 try:
   # added in python 3.3
   from unittest.mock import patch, Mock
@@ -80,28 +82,9 @@ def run_tor(tor_cmd, *args, **kwargs):
 class TestProcess(unittest.TestCase):
   @staticmethod
   def run_tests(tor_cmd):
-    async_tests = (
-      'test_version_argument',
-      'test_help_argument',
-      'test_quiet_argument',
-      'test_hush_argument',
-      'test_hash_password',
-      'test_hash_password_requires_argument',
-      'test_list_torrc_options_argument',
-      'test_torrc_arguments',
-      'test_torrc_arguments_via_stdin',
-      'test_with_missing_torrc',
-      'test_can_run_multithreaded',
-      'test_launch_tor_with_config_via_file',
-      'test_launch_tor_with_config_via_stdin',
-      'test_with_invalid_config',
-      'test_launch_tor_with_timeout',
-      'test_take_ownership_via_pid',
-      'test_take_ownership_via_controller',
-    )
-
-    for func in async_tests:
-      setattr(TestProcess, func, stem.util.test_tools.AsyncTest(getattr(TestProcess, func), args = (tor_cmd,)).method)
+    for func, async_test in stem.util.test_tools.ASYNC_TESTS.items():
+      if func.startswith('test.integ.process.'):
+        async_test.run(tor_cmd)
 
   def setUp(self):
     self.data_directory = tempfile.mkdtemp()
@@ -111,7 +94,7 @@ class TestProcess(unittest.TestCase):
   def tearDown(self):
     shutil.rmtree(self.data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_version_argument(tor_cmd):
     """
     Check that 'tor --version' matches 'GETINFO version'.
@@ -122,7 +105,7 @@ class TestProcess(unittest.TestCase):
     if 'Tor version %s.\n' % test.tor_version() != version_output:
       raise AssertionError('Unexpected response: %s' % version_output)
 
-  @staticmethod
+  @asynchronous
   def test_help_argument(tor_cmd):
     """
     Check that 'tor --help' provides the expected output.
@@ -136,7 +119,7 @@ class TestProcess(unittest.TestCase):
     if help_output != run_tor(tor_cmd, '-h'):
       raise AssertionError("'tor -h' should simply be an alias for 'tor --help'")
 
-  @staticmethod
+  @asynchronous
   def test_quiet_argument(tor_cmd):
     """
     Check that we don't provide anything on stdout when running 'tor --quiet'.
@@ -145,7 +128,7 @@ class TestProcess(unittest.TestCase):
     if '' != run_tor(tor_cmd, '--quiet', '--invalid_argument', 'true', expect_failure = True):
       raise AssertionError('No output should be provided with the --quiet argument')
 
-  @staticmethod
+  @asynchronous
   def test_hush_argument(tor_cmd):
     """
     Check that we only get warnings and errors when running 'tor --hush'.
@@ -161,7 +144,7 @@ class TestProcess(unittest.TestCase):
     if "[warn] Failed to parse/validate config: Unknown option 'invalid_argument'.  Failing." not in output:
       raise AssertionError('Unexpected response: %s' % output)
 
-  @staticmethod
+  @asynchronous
   def test_hash_password(tor_cmd):
     """
     Hash a controller password. It's salted so can't assert that we get a
@@ -192,7 +175,7 @@ class TestProcess(unittest.TestCase):
     if hashlib.sha1(inp).digest() != hashed:
       raise AssertionError('Password hash not what we expected (%s rather than %s)' % (hashlib.sha1(inp).digest(), hashed))
 
-  @staticmethod
+  @asynchronous
   def test_hash_password_requires_argument(tor_cmd):
     """
     Check that 'tor --hash-password' balks if not provided with something to
@@ -259,7 +242,7 @@ class TestProcess(unittest.TestCase):
       expected = 'stemIntegTest %s\n' % fingerprint
       self.assertEqual(expected, fingerprint_file.read())
 
-  @staticmethod
+  @asynchronous
   def test_list_torrc_options_argument(tor_cmd):
     """
     Exercise our 'tor --list-torrc-options' argument.
@@ -295,7 +278,7 @@ class TestProcess(unittest.TestCase):
 
         self.assertEqual('nope', str(exc))
 
-  @staticmethod
+  @asynchronous
   def test_torrc_arguments(tor_cmd):
     """
     Pass configuration options on the commandline.
@@ -334,7 +317,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_torrc_arguments_via_stdin(tor_cmd):
     """
     Pass configuration options via stdin.
@@ -354,7 +337,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_with_missing_torrc(tor_cmd):
     """
     Provide a torrc path that doesn't exist.
@@ -370,7 +353,7 @@ class TestProcess(unittest.TestCase):
     if '[notice] Configuration file "/path/that/really/shouldnt/exist" not present, using reasonable defaults.' not in output:
       raise AssertionError('Missing torrc should be allowed with --ignore-missing-torrc')
 
-  @staticmethod
+  @asynchronous
   def test_can_run_multithreaded(tor_cmd):
     """
     Our launch_tor() function uses signal to support its timeout argument.
@@ -422,7 +405,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_launch_tor_with_config_via_file(tor_cmd):
     """
     Exercises launch_tor_with_config when we write a torrc to disk.
@@ -466,7 +449,7 @@ class TestProcess(unittest.TestCase):
 
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_launch_tor_with_config_via_stdin(tor_cmd):
     """
     Exercises launch_tor_with_config when we provide our torrc via stdin.
@@ -509,7 +492,7 @@ class TestProcess(unittest.TestCase):
 
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_with_invalid_config(tor_cmd):
     """
     Spawn a tor process with a configuration that should make it dead on arrival.
@@ -540,7 +523,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_launch_tor_with_timeout(tor_cmd):
     """
     Runs launch_tor where it times out before completing.
@@ -568,7 +551,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_take_ownership_via_pid(tor_cmd):
     """
     Checks that the tor process quits after we do if we set take_ownership. To
@@ -618,7 +601,7 @@ class TestProcess(unittest.TestCase):
     finally:
       shutil.rmtree(data_directory)
 
-  @staticmethod
+  @asynchronous
   def test_take_ownership_via_controller(tor_cmd):
     """
     Checks that the tor process quits after the controller that owns it





More information about the tor-commits mailing list