[tor-commits] [ooni-probe/master] Fixed error with reporting results for tcp flags tests, and cleaned up test

isis at torproject.org isis at torproject.org
Tue Dec 18 05:53:46 UTC 2012


commit 9688f83110a7d8986e20d665b93f4acbf6ba40f6
Author: Isis Lovecruft <isis at torproject.org>
Date:   Sun Dec 9 23:13:42 2012 +0000

    Fixed error with reporting results for tcp flags tests, and cleaned up test
    abort code.
---
 nettests/bridge_reachability/tcpsyn.py |  156 ++++++++++++++------------------
 ooni/nettest.py                        |   77 +++++++---------
 ooni/reporter.py                       |    3 +-
 ooni/runner.py                         |   62 ++++++++-----
 4 files changed, 140 insertions(+), 158 deletions(-)

diff --git a/nettests/bridge_reachability/tcpsyn.py b/nettests/bridge_reachability/tcpsyn.py
index 6a4b8db..79b8e52 100644
--- a/nettests/bridge_reachability/tcpsyn.py
+++ b/nettests/bridge_reachability/tcpsyn.py
@@ -16,32 +16,29 @@
 import os
 import sys
 
-from ipaddr                 import IPAddress
 from twisted.python         import usage
 from twisted.python.failure import Failure
-from twisted.internet       import reactor, defer, address
+from twisted.internet       import reactor, defer
 from ooni                   import nettest, config
 from ooni.utils             import net, log
 from ooni.utils.otime       import timestamp
 
 try:
-    from scapy.all          import TCP, IP
+    from scapy.all          import TCP, IP, sr
     from ooni.utils         import txscapy
 except:
     log.msg("This test requires scapy, see www.secdev.org/projects/scapy")
 
 
-class TCPFlagOptions(usage.Options):
+class TCPFlagsOptions(usage.Options):
     """Options for TCPTest."""
     optParameters = [
         ['dst', 'd', None, 'Host IP to ping'],
         ['port', 'p', None, 'Host port'],
-        ['flags', 's', None, 'Comma separated flags to set [S|A|F]'],
+        ['flags', 's', 'S', 'Comma separated flags to set, eg. "SA"'],
         ['count', 'c', 3, 'Number of SYN packets to send', int],
         ['interface', 'i', None, 'Network interface to use'],
-        ['hexdump', 'x', False, 'Show hexdump of responses'],
-        ['pdf', 'y', False,
-         'Create pdf of visual representation of packet conversations']]
+        ['hexdump', 'x', False, 'Show hexdump of responses']]
 
 class TCPFlagsTest(nettest.NetTestCase):
     """
@@ -57,29 +54,28 @@ class TCPFlagsTest(nettest.NetTestCase):
     version      = '0.1.1'
     requiresRoot = True
 
-    usageOptions = TCPFlagOptions
+    usageOptions = TCPFlagsOptions
     inputFile    = ['file', 'f', None, 'File of list of IP:PORTs to ping']
 
+    destinations = {}
+
     def setUp(self, *a, **kw):
         """Configure commandline parameters for TCPSynTest."""
-        self.report = {}
-        self.packets = {'results': [], 'unanswered': []}
-
         if self.localOptions:
             for key, value in self.localOptions.items():
                 setattr(self, key, value)
         if not self.interface:
             try:
                 iface = net.getDefaultIface()
-            except net.IfaceError, ie:
-                log.warn("Could not find a working network interface!")
-                log.fail(ie)
-            else:
                 self.interface = iface
+            except net.IfaceError:
+                self.abortClass("Could not find a working network interface!")
+        if self.flags:
+            self.flags = self.flags.split(',')
         if config.advanced.debug:
             defer.setDebugging('on')
 
-    def addToDestinations(self, addr='0.0.0.0', port='443'):
+    def addToDestinations(self, addr=None, port='443'):
         """
         Validate and add an IP address and port to the dictionary of
         destinations to send to. If the host's IP is already in the
@@ -89,12 +85,16 @@ class TCPFlagsTest(nettest.NetTestCase):
         @param port: A string representing a port number.
         @returns: A 2-tuple containing the address and port.
         """
+        if addr is None:
+            return (None, None) # do we want to return SkipTest?
+
         dst, dport = net.checkIPandPort(addr, port)
-        if not dst in self.report.keys():
-            self.report[dst] = {'dst': dst, 'dport': [dport]}
+        if not dst in self.destinations.keys():
+            self.destinations[dst] = {'dst': dst,
+                                      'dport': [dport]}
         else:
             log.debug("Got additional port for destination.")
-            self.report[dst]['dport'].append(dport)
+            self.destinations[dst]['dport'].append(dport)
         return (dst, dport)
 
     def inputProcessor(self, input_file=None):
@@ -118,7 +118,7 @@ class TCPFlagsTest(nettest.NetTestCase):
                     raw_ip, raw_port = one.rsplit(':', 1)
                     yield self.addToDestinations(raw_ip, raw_port)
 
-    def tcp_flags(self, flags=None):
+    def test_tcp_flags(self):
         """
         Generate, send, and listen for responses to, a list of TCP/IP packets
         to an address and port pair taken from the current input, and a string
@@ -128,11 +128,14 @@ class TCPFlagsTest(nettest.NetTestCase):
             A string representing the TCP flags to be set, i.e. "SA" or "F".
             Defaults to "S".
         """
-        def build_packets(addr, port, flags=None, count=3):
+        def build_packets(addr, port):
             """Construct a list of packets to send out."""
             packets = []
-            for x in xrange(count):
-                packets.append( IP(dst=addr)/TCP(dport=port, flags=flags) )
+            for flag in self.flags:
+                log.debug("Generating packets with %s flags for %s:%d..."
+                          % (flag, addr, port))
+                for x in xrange(self.count):
+                    packets.append( IP(dst=addr)/TCP(dport=port, flags=flag) )
             return packets
 
         def process_packets(packet_list):
@@ -144,72 +147,45 @@ class TCPFlagsTest(nettest.NetTestCase):
             @param packet_list:
                 A :class:scapy.plist.PacketList
             """
+            log.msg("Processing received packets...")
             results, unanswered = packet_list
-            self.packets['results'].append([r for r in results])
-            self.packets['unanswered'].append([u for u in unanswered])
-    
+
             for (q, r) in results:
-                request_data = {'dst': q.dst,
-                                'dport': q.dport,
-                                'summary': q.summary(),
-                                'command': q.command(),
-                                'hexdump': None,
-                                'sent_time': q.time}
-                response_data = {'src': r['IP'].src,
-                                 'flags': r['IP'].flags,
-                                 'summary': r.summary(),
-                                 'command': r.command(),
-                                 'hexdump': None,
-                                 'recv_time': r.time,
-                                 'delay': r.time - q.time}
+                request = {'dst': q.dst,
+                           'dport': q.dport,
+                           'summary': q.summary(),
+                           'hexdump': None,
+                           'sent_time': q.time}
+                response = {'src': r['IP'].src,
+                            'flags': r['IP'].flags,
+                            'summary': r.summary(),
+                            'hexdump': None,
+                            'recv_time': r.time,
+                            'delay': r.time - q.time}
                 if self.hexdump:
-                    request_data.update('hexdump', q.hexdump())
-                    response_data.update('hexdump', r.hexdump())
-
-                for dest, data in self.report.items():
-                    if data['dst'] == response_data['src']:
-                        if not 'reachable' in data:
-                            if self.hexdump:
-                                log.msg("%s\n%s" % (q.hexdump(), r.hexdump()))
-                            else:
-                                log.msg(" Received response:\n%s ==> %s"
-                                        % (q.mysummary(), r.mysummary()))
-                            data.update( {'reachable': True,
-                                          'request': request_data,
-                                          'response': response_data} )
-                            self.report[response_data['src']['data'].update(data)
-
-            if unanswered is not None and len(unanswered) > 0:
-                log.msg("Waiting on responses from\n%s" %
-                        '\n'.join( [unans.summary() for unans in unanswered] ))
-            log.msg("Writing response packet information to report...")
- 
-        (addr, port) = self.input
-        packets = build_packets(addr, port, str(flags), self.count)
-        d = txscapy.sr(packets, iface=self.interface)
-        #d.addCallbacks(process_packets, log.exception)
-        #d.addCallbacks(process_unanswered, log.exception)
-        d.addCallback(process_packets)
-        d.addErrback(process_unanswered)
-
-        return d
-
-    @log.catch
-    def createPDF(self):
-        pdfname = self.name + '_' + timestamp()
-        self.packets['results'].pdfdump(pdfname)
-        log.msg("Visual packet conversation saved to %s.pdf" % pdfname)
-
-    def test_tcp_flags(self):
-        """Send packets with given TCP flags to an address:port pair."""
-        flag_list = self.flags.split(',')
-
-        dl = []
-        for flag in flag_list:
-            dl.append(self.tcp_flags(flag))
-        d = defer.DeferredList(dl)
-
-        if self.pdf:
-            d.addCallback(self.createPDF)
-
-        return d
+                    request['hexdump'] = q.hexdump()
+                    response['hexdump'] = r.hexdump()
+
+                for dest, data in self.destinations.items():
+                    if response['src'] == data['dst']:
+                        log.msg(" Received response from %s:\n%s ==> %s" % (
+                                response['src'], q.mysummary(), r.mysummary()))
+                        if self.hexdump:
+                            log.msg("%s\n%s" % (q.hexdump(), r.hexdump()))
+
+                        self.report['request'] = request
+                        self.report['response'] = response
+
+            if unanswered is not None:
+                unans = [un.summary() for un in unanswered]
+                log.msg(" Waiting on responses from:\n%s" % '\n'.join(unans))
+                self.report['unanswered'] = unans
+
+        try:
+            self.report = {}
+            (addr, port) = self.input
+            pkts = build_packets(addr, port)
+            d = process_packets(sr(pkts, iface=self.interface, timeout=5))
+            return d
+        except Exception, ex:
+            log.exception(ex)
diff --git a/ooni/nettest.py b/ooni/nettest.py
index bd6ef9b..29ced70 100644
--- a/ooni/nettest.py
+++ b/ooni/nettest.py
@@ -12,9 +12,11 @@ import sys
 import os
 import itertools
 import traceback
+import inspect
 
 from twisted.trial import unittest, itrial
 from twisted.trial import util as txtrutil
+from twisted.trial.test import skipping
 from twisted.internet import defer, utils
 from twisted.python import usage
 
@@ -171,59 +173,46 @@ class NetTestCase(object):
     def __repr__(self):
         return "<%s inputs=%s>" % (self.__class__, self.inputs)
 
-    def _abort(self, reason):
-        """
-
-        Abort running the current input. Raises
-        :class:`twisted.trial.test.skipping.SkipTest <SkipTest>` test_method,
-        or test_class. If called with only one argument, assume we're going to
-        ignore the current input. Otherwise, the name of the method or class
-        in relation to the test_instance, i.e. "self" should be given as value
-        for the keyword argument "obj".
+    def _abortMethod(self, reason, method=None):
+        if method is None:
+            test_method = self._testMethod
+        else:
+            test_method = getattr(self.__class__, method, False)
 
-        XXX call oreporter.allDone() from parent stack frame
-        """
-        reason = str(reason)
-        raise SkipTest("%s\n%s" % (str(reason), str(self.input)) )
-
-    def _abortMethod(self, reason, method):
-        if inspect.ismethod(method):
-            abort = getattr(self.__class__, method, False)
-            log.debug("Aborting remaining inputs for %s" % str(abort.func_name))
-            setattr(abort, 'skip', reason)
+        if inspect.ismethod(test_method):
+            method_name = test_method.im_func.func_name
+            setattr(test_method, 'skip', reason)
+            raise skipping.SkipTest("Aborting %s for reason: %s"
+                                    % (method_name, reason) )
         else:
-            log.debug("abortMethod(): could not find method %s" % str(method))
+            log.debug("_abortMethod(): could not find method %s" % test_method)
     
-    @log.catch
-    def _abortClass(self, reason, cls):
-        if not inspect.isclass(obj) or not isTestCase(obj):
-            log.debug("_abortClass() could not find class %s" % str(cls))
-            return
-        abort = getattr(obj, '__class__', self.__class__)
-        log.debug("Aborting %s test" % str(abort.name))
-        setattr(abort, 'skip', reason)
-
-    def abortCurrentInput(self, reason):
+    def abortInput(self, reason):
         """
         Abort the current input.
         
         @param reason: A string explaining why this test is being skipped.
+        @raises: A :class:`twisted.trial.test.skipping.SkipTest <SkipTest>` 
         """
-        return self._abort(reason)
-
-    def abortInput(self, reason):
-        return self._abort(reason)
+        raise skipping.SkipTest(" Reason: %s\nCurrent input: %s"
+                                % (reason, self.input))
 
+    def abortMethod(self, reason, test_method=None):
+        """
+        Abort all remaining inputs for the current test method.
 
-# This needs to be here so that NetTestCase.abort() can call it, since we
-# cannot import runner because runner imports NetTestCase.
-def isTestCase(obj):
-    """
-    Return True if obj is a subclass of NetTestCase, false if otherwise.
-    """
-    try:
-        return issubclass(obj, NetTestCase)
-    except TypeError:
-        return False
+        @param reason: A string explaining why the current test_method is
+                       being skipped.
+        @param test_method: (optional) The test_method to skip, defaults to
+                            the currently running test_method.
+        """
+        return self._abortMethod(reason, test_method)
 
+    def abortClass(self, reason='unspecified'):
+        """
+        Abort the entire NetTestCase class.
 
+        @param reason: A string explaining why the class is being skipped.
+        """
+        log.msg("Aborting %s: %s" % (self.__class__.name, reason))
+        setattr(self.__class__, 'skip', reason)
diff --git a/ooni/reporter.py b/ooni/reporter.py
index 6fdc142..133b98d 100644
--- a/ooni/reporter.py
+++ b/ooni/reporter.py
@@ -218,7 +218,8 @@ class OONIBReporter(OReporter):
             response = yield self.agent.request("PUT", url, 
                                 bodyProducer=bodyProducer)
         except:
-            # XXX we must trap this in the runner and make sure to report the data later.
+            # XXX we must trap this in the runner and make sure to report the
+            # data later.
             raise OONIBReportUpdateFailed
 
         #parsed_response = json.loads(backend_response)
diff --git a/ooni/runner.py b/ooni/runner.py
index 4214360..5969dd5 100644
--- a/ooni/runner.py
+++ b/ooni/runner.py
@@ -16,7 +16,7 @@ import traceback
 import itertools
 
 from twisted.python import reflect, usage, failure
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.trial.runner import filenameToModule
 from twisted.trial import reporter as txreporter
 from twisted.trial import util as txtrutil
@@ -28,6 +28,15 @@ from ooni.inputunit import InputUnitFactory
 from ooni import reporter, nettest
 from ooni.utils import log, checkForRoot, PermissionsError
 
+def isTestCase(obj):
+    """
+    Return True if obj is a subclass of NetTestCase, false if otherwise.
+    """
+    try:
+        return issubclass(obj, nettest.NetTestCase)
+    except TypeError:
+        return False
+
 def processTest(obj, cmd_line_options):
     """
     Process the parameters and :class:`twisted.python.usage.Options` of a
@@ -112,7 +121,7 @@ def findTestClassesFromConfig(cmd_line_options):
 
     module = filenameToModule(filename)
     for name, val in inspect.getmembers(module):
-        if nettest.isTestCase(val):
+        if isTestCase(val):
             classes.append(processTest(val, cmd_line_options))
     return classes
 
@@ -185,13 +194,19 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
             d.errback(timeout_fail)
         except defer.AlreadyCalledError:
             # if the deferred has already been called but the *back chain is
-            # still unfinished, crash the reactor and report the timeout
+            # still unfinished, safely crash the reactor and report the timeout
             reactor.crash()
             test_instance._timedOut = True    # see test_instance._wait
             test_instance._test_result.addExpectedFailure(test_instance, fail)
     test_timeout = txtrutils.suppressWarnings(
         test_timeout, txtrutil.suppress(category=DeprecationWarning))
 
+    def test_skip_class(reason):
+        try:
+            d.errback(failure.Failure(SkipTest("%s" % reason)))
+        except defer.AlreadyCalledError:
+            pass # XXX not sure what to do here...
+
     def test_done(result, test_instance, test_name):
         log.debug("Concluded %s with inputs %s"
                   % (test_name, test_instance.input))
@@ -199,10 +214,7 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
 
     def test_error(error, test_instance, test_name):
         if isinstance(error, SkipTest):
-            if len(error.args) > 0:
-                skip_what = error.args[1]
-                # XXX we'll need to handle methods and classes
-            log.info("%s" % error.message)
+            log.warn("%s" % error.message)
         else:
             log.exception(error)
 
@@ -219,26 +231,30 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
     test_instance._setUp()
     test_instance.setUp()
 
-    test_skip = txtrutil.acquireAttribute(
-        test_instance._parents, 'skip', None)
-    if test_skip is not None:
-        # XXX we'll need to do something more than warn
+    test_skip = txtrutil.acquireAttribute(test_instance._parents, 'skip', None)
+    if test_skip:
         log.warn("%s marked these tests to be skipped: %s"
-                  % (test_instance.name, test_skip))
+                 % (test_instance.name, test_skip))
     skip_list = [test_skip]
 
-    if not test_method in skip_list:
-        test = getattr(test_instance, test_method)
-        d = defer.maybeDeferred(test)
+    test = getattr(test_instance, test_method)
+    test_instance._testMethod = test
+
+    d = defer.maybeDeferred(test)
+
+    # register the timer with the reactor
+    call_timeout = reactor.callLater(test_instance.timeout, test_timeout, d)
+    d.addBoth(lambda x: call_timeout.active() and call_timeout.cancel() or x)
+
+    # check if the class has been aborted
+    if hasattr(test_instance.__class__, 'skip'):
+        reason = getattr(test_instance.__class__, 'skip')
+        call_skip = reactor.callLater(0, test_skip_class, reason)
+        d.addBoth(lambda x: call_skip.active() and call_skip.cancel() or x)
+
+    d.addCallback(test_done, test_instance, test_method)
+    d.addErrback(test_error, test_instance, test_method)
 
-        # register the timer with the reactor
-        call = reactor.callLater(test_instance.timeout, test_timeout, d)
-        d.addBoth(lambda x: call.active() and call.cancel() or x)
-    
-        d.addCallback(test_done, test_instance, test_method)
-        d.addErrback(test_error, test_instance, test_method)
-    else:
-        d = defer.Deferred()
     return d
 
 def runTestWithInputUnit(test_class, test_method, input_unit, oreporter):





More information about the tor-commits mailing list