[tor-commits] [stem/master] @require_controller decorator

atagar at torproject.org atagar at torproject.org
Sat Feb 21 22:08:00 UTC 2015


commit b8c5893617e825b5ff7eed4de9c245921fa60e97
Author: Damian Johnson <atagar at torproject.org>
Date:   Sat Feb 21 13:21:32 2015 -0800

    @require_controller decorator
    
    Replacing our require_control() function with a decorator. Yay, nicer code!
---
 test/integ/connection/authentication.py |   48 +++------
 test/integ/connection/connect.py        |   14 +--
 test/integ/control/base_controller.py   |   39 +++-----
 test/integ/control/controller.py        |  164 ++++++++++---------------------
 test/integ/process.py                   |    6 +-
 test/integ/response/protocolinfo.py     |   17 +---
 test/integ/socket/control_message.py    |   27 ++---
 test/integ/socket/control_socket.py     |   30 ++----
 test/integ/version.py                   |   10 +-
 test/runner.py                          |   18 ++--
 10 files changed, 126 insertions(+), 247 deletions(-)

diff --git a/test/integ/connection/authentication.py b/test/integ/connection/authentication.py
index 816e671..3687af3 100644
--- a/test/integ/connection/authentication.py
+++ b/test/integ/connection/authentication.py
@@ -11,6 +11,8 @@ import stem.socket
 import stem.version
 import test.runner
 
+from test.runner import require_controller
+
 # Responses given by tor for various authentication failures. These may change
 # in the future and if they do then this test should be updated.
 
@@ -105,40 +107,36 @@ class TestAuthenticate(unittest.TestCase):
     if tor_version >= stem.version.Requirement.AUTH_SAFECOOKIE:
       self.cookie_auth_methods.append(stem.connection.AuthMethod.SAFECOOKIE)
 
+  @require_controller
   def test_authenticate_general_socket(self):
     """
     Tests that the authenticate function can authenticate to our socket.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
+
     with runner.get_tor_socket(False) as control_socket:
       stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
       test.runner.exercise_controller(self, control_socket)
 
+  @require_controller
   def test_authenticate_general_controller(self):
     """
     Tests that the authenticate function can authenticate via a Controller.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
+
     with runner.get_tor_controller(False) as controller:
       stem.connection.authenticate(controller, test.runner.CONTROL_PASSWORD, runner.get_chroot())
       test.runner.exercise_controller(self, controller)
 
+  @require_controller
   def test_authenticate_general_example(self):
     """
     Tests the authenticate function with something like its pydoc example.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     tor_options = runner.get_options()
 
@@ -169,14 +167,12 @@ class TestAuthenticate(unittest.TestCase):
     finally:
       control_socket.close()
 
+  @require_controller
   def test_authenticate_general_password(self):
     """
     Tests the authenticate function's password argument.
     """
 
-    if test.runner.require_control(self):
-      return
-
     # this is a much better test if we're just using password auth, since
     # authenticate will work reguardless if there's something else to
     # authenticate with
@@ -206,6 +202,7 @@ class TestAuthenticate(unittest.TestCase):
       stem.connection.authenticate(control_socket, test.runner.CONTROL_PASSWORD, runner.get_chroot())
       test.runner.exercise_controller(self, control_socket)
 
+  @require_controller
   def test_authenticate_general_cookie(self):
     """
     Tests the authenticate function with only cookie authentication methods.
@@ -213,9 +210,6 @@ class TestAuthenticate(unittest.TestCase):
     individually.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     tor_options = runner.get_options()
     is_cookie_only = test.runner.Torrc.COOKIE in tor_options and test.runner.Torrc.PASSWORD not in tor_options
@@ -233,14 +227,12 @@ class TestAuthenticate(unittest.TestCase):
             protocolinfo_response.auth_methods = (method, )
             stem.connection.authenticate(control_socket, chroot_path = runner.get_chroot(), protocolinfo_response = protocolinfo_response)
 
+  @require_controller
   def test_authenticate_none(self):
     """
     Tests the authenticate_none function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     auth_type = stem.connection.AuthMethod.NONE
 
     if _can_authenticate(auth_type):
@@ -248,14 +240,12 @@ class TestAuthenticate(unittest.TestCase):
     else:
       self.assertRaises(stem.connection.OpenAuthRejected, self._check_auth, auth_type)
 
+  @require_controller
   def test_authenticate_password(self):
     """
     Tests the authenticate_password function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     auth_type = stem.connection.AuthMethod.PASSWORD
     auth_value = test.runner.CONTROL_PASSWORD
 
@@ -278,14 +268,12 @@ class TestAuthenticate(unittest.TestCase):
 
         self.assertRaises(exc_type, self._check_auth, auth_type, auth_value)
 
+  @require_controller
   def test_authenticate_cookie(self):
     """
     Tests the authenticate_cookie function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     auth_value = test.runner.get_runner().get_auth_cookie_path()
 
     for auth_type in self.cookie_auth_methods:
@@ -302,15 +290,13 @@ class TestAuthenticate(unittest.TestCase):
       else:
         self.assertRaises(stem.connection.CookieAuthRejected, self._check_auth, auth_type, auth_value, False)
 
+  @require_controller
   def test_authenticate_cookie_invalid(self):
     """
     Tests the authenticate_cookie function with a properly sized but incorrect
     value.
     """
 
-    if test.runner.require_control(self):
-      return
-
     auth_value = test.runner.get_runner().get_test_dir('fake_cookie')
 
     # we need to create a 32 byte cookie file to load from
@@ -341,19 +327,18 @@ class TestAuthenticate(unittest.TestCase):
 
     os.remove(auth_value)
 
+  @require_controller
   def test_authenticate_cookie_missing(self):
     """
     Tests the authenticate_cookie function with a path that really, really
     shouldn't exist.
     """
 
-    if test.runner.require_control(self):
-      return
-
     for auth_type in self.cookie_auth_methods:
       auth_value = "/if/this/exists/then/they're/asking/for/a/failure"
       self.assertRaises(stem.connection.UnreadableCookieFile, self._check_auth, auth_type, auth_value, False)
 
+  @require_controller
   def test_authenticate_cookie_wrong_size(self):
     """
     Tests the authenticate_cookie function with our torrc as an auth cookie.
@@ -361,9 +346,6 @@ class TestAuthenticate(unittest.TestCase):
     socket.
     """
 
-    if test.runner.require_control(self):
-      return
-
     auth_value = test.runner.get_runner().get_torrc_path(True)
 
     for auth_type in self.cookie_auth_methods:
diff --git a/test/integ/connection/connect.py b/test/integ/connection/connect.py
index c1785fd..b50fdfc 100644
--- a/test/integ/connection/connect.py
+++ b/test/integ/connection/connect.py
@@ -13,6 +13,8 @@ except ImportError:
 import stem.connection
 import test.runner
 
+from test.runner import require_controller
+
 
 class TestConnect(unittest.TestCase):
   def setUp(self):
@@ -23,14 +25,12 @@ class TestConnect(unittest.TestCase):
   def tearDown(self):
     sys.stdout = self.original_stdout
 
+  @require_controller
   def test_connect(self):
     """
     Basic sanity checks for the connect function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     control_socket = stem.connection.connect(
@@ -42,14 +42,12 @@ class TestConnect(unittest.TestCase):
 
     test.runner.exercise_controller(self, control_socket)
 
+  @require_controller
   def test_connect_port(self):
     """
     Basic sanity checks for the connect_port function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     control_socket = stem.connection.connect_port(
@@ -64,14 +62,12 @@ class TestConnect(unittest.TestCase):
     else:
       self.assertEqual(control_socket, None)
 
+  @require_controller
   def test_connect_socket_file(self):
     """
     Basic sanity checks for the connect_socket_file function.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     control_socket = stem.connection.connect_socket_file(
diff --git a/test/integ/control/base_controller.py b/test/integ/control/base_controller.py
index 9f0e2aa..48dd3df 100644
--- a/test/integ/control/base_controller.py
+++ b/test/integ/control/base_controller.py
@@ -8,10 +8,13 @@ import time
 import unittest
 
 import stem.control
-import test.runner
 import stem.socket
 import stem.util.system
 
+import test.runner
+
+from test.runner import require_controller
+
 
 class StateObserver(object):
   """
@@ -35,15 +38,14 @@ class StateObserver(object):
 
 
 class TestBaseController(unittest.TestCase):
+  @require_controller
   def test_connect_repeatedly(self):
     """
     Connects and closes the socket repeatedly. This is a simple attempt to
     trigger concurrency issues.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif stem.util.system.is_mac():
+    if stem.util.system.is_mac():
       test.runner.skip(self, '(ticket #6235)')
       return
 
@@ -54,53 +56,46 @@ class TestBaseController(unittest.TestCase):
         controller.connect()
         controller.close()
 
+  @require_controller
   def test_msg(self):
     """
     Tests a basic query with the msg() method.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       controller = stem.control.BaseController(control_socket)
       test.runner.exercise_controller(self, controller)
 
+  @require_controller
   def test_msg_invalid(self):
     """
     Tests the msg() method against an invalid controller command.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       controller = stem.control.BaseController(control_socket)
       response = controller.msg('invalid')
       self.assertEqual('Unrecognized command "invalid"', str(response))
 
+  @require_controller
   def test_msg_invalid_getinfo(self):
     """
     Tests the msg() method against a non-existant GETINFO option.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       controller = stem.control.BaseController(control_socket)
       response = controller.msg('GETINFO blarg')
       self.assertEqual('Unrecognized key "blarg"', str(response))
 
+  @require_controller
   def test_msg_repeatedly(self):
     """
     Connects, sends a burst of messages, and closes the socket repeatedly. This
     is a simple attempt to trigger concurrency issues.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif stem.util.system.is_mac():
+    if stem.util.system.is_mac():
       test.runner.skip(self, '(ticket #6235)')
       return
 
@@ -131,6 +126,7 @@ class TestBaseController(unittest.TestCase):
       for msg_thread in message_threads:
         msg_thread.join()
 
+  @require_controller
   def test_asynchronous_event_handling(self):
     """
     Check that we can both receive asynchronous events while hammering our
@@ -138,9 +134,6 @@ class TestBaseController(unittest.TestCase):
     listeners will still receive all of the enqueued events.
     """
 
-    if test.runner.require_control(self):
-      return
-
     class ControlledListener(stem.control.BaseController):
       """
       Controller that blocks event handling until told to do so.
@@ -189,29 +182,25 @@ class TestBaseController(unittest.TestCase):
         self.assertTrue(re.match('650 BW [0-9]+ [0-9]+\r\n', bw_event.raw_content()))
         self.assertEqual(('650', ' '), bw_event.content()[0][:2])
 
+  @require_controller
   def test_get_latest_heartbeat(self):
     """
     Basic check for get_latest_heartbeat().
     """
 
-    if test.runner.require_control(self):
-      return
-
     # makes a getinfo query, then checks that the heartbeat is close to now
     with test.runner.get_runner().get_tor_socket() as control_socket:
       controller = stem.control.BaseController(control_socket)
       controller.msg('GETINFO version')
       self.assertTrue((time.time() - controller.get_latest_heartbeat()) < 5)
 
+  @require_controller
   def test_status_notifications(self):
     """
     Checks basic functionality of the add_status_listener() and
     remove_status_listener() methods.
     """
 
-    if test.runner.require_control(self):
-      return
-
     state_observer = StateObserver()
 
     with test.runner.get_runner().get_tor_socket(False) as control_socket:
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 4b2cbfc..632c002 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -26,6 +26,8 @@ from stem.control import EventType, Listener, State
 from stem.exit_policy import ExitPolicy
 from stem.version import Requirement
 
+from test.runner import require_controller
+
 # Router status entry for a relay with a nickname other than 'Unnamed'. This is
 # used for a few tests that need to look up a relay.
 
@@ -42,9 +44,6 @@ class TestController(unittest.TestCase):
     Basic sanity check for the from_port constructor.
     """
 
-    if test.runner.require_control(self):
-      return
-
     if test.runner.Torrc.PORT in test.runner.get_runner().get_options():
       with stem.control.Controller.from_port(port = test.runner.CONTROL_PORT) as controller:
         self.assertTrue(isinstance(controller, stem.control.Controller))
@@ -56,23 +55,19 @@ class TestController(unittest.TestCase):
     Basic sanity check for the from_socket_file constructor.
     """
 
-    if test.runner.require_control(self):
-      return
-
     if test.runner.Torrc.SOCKET in test.runner.get_runner().get_options():
       with stem.control.Controller.from_socket_file(path = test.runner.CONTROL_SOCKET_PATH) as controller:
         self.assertTrue(isinstance(controller, stem.control.Controller))
     else:
       self.assertRaises(stem.SocketError, stem.control.Controller.from_socket_file, test.runner.CONTROL_SOCKET_PATH)
 
+  @require_controller
   def test_reset_notification(self):
     """
     Checks that a notificiation listener is... well, notified of SIGHUPs.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_version(self, stem.version.Requirement.EVENT_SIGNAL):
+    if test.runner.require_version(self, stem.version.Requirement.EVENT_SIGNAL):
       return
 
     with test.runner.get_runner().get_tor_controller() as controller:
@@ -105,15 +100,13 @@ class TestController(unittest.TestCase):
 
       controller.reset_conf('__OwningControllerProcess')
 
+  @require_controller
   def test_event_handling(self):
     """
     Add a couple listeners for various events and make sure that they receive
     them. Then remove the listeners.
     """
 
-    if test.runner.require_control(self):
-      return
-
     event_notice1, event_notice2 = threading.Event(), threading.Event()
     event_buffer1, event_buffer2 = [], []
 
@@ -163,15 +156,13 @@ class TestController(unittest.TestCase):
         self.assertTrue(hasattr(event, 'read'))
         self.assertTrue(hasattr(event, 'written'))
 
+  @require_controller
   def test_reattaching_listeners(self):
     """
     Checks that event listeners are re-attached when a controller disconnects
     then reconnects to tor.
     """
 
-    if test.runner.require_control(self):
-      return
-
     event_notice = threading.Event()
     event_buffer = []
 
@@ -243,14 +234,12 @@ class TestController(unittest.TestCase):
         event_notice.wait(4)
         self.assertTrue(len(event_buffer) >= 1)
 
+  @require_controller
   def test_getinfo(self):
     """
     Exercises GETINFO with valid and invalid queries.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -283,14 +272,12 @@ class TestController(unittest.TestCase):
       self.assertEqual({}, controller.get_info([]))
       self.assertEqual({}, controller.get_info([], {}))
 
+  @require_controller
   def test_get_version(self):
     """
     Test that the convenient method get_version() works.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -298,15 +285,13 @@ class TestController(unittest.TestCase):
       self.assertTrue(isinstance(version, stem.version.Version))
       self.assertEqual(version, runner.get_tor_version())
 
+  @require_controller
   def test_get_exit_policy(self):
     """
     Sanity test for get_exit_policy(). We have the default policy (no
     ExitPolicy set) which is a little... long due to the boilerplate.
     """
 
-    if test.runner.require_control(self):
-      return
-
     expected = ExitPolicy(
       'reject 0.0.0.0/8:*',
       'reject 169.254.0.0/16:*',
@@ -343,28 +328,24 @@ class TestController(unittest.TestCase):
       policy_str = policy_str[:public_addr_start] + policy_str[public_addr_end:]
       self.assertEqual(str(expected), policy_str)
 
+  @require_controller
   def test_authenticate(self):
     """
     Test that the convenient method authenticate() works.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller(False) as controller:
       controller.authenticate(test.runner.CONTROL_PASSWORD)
       test.runner.exercise_controller(self, controller)
 
+  @require_controller
   def test_protocolinfo(self):
     """
     Test that the convenient method protocolinfo() works.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller(False) as controller:
@@ -390,14 +371,12 @@ class TestController(unittest.TestCase):
 
       self.assertEqual(tuple(auth_methods), protocolinfo.auth_methods)
 
+  @require_controller
   def test_getconf(self):
     """
     Exercises GETCONF with valid and invalid queries.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -455,15 +434,13 @@ class TestController(unittest.TestCase):
       self.assertEqual({}, controller.get_conf_map('', 'la-di-dah'))
       self.assertEqual({}, controller.get_conf_map([], 'la-di-dah'))
 
+  @require_controller
   def test_hidden_services_conf(self):
     """
     Exercises the hidden service family of methods (get_hidden_service_conf,
     set_hidden_service_conf, create_hidden_service, and remove_hidden_service).
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     test_dir = runner.get_test_dir()
@@ -548,15 +525,13 @@ class TestController(unittest.TestCase):
           except:
             pass
 
+  @require_controller
   def test_set_conf(self):
     """
     Exercises set_conf(), reset_conf(), and set_options() methods with valid
     and invalid requests.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     tmpdir = tempfile.mkdtemp()
 
@@ -623,14 +598,13 @@ class TestController(unittest.TestCase):
 
         shutil.rmtree(tmpdir)
 
+  @require_controller
   def test_loadconf(self):
     """
     Exercises Controller.load_conf with valid and invalid requests.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_version(self, stem.version.Requirement.LOADCONF):
+    if test.runner.require_version(self, stem.version.Requirement.LOADCONF):
       return
 
     runner = test.runner.get_runner()
@@ -664,10 +638,8 @@ class TestController(unittest.TestCase):
         controller.load_conf(oldconf)
         controller.reset_conf('__OwningControllerProcess')
 
+  @require_controller
   def test_saveconf(self):
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     # only testing for success, since we need to run out of disk space to test
@@ -686,14 +658,12 @@ class TestController(unittest.TestCase):
         controller.save_conf()
         controller.reset_conf('__OwningControllerProcess')
 
+  @require_controller
   def test_get_ports(self):
     """
     Test Controller.get_ports against a running tor instance.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -709,14 +679,12 @@ class TestController(unittest.TestCase):
       else:
         self.assertEqual([], controller.get_ports(Listener.CONTROL))
 
+  @require_controller
   def test_get_listeners(self):
     """
     Test Controller.get_listeners against a running tor instance.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -732,27 +700,21 @@ class TestController(unittest.TestCase):
       else:
         self.assertEqual([], controller.get_listeners(Listener.CONTROL))
 
+  @require_controller
   def test_get_socks_listeners(self):
     """
     Test Controller.get_socks_listeners against a running tor instance.
     """
 
-    if test.runner.require_control(self):
-      return
-
-    runner = test.runner.get_runner()
-
-    with runner.get_tor_controller() as controller:
+    with test.runner.get_runner().get_tor_controller() as controller:
       self.assertEqual([('127.0.0.1', 1112)], controller.get_socks_listeners())
 
+  @require_controller
   def test_enable_feature(self):
     """
     Test Controller.enable_feature with valid and invalid inputs.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
 
     with runner.get_tor_controller() as controller:
@@ -778,14 +740,12 @@ class TestController(unittest.TestCase):
       else:
         self.fail()
 
+  @require_controller
   def test_signal(self):
     """
     Test controller.signal with valid and invalid signals.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_controller() as controller:
       # valid signal
       controller.signal('CLEARDNSCACHE')
@@ -793,14 +753,12 @@ class TestController(unittest.TestCase):
       # invalid signals
       self.assertRaises(stem.InvalidArguments, controller.signal, 'FOOBAR')
 
+  @require_controller
   def test_newnym_availability(self):
     """
     Test the is_newnym_available and get_newnym_wait methods.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_controller() as controller:
       self.assertEqual(True, controller.is_newnym_available())
       self.assertEqual(0.0, controller.get_newnym_wait())
@@ -810,10 +768,9 @@ class TestController(unittest.TestCase):
       self.assertEqual(False, controller.is_newnym_available())
       self.assertTrue(controller.get_newnym_wait() > 9.0)
 
+  @require_controller
   def test_extendcircuit(self):
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
     elif test.runner.require_version(self, Requirement.EXTENDCIRCUIT_PATH_OPTIONAL):
       return
@@ -830,14 +787,13 @@ class TestController(unittest.TestCase):
       self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#')
       self.assertRaises(stem.InvalidRequest, controller.extend_circuit, '0', 'thisroutershouldntexistbecausestemexists!@##$%#', 'foo')
 
+  @require_controller
   def test_repurpose_circuit(self):
     """
     Tests Controller.repurpose_circuit with valid and invalid input.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
     elif test.runner.require_version(self, Requirement.EXTENDCIRCUIT_PATH_OPTIONAL):
       return
@@ -857,14 +813,13 @@ class TestController(unittest.TestCase):
       self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, 'f934h9f3h4', 'fooo')
       self.assertRaises(stem.InvalidRequest, controller.repurpose_circuit, '4', 'fooo')
 
+  @require_controller
   def test_close_circuit(self):
     """
     Tests Controller.close_circuit with valid and invalid input.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
     elif test.runner.require_version(self, Requirement.EXTENDCIRCUIT_PATH_OPTIONAL):
       return
@@ -888,14 +843,13 @@ class TestController(unittest.TestCase):
       self.assertRaises(stem.InvalidArguments, controller.close_circuit, circuit_id + '1024')
       self.assertRaises(stem.InvalidRequest, controller.close_circuit, '')
 
+  @require_controller
   def test_get_streams(self):
     """
     Tests Controller.get_streams().
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
 
     host = socket.gethostbyname('www.torproject.org')
@@ -916,14 +870,13 @@ class TestController(unittest.TestCase):
 
     self.assertTrue('%s:%s' % (host, port) in [stream.target for stream in streams])
 
+  @require_controller
   def test_close_stream(self):
     """
     Tests Controller.close_stream with valid and invalid input.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
 
     runner = test.runner.get_runner()
@@ -958,10 +911,9 @@ class TestController(unittest.TestCase):
 
       self.assertRaises(stem.InvalidArguments, controller.close_stream, 'blarg')
 
+  @require_controller
   def test_mapaddress(self):
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
 
     runner = test.runner.get_runner()
@@ -999,14 +951,13 @@ class TestController(unittest.TestCase):
       ip_addr = response[response.find(b'\r\n\r\n'):].strip()
       self.assertTrue(stem.util.connection.is_valid_ipv4_address(stem.util.str_tools._to_unicode(ip_addr)))
 
+  @require_controller
   def test_get_microdescriptor(self):
     """
     Basic checks for get_microdescriptor().
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_version(self, Requirement.MICRODESCRIPTOR_IS_DEFAULT):
+    if test.runner.require_version(self, Requirement.MICRODESCRIPTOR_IS_DEFAULT):
       return
     elif test.runner.require_online(self):
       return
@@ -1028,6 +979,7 @@ class TestController(unittest.TestCase):
 
       self.assertEqual(md_by_fingerprint, md_by_nickname)
 
+  @require_controller
   def test_get_microdescriptors(self):
     """
     Fetches a few descriptors via the get_microdescriptors() method.
@@ -1035,9 +987,7 @@ class TestController(unittest.TestCase):
 
     runner = test.runner.get_runner()
 
-    if test.runner.require_control(self):
-      return
-    elif not os.path.exists(runner.get_test_dir('cached-descriptors')):
+    if not os.path.exists(runner.get_test_dir('cached-descriptors')):
       test.runner.skip(self, '(no cached microdescriptors)')
       return
 
@@ -1051,6 +1001,7 @@ class TestController(unittest.TestCase):
         if count > 10:
           break
 
+  @require_controller
   def test_get_server_descriptor(self):
     """
     Basic checks for get_server_descriptor().
@@ -1058,9 +1009,7 @@ class TestController(unittest.TestCase):
 
     runner = test.runner.get_runner()
 
-    if test.runner.require_control(self):
-      return
-    elif runner.get_tor_version() >= Requirement.MICRODESCRIPTOR_IS_DEFAULT:
+    if runner.get_tor_version() >= Requirement.MICRODESCRIPTOR_IS_DEFAULT:
       test.runner.skip(self, '(requires server descriptors)')
       return
 
@@ -1082,6 +1031,7 @@ class TestController(unittest.TestCase):
 
       self.assertEqual(desc_by_fingerprint, desc_by_nickname)
 
+  @require_controller
   def test_get_server_descriptors(self):
     """
     Fetches a few descriptors via the get_server_descriptors() method.
@@ -1089,9 +1039,7 @@ class TestController(unittest.TestCase):
 
     runner = test.runner.get_runner()
 
-    if test.runner.require_control(self):
-      return
-    elif runner.get_tor_version() >= Requirement.MICRODESCRIPTOR_IS_DEFAULT:
+    if runner.get_tor_version() >= Requirement.MICRODESCRIPTOR_IS_DEFAULT:
       test.runner.skip(self, '(requires server descriptors)')
       return
 
@@ -1110,14 +1058,13 @@ class TestController(unittest.TestCase):
         if count > 10:
           break
 
+  @require_controller
   def test_get_network_status(self):
     """
     Basic checks for get_network_status().
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
 
     with test.runner.get_runner().get_tor_controller() as controller:
@@ -1137,6 +1084,7 @@ class TestController(unittest.TestCase):
 
       self.assertEqual(desc_by_fingerprint, desc_by_nickname)
 
+  @require_controller
   def test_get_network_statuses(self):
     """
     Fetches a few descriptors via the get_network_statuses() method.
@@ -1144,9 +1092,7 @@ class TestController(unittest.TestCase):
 
     runner = test.runner.get_runner()
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
 
     with runner.get_tor_controller() as controller:
@@ -1165,10 +1111,9 @@ class TestController(unittest.TestCase):
         if count > 10:
           break
 
+  @require_controller
   def test_attachstream(self):
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
     elif test.runner.require_version(self, Requirement.EXTENDCIRCUIT_PATH_OPTIONAL):
       return
@@ -1210,14 +1155,13 @@ class TestController(unittest.TestCase):
 
     self.assertEqual(our_stream.circ_id, circuit_id)
 
+  @require_controller
   def test_get_circuits(self):
     """
     Fetches circuits via the get_circuits() method.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_online(self):
+    if test.runner.require_online(self):
       return
     elif test.runner.require_version(self, Requirement.EXTENDCIRCUIT_PATH_OPTIONAL):
       return
@@ -1227,15 +1171,13 @@ class TestController(unittest.TestCase):
       circuits = controller.get_circuits()
       self.assertTrue(new_circ in [circ.id for circ in circuits])
 
+  @require_controller
   def test_transition_to_relay(self):
     """
     Transitions Tor to turn into a relay, then back to a client. This helps to
     catch transition issues such as the one cited in :trac:`14901`.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_controller() as controller:
       self.assertEqual(None, controller.get_conf('OrPort'))
 
diff --git a/test/integ/process.py b/test/integ/process.py
index 6c6be21..7b10828 100644
--- a/test/integ/process.py
+++ b/test/integ/process.py
@@ -21,6 +21,8 @@ import stem.util.tor_tools
 import stem.version
 import test.runner
 
+from test.runner import require_controller
+
 try:
   # added in python 3.3
   from unittest.mock import patch
@@ -44,14 +46,12 @@ class TestProcess(unittest.TestCase):
   def tearDown(self):
     shutil.rmtree(self.data_directory)
 
+  @require_controller
   def test_version_argument(self):
     """
     Check that 'tor --version' matches 'GETINFO version'.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_controller() as controller:
       self.assertEqual('Tor version %s.\n' % controller.get_version(), self.run_tor('--version'))
 
diff --git a/test/integ/response/protocolinfo.py b/test/integ/response/protocolinfo.py
index c9ca97d..b2ecd2c 100644
--- a/test/integ/response/protocolinfo.py
+++ b/test/integ/response/protocolinfo.py
@@ -11,6 +11,7 @@ import stem.util.system
 import stem.version
 import test.runner
 
+from test.runner import require_controller
 from test.integ.util.system import filter_system_call
 
 try:
@@ -21,15 +22,13 @@ except ImportError:
 
 
 class TestProtocolInfo(unittest.TestCase):
+  @require_controller
   def test_parsing(self):
     """
     Makes a PROTOCOLINFO query and processes the response for our control
     connection.
     """
 
-    if test.runner.require_control(self):
-      return
-
     control_socket = test.runner.get_runner().get_tor_socket(False)
     control_socket.send('PROTOCOLINFO 1')
     protocolinfo_response = control_socket.recv()
@@ -45,6 +44,7 @@ class TestProtocolInfo(unittest.TestCase):
 
     self.assert_matches_test_config(protocolinfo_response)
 
+  @require_controller
   @patch('stem.util.proc.is_available', Mock(return_value = False))
   @patch('stem.util.system.is_available', Mock(return_value = True))
   def test_get_protocolinfo_path_expansion(self):
@@ -58,9 +58,6 @@ class TestProtocolInfo(unittest.TestCase):
     with the 'RELATIVE' target.
     """
 
-    if test.runner.require_control(self):
-      return
-
     if test.runner.Torrc.PORT in test.runner.get_runner().get_options():
       lookup_prefixes = (
         stem.util.system.GET_PID_BY_PORT_NETSTAT,
@@ -90,6 +87,7 @@ class TestProtocolInfo(unittest.TestCase):
       self.assertTrue(control_socket.is_alive())
       control_socket.close()
 
+  @require_controller
   def test_multiple_protocolinfo_calls(self):
     """
     Tests making repeated PROTOCOLINFO queries. This use case is interesting
@@ -97,23 +95,18 @@ class TestProtocolInfo(unittest.TestCase):
     re-establish it.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket(False) as control_socket:
       for _ in range(5):
         protocolinfo_response = stem.connection.get_protocolinfo(control_socket)
         self.assert_matches_test_config(protocolinfo_response)
 
+  @require_controller
   def test_pre_disconnected_query(self):
     """
     Tests making a PROTOCOLINFO query when previous use of the socket had
     already disconnected it.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket(False) as control_socket:
       # makes a couple protocolinfo queries outside of get_protocolinfo first
       control_socket.send('PROTOCOLINFO 1')
diff --git a/test/integ/socket/control_message.py b/test/integ/socket/control_message.py
index a2eb6f2..e9faf84 100644
--- a/test/integ/socket/control_message.py
+++ b/test/integ/socket/control_message.py
@@ -9,16 +9,16 @@ import stem.socket
 import stem.version
 import test.runner
 
+from test.runner import require_controller
+
 
 class TestControlMessage(unittest.TestCase):
+  @require_controller
   def test_unestablished_socket(self):
     """
     Checks message parsing when we have a valid but unauthenticated socket.
     """
 
-    if test.runner.require_control(self):
-      return
-
     # If an unauthenticated connection gets a message besides AUTHENTICATE or
     # PROTOCOLINFO then tor will give an 'Authentication required.' message and
     # hang up.
@@ -54,14 +54,12 @@ class TestControlMessage(unittest.TestCase):
     self.assertRaises(stem.SocketClosed, control_socket.send, 'GETINFO version')
     self.assertRaises(stem.SocketClosed, control_socket.recv)
 
+  @require_controller
   def test_invalid_command(self):
     """
     Parses the response for a command which doesn't exist.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       control_socket.send('blarg')
       unrecognized_command_response = control_socket.recv()
@@ -70,14 +68,12 @@ class TestControlMessage(unittest.TestCase):
       self.assertEqual('510 Unrecognized command "blarg"\r\n', unrecognized_command_response.raw_content())
       self.assertEqual([('510', ' ', 'Unrecognized command "blarg"')], unrecognized_command_response.content())
 
+  @require_controller
   def test_invalid_getinfo(self):
     """
     Parses the response for a GETINFO query which doesn't exist.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       control_socket.send('GETINFO blarg')
       unrecognized_key_response = control_socket.recv()
@@ -86,14 +82,12 @@ class TestControlMessage(unittest.TestCase):
       self.assertEqual('552 Unrecognized key "blarg"\r\n', unrecognized_key_response.raw_content())
       self.assertEqual([('552', ' ', 'Unrecognized key "blarg"')], unrecognized_key_response.content())
 
+  @require_controller
   def test_getinfo_config_file(self):
     """
     Parses the 'GETINFO config-file' response.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     torrc_dst = runner.get_torrc_path()
 
@@ -105,14 +99,13 @@ class TestControlMessage(unittest.TestCase):
       self.assertEqual('250-config-file=%s\r\n250 OK\r\n' % torrc_dst, config_file_response.raw_content())
       self.assertEqual([('250', '-', 'config-file=%s' % torrc_dst), ('250', ' ', 'OK')], config_file_response.content())
 
+  @require_controller
   def test_getinfo_config_text(self):
     """
     Parses the 'GETINFO config-text' response.
     """
 
-    if test.runner.require_control(self):
-      return
-    elif test.runner.require_version(self, stem.version.Requirement.GETINFO_CONFIG_TEXT):
+    if test.runner.require_version(self, stem.version.Requirement.GETINFO_CONFIG_TEXT):
       return
 
     runner = test.runner.get_runner()
@@ -150,14 +143,12 @@ class TestControlMessage(unittest.TestCase):
         self.assertTrue('%s\r\n' % torrc_entry in config_text_response.raw_content())
         self.assertTrue('%s' % torrc_entry in config_text_response.content()[0][2])
 
+  @require_controller
   def test_bw_event(self):
     """
     Issues 'SETEVENTS BW' and parses a couple events.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       control_socket.send('SETEVENTS BW')
       setevents_response = control_socket.recv()
diff --git a/test/integ/socket/control_socket.py b/test/integ/socket/control_socket.py
index a45c47e..83851be 100644
--- a/test/integ/socket/control_socket.py
+++ b/test/integ/socket/control_socket.py
@@ -16,16 +16,16 @@ import stem.control
 import stem.socket
 import test.runner
 
+from test.runner import require_controller
+
 
 class TestControlSocket(unittest.TestCase):
+  @require_controller
   def test_connection_time(self):
     """
     Checks that our connection_time method tracks when our state's changed.
     """
 
-    if test.runner.require_control(self):
-      return
-
     test_start = time.time()
     runner = test.runner.get_runner()
 
@@ -54,14 +54,12 @@ class TestControlSocket(unittest.TestCase):
       reconnection_time = control_socket.connection_time()
       self.assertTrue(disconnection_time < reconnection_time <= time.time())
 
+  @require_controller
   def test_send_buffered(self):
     """
     Sends multiple requests before receiving back any of the replies.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     tor_version = runner.get_tor_version()
 
@@ -74,14 +72,12 @@ class TestControlSocket(unittest.TestCase):
         self.assertTrue(str(response).startswith('version=%s' % tor_version))
         self.assertTrue(str(response).endswith('\nOK'))
 
+  @require_controller
   def test_send_closed(self):
     """
     Sends a message after we've closed the connection.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       self.assertTrue(control_socket.is_alive())
       control_socket.close()
@@ -89,6 +85,7 @@ class TestControlSocket(unittest.TestCase):
 
       self.assertRaises(stem.SocketClosed, control_socket.send, 'blarg')
 
+  @require_controller
   def test_send_disconnected(self):
     """
     Sends a message to a socket that has been disconnected by the other end.
@@ -99,9 +96,6 @@ class TestControlSocket(unittest.TestCase):
     call. With a file socket, however, we'll also fail when calling send().
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       control_socket.send('QUIT')
       self.assertEqual('closing connection', str(control_socket.recv()))
@@ -117,14 +111,12 @@ class TestControlSocket(unittest.TestCase):
         self.assertRaises(stem.SocketClosed, control_socket.send, 'blarg')
         self.assertFalse(control_socket.is_alive())
 
+  @require_controller
   def test_recv_closed(self):
     """
     Receives a message after we've closed the connection.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       self.assertTrue(control_socket.is_alive())
       control_socket.close()
@@ -132,15 +124,13 @@ class TestControlSocket(unittest.TestCase):
 
       self.assertRaises(stem.SocketClosed, control_socket.recv)
 
+  @require_controller
   def test_recv_disconnected(self):
     """
     Receives a message from a socket that has been disconnected by the other
     end.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket() as control_socket:
       control_socket.send('QUIT')
       self.assertEqual('closing connection', str(control_socket.recv()))
@@ -153,14 +143,12 @@ class TestControlSocket(unittest.TestCase):
       self.assertRaises(stem.SocketClosed, control_socket.recv)
       self.assertFalse(control_socket.is_alive())
 
+  @require_controller
   def test_connect_repeatedly(self):
     """
     Checks that we can reconnect, use, and disconnect a socket repeatedly.
     """
 
-    if test.runner.require_control(self):
-      return
-
     with test.runner.get_runner().get_tor_socket(False) as control_socket:
       for _ in range(10):
         # this will raise if the PROTOCOLINFO query fails
diff --git a/test/integ/version.py b/test/integ/version.py
index 0cd15eb..8347b35 100644
--- a/test/integ/version.py
+++ b/test/integ/version.py
@@ -9,6 +9,8 @@ import stem.prereq
 import stem.version
 import test.runner
 
+from test.runner import require_controller
+
 
 class TestVersion(unittest.TestCase):
   def test_get_system_tor_version(self):
@@ -32,28 +34,24 @@ class TestVersion(unittest.TestCase):
     # try running against a command that doesn't exist
     self.assertRaises(IOError, stem.version.get_system_tor_version, 'blarg')
 
+  @require_controller
   def test_get_system_tor_version_value(self):
     """
     Checks that the get_system_tor_version() provides the same value as our
     test instance provides.
     """
 
-    if test.runner.require_control(self):
-      return
-
     runner = test.runner.get_runner()
     system_tor_version = stem.version.get_system_tor_version(runner.get_tor_command())
     self.assertEqual(runner.get_tor_version(), system_tor_version)
 
+  @require_controller
   def test_getinfo_version_parsing(self):
     """
     Issues a 'GETINFO version' query to our test instance and makes sure that
     we can parse it.
     """
 
-    if test.runner.require_control(self):
-      return
-
     control_socket = test.runner.get_runner().get_tor_socket()
     control_socket.send('GETINFO version')
     version_response = control_socket.recv()
diff --git a/test/runner.py b/test/runner.py
index 3d0e43b..2f9fbf1 100644
--- a/test/runner.py
+++ b/test/runner.py
@@ -12,7 +12,7 @@ about the tor test instance they're running against.
   TorInaccessable - Tor can't be queried for the information
 
   skip - skips the current test if we can
-  require_control - skips the test unless tor provides a controller endpoint
+  require_controller - skips the test unless tor provides a controller endpoint
   require_version - skips the test unless we meet a tor version requirement
   require_online - skips unless targets allow for online tests
   only_run_once - skip the test if it has been ran before
@@ -117,18 +117,18 @@ def skip(test_case, message):
     test_case.skipTest(message)
 
 
-def require_control(test_case):
+def require_controller(func):
   """
   Skips the test unless tor provides an endpoint for controllers to attach to.
-
-  :param unittest.TestCase test_case: test being ran
-
-  :returns: True if test should be skipped, False otherwise
   """
 
-  if not get_runner().is_accessible():
-    skip(test_case, '(no connection)')
-    return True
+  def wrapped(self, *args, **kwargs):
+    if get_runner().is_accessible():
+      return func(self, *args, **kwargs)
+    else:
+      skip(self, '(no connection)')
+
+  return wrapped
 
 
 def require_version(test_case, req_version):





More information about the tor-commits mailing list