commit 1029140ad660c12f4395df40baff641b534a062a Author: Isis Lovecruft isis@torproject.org Date: Thu Dec 6 04:29:55 2012 +0000
Refactoring tcp flag test --- nettests/bridge_reachability/tcpsyn.py | 193 +++++++++++++++----------------- ooni/nettest.py | 29 ----- ooni/oonicli.py | 6 +- ooni/reporter.py | 3 +- ooni/runner.py | 82 ++++++-------- 5 files changed, 127 insertions(+), 186 deletions(-)
diff --git a/nettests/bridge_reachability/tcpsyn.py b/nettests/bridge_reachability/tcpsyn.py index 548fc0e..6a4b8db 100644 --- a/nettests/bridge_reachability/tcpsyn.py +++ b/nettests/bridge_reachability/tcpsyn.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- # # +-----------+ -# | tcpsyn.py | +# | tcpflags.py | # +-----------+ -# Send a TCP SYN packet to a test server to check that -# it is reachable. +# Send packets with various TCP flags set to a test server +# to check that it is reachable. # # @authors: Isis Lovecruft, isis@torproject.org # @version: 0.0.1-pre-alpha @@ -36,15 +36,14 @@ class TCPFlagOptions(usage.Options): optParameters = [ ['dst', 'd', None, 'Host IP to ping'], ['port', 'p', None, 'Host port'], + ['flags', 's', None, 'Comma separated flags to set [S|A|F]'], ['count', 'c', 3, 'Number of SYN packets to send', int], ['interface', 'i', None, 'Network interface to use'], ['hexdump', 'x', False, 'Show hexdump of responses'], ['pdf', 'y', False, - 'Create pdf of visual representation of packet conversations'], - ['cerealize', 'z', False, - 'Cerealize scapy objects for further scripting']] + 'Create pdf of visual representation of packet conversations']]
-class TCPFlagTest(nettest.NetTestCase): +class TCPFlagsTest(nettest.NetTestCase): """ Sends only a TCP SYN packet to a host IP:PORT, and waits for either a SYN/ACK, a RST, or an ICMP error. @@ -52,21 +51,19 @@ class TCPFlagTest(nettest.NetTestCase): TCPSynTest can take an input file containing one IP:Port pair per line, or the commandline switches --dst <IP> and --port <PORT> can be used. """ - name = 'TCP Flag' + name = 'TCP Flags' author = 'Isis Lovecruft isis@torproject.org' description = 'A TCP SYN/ACK/FIN test to see if a host is reachable.' - version = '0.0.1' + version = '0.1.1' requiresRoot = True
usageOptions = TCPFlagOptions inputFile = ['file', 'f', None, 'File of list of IP:PORTs to ping']
- #destinations = {} - - @log.catch def setUp(self, *a, **kw): """Configure commandline parameters for TCPSynTest.""" self.report = {} + self.packets = {'results': [], 'unanswered': []}
if self.localOptions: for key, value in self.localOptions.items(): @@ -78,7 +75,6 @@ class TCPFlagTest(nettest.NetTestCase): log.warn("Could not find a working network interface!") log.fail(ie) else: - log.msg("Using system default interface: %s" % iface) self.interface = iface if config.advanced.debug: defer.setDebugging('on') @@ -94,13 +90,10 @@ class TCPFlagTest(nettest.NetTestCase): @returns: A 2-tuple containing the address and port. """ dst, dport = net.checkIPandPort(addr, port) - #if not dst in self.destinations.keys(): if not dst in self.report.keys(): - #self.destinations[dst] = {'dst': dst, 'dport': [dport]} self.report[dst] = {'dst': dst, 'dport': [dport]} else: log.debug("Got additional port for destination.") - #self.destinations[dst]['dport'].append(dport) self.report[dst]['dport'].append(dport) return (dst, dport)
@@ -112,87 +105,20 @@ class TCPFlagTest(nettest.NetTestCase): """ if self.localOptions['dst'] is not None \ and self.localOptions['port'] is not None: - log.debug("processing commandline destination input") + log.debug("Processing commandline destination") yield self.addToDestinations(self.localOptions['dst'], self.localOptions['port']) if input_file and os.path.isfile(input_file): - log.debug("processing input file %s" % input_file) + log.debug("Processing input file %s" % input_file) with open(input_file) as f: for line in f.readlines(): if line.startswith('#'): continue one = line.strip() - raw_ip, raw_port = one.rsplit(':', 1) ## XXX not ipv6 safe! + raw_ip, raw_port = one.rsplit(':', 1) yield self.addToDestinations(raw_ip, raw_port)
- @log.catch - def createPDF(self, results): - pdfname = self.name + '_' + timestamp() - results.pdfdump(pdfname) - log.msg("Visual packet conversation saved to %s.pdf" % pdfname) - - @staticmethod - def build_packets(addr, port, flags=None, count=3): - """Construct a list of packets to send out.""" - packets = [] - for x in xrange(count): - packets.append( IP(dst=addr)/TCP(dport=port, flags=flags) ) - return packets - - @staticmethod - def process_packets(packet_list): - """ - If the source address of packet in :param:packet_list matches one of our input - destinations, then extract some of the information from it to the test report. - - @param packet_list: - A :class:scapy.plist.PacketList - """ - results, unanswered = packet_list - - if self.pdf: - self.createPDF(results) - - for (q, r) in results: - request_data = {'dst': q.dst, - 'dport': q.dport, - 'summary': q.summary(), - 'command': q.command(), - 'sent_time': q.time} - response_data = {'src': r['IP'].src, - 'flags': r['IP'].flags, - 'summary': r.summary(), - 'command': r.command(), - 'recv_time': r.time, - 'delay': r.time - q.time} - if self.hexdump: - request_data.update('hexdump', q.hexdump()) - response_data.update('hexdump', r.hexdump()) - for dest, data in self.destinations.items(): - if data['dst'] == response_data['src']: - if not 'reachable' in data: - if self.hexdump: - log.msg("%s\n%s" % (q.hexdump(), r.hexdump())) - else: - log.msg(" Received response:\n%s ==> %s" - % (q.mysummary(), r.mysummary())) - data.update( {'reachable': True, - 'request': request_data, - 'response': response_data} ) - return unanswered - - @staticmethod - def process_unanswered(unanswered): - """Callback function to process unanswered packets.""" - if unanswered is not None and len(unanswered) > 0: - log.msg("Waiting on responses from\n%s" % - '\n'.join( [unans.summary() for unans in unanswered] )) - log.msg("Writing response packet information to report...") - self.report = (self.destinations) - return self.destinations - - @log.catch - def tcp_flags(self, flags="S"): + def tcp_flags(self, flags=None): """ Generate, send, and listen for responses to, a list of TCP/IP packets to an address and port pair taken from the current input, and a string @@ -202,25 +128,88 @@ class TCPFlagTest(nettest.NetTestCase): A string representing the TCP flags to be set, i.e. "SA" or "F". Defaults to "S". """ + def build_packets(addr, port, flags=None, count=3): + """Construct a list of packets to send out.""" + packets = [] + for x in xrange(count): + packets.append( IP(dst=addr)/TCP(dport=port, flags=flags) ) + return packets + + def process_packets(packet_list): + """ + If the source address of packet in :param:packet_list matches one of + our input destinations, then extract some of the information from it + to the test report. + + @param packet_list: + A :class:scapy.plist.PacketList + """ + results, unanswered = packet_list + self.packets['results'].append([r for r in results]) + self.packets['unanswered'].append([u for u in unanswered]) + + for (q, r) in results: + request_data = {'dst': q.dst, + 'dport': q.dport, + 'summary': q.summary(), + 'command': q.command(), + 'hexdump': None, + 'sent_time': q.time} + response_data = {'src': r['IP'].src, + 'flags': r['IP'].flags, + 'summary': r.summary(), + 'command': r.command(), + 'hexdump': None, + 'recv_time': r.time, + 'delay': r.time - q.time} + if self.hexdump: + request_data.update('hexdump', q.hexdump()) + response_data.update('hexdump', r.hexdump()) + + for dest, data in self.report.items(): + if data['dst'] == response_data['src']: + if not 'reachable' in data: + if self.hexdump: + log.msg("%s\n%s" % (q.hexdump(), r.hexdump())) + else: + log.msg(" Received response:\n%s ==> %s" + % (q.mysummary(), r.mysummary())) + data.update( {'reachable': True, + 'request': request_data, + 'response': response_data} ) + self.report[response_data['src']['data'].update(data) + + if unanswered is not None and len(unanswered) > 0: + log.msg("Waiting on responses from\n%s" % + '\n'.join( [unans.summary() for unans in unanswered] )) + log.msg("Writing response packet information to report...") + (addr, port) = self.input - packets = self.build_packets(addr, port, str(flags), self.count) + packets = build_packets(addr, port, str(flags), self.count) d = txscapy.sr(packets, iface=self.interface) - d.addCallbacks(self.process_packets, log.exception) - d.addCallbacks(self.process_unanswered, log.exception) + #d.addCallbacks(process_packets, log.exception) + #d.addCallbacks(process_unanswered, log.exception) + d.addCallback(process_packets) + d.addErrback(process_unanswered) + return d
- def test_tcp_fin(self): - """Send a list of FIN packets to an address and port pair from inputs.""" - return self.tcp_flags("F") + @log.catch + def createPDF(self): + pdfname = self.name + '_' + timestamp() + self.packets['results'].pdfdump(pdfname) + log.msg("Visual packet conversation saved to %s.pdf" % pdfname) + + def test_tcp_flags(self): + """Send packets with given TCP flags to an address:port pair.""" + flag_list = self.flags.split(',')
- def test_tcp_syn(self): - """Send a list of SYN packets to an address and port pair from inputs.""" - return self.tcp_flags("S") + dl = [] + for flag in flag_list: + dl.append(self.tcp_flags(flag)) + d = defer.DeferredList(dl)
- def test_tcp_synack(self): - """Send a list of SYN/ACK packets to an address and port pair from inputs.""" - return self.tcp_flags("SA") + if self.pdf: + d.addCallback(self.createPDF)
- def test_tcp_ack(self): - """Send a list of SYN packets to an address and port pair from inputs.""" - return self.tcp_flags("A") + return d diff --git a/ooni/nettest.py b/ooni/nettest.py index 1d1477d..bd6ef9b 100644 --- a/ooni/nettest.py +++ b/ooni/nettest.py @@ -171,35 +171,6 @@ class NetTestCase(object): def __repr__(self): return "<%s inputs=%s>" % (self.__class__, self.inputs)
- def _getSkip(self): - return txtrutil.acquireAttribute(self._parents, 'skip', None) - - def _getSkipReason(self, method, skip): - return super(TestCase, self)._getSkipReason(self, method, skip) - - def _getTimeout(self): - """ - Returns the timeout value set on this test. Check on the instance - first, the the class, then the module, then package. As soon as it - finds something with a timeout attribute, returns that. Returns - twisted.trial.util.DEFAULT_TIMEOUT_DURATION if it cannot find - anything. See TestCase docstring for more details. - """ - try: - testMethod = getattr(self, methodName) - except: - testMethod = self.setUp - self._parents = [testMethod, self] - self._parents.extend(txtrutil.getPythonContainers(testMethod)) - timeout = txtrutil.acquireAttribute(self._parents, 'timeout', - txtrutil.DEFAULT_TIMEOUT_DURATION) - try: - return float(timeout) - except (ValueError, TypeError): - warnings.warn("'timeout' attribute needs to be a number.", - category=DeprecationWarning) - return txtrutil.DEFAULT_TIMEOUT_DURATION - def _abort(self, reason): """
diff --git a/ooni/oonicli.py b/ooni/oonicli.py index 3362d06..c64e445 100644 --- a/ooni/oonicli.py +++ b/ooni/oonicli.py @@ -81,7 +81,7 @@ class Options(usage.Options, app.ReactorSelectionMixin):
def testsEnded(*arg, **kw): """You can place here all the post shutdown tasks.""" - log.debug("testsEnded: Finished running all tests") + log.debug("Finished running all tests")
def run(): """Call me to begin testing from a file.""" @@ -133,7 +133,3 @@ def run(): tests_d = runner.runTestCases(test_cases, options, cmd_line_options, yamloo_filename) tests_d.addBoth(testsEnded) - - ## it appears that tests run without this? - #reactor.run() - diff --git a/ooni/reporter.py b/ooni/reporter.py index 193d056..6fdc142 100644 --- a/ooni/reporter.py +++ b/ooni/reporter.py @@ -111,8 +111,7 @@ class OReporter(object): pass
def testDone(self, test, test_name): - log.debug("Finished running %s" % test_name) - log.debug("Writing report") + log.debug("Calling reporter to record results") test_report = dict(test.report)
if isinstance(test.input, packet.Packet): diff --git a/ooni/runner.py b/ooni/runner.py index 2b41d59..4214360 100644 --- a/ooni/runner.py +++ b/ooni/runner.py @@ -18,8 +18,8 @@ import itertools from twisted.python import reflect, usage, failure from twisted.internet import defer from twisted.trial.runner import filenameToModule -from twisted.trial import util as txtrutil from twisted.trial import reporter as txreporter +from twisted.trial import util as txtrutil from twisted.trial.unittest import utils as txtrutils from twisted.trial.unittest import SkipTest from twisted.internet import reactor, threads @@ -144,37 +144,31 @@ def loadTestsAndOptions(classes, cmd_line_options):
return test_cases, options
-def abortTestRun(test_class, warn_err_fail, test_input, oreporter): - """ - Abort the entire test, and record the error, failure, or warning for why - it could not be completed. +def getTimeout(test_instance, test_method): """ - log.warn("Aborting remaining tests for %s" % test_name) + Returns the timeout value set on this test. Check on the instance first, + the the class, then the module, then package. As soon as it finds + something with a timeout attribute, returns that. Returns + twisted.trial.util.DEFAULT_TIMEOUT_DURATION if it cannot find anything.
-def abortTestWasCalled(abort_reason, abort_what, test_class, test_instance, - test_method, test_input, oreporter): + See twisted.trial.unittest.TestCase docstring for more details. """ - XXX - """ - if not abort_what in ['class', 'method', 'input']: - log.warn("__test_abort__() must specify 'class', 'method', or 'input'") - abort_what = 'input' - - if not isinstance(abort_reason, Exception): - abort_reason = Exception(str(abort_reason)) - if abort_what == 'input': - log.msg("%s test requested to abort for input: %s" - % (test_instance.name, test_input)) - d = defer.maybeDeferred(lambda x: object) - - if hasattr(test_instance, "abort_all"): - log.msg("%s test requested to abort all remaining inputs" - % test_instance.name) - #else: - # d = defer.Deferred() - # d.cancel() - # d = abortTestRun(test_class, reason, test_input, oreporter) - + try: + testMethod = getattr(test_instance, test_method) + except: + log.debug("_getTimeout couldn't find self.methodName!") + return txtrutil.DEFAULT_TIMEOUT_DURATION + else: + test_instance._parents = [testMethod, test_instance] + test_instance._parents.extend(txtrutil.getPythonContainers(testMethod)) + timeout = txtrutil.acquireAttribute(test_instance._parents, 'timeout', + txtrutil.DEFAULT_TIMEOUT_DURATION) + try: + return float(timeout) + except (ValueError, TypeError): + warnings.warn("'timeout' attribute needs to be a number.", + category=DeprecationWarning) + return txtrutil.DEFAULT_TIMEOUT_DURATION
def runTestWithInput(test_class, test_method, test_input, oreporter): """ @@ -205,6 +199,9 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
def test_error(error, test_instance, test_name): if isinstance(error, SkipTest): + if len(error.args) > 0: + skip_what = error.args[1] + # XXX we'll need to handle methods and classes log.info("%s" % error.message) else: log.exception(error) @@ -212,32 +209,23 @@ def runTestWithInput(test_class, test_method, test_input, oreporter): test_instance = test_class() test_instance.input = test_input test_instance.report = {} - # XXX TODO - # the twisted.trial.reporter.TestResult is expected by test_timeout(), - # but we should eventually replace it with a stub class + # XXX TODO the twisted.trial.reporter.TestResult is expected by + # test_timeout(), but we should eventually replace it with a stub class test_instance._test_result = txreporter.TestResult() # use this to keep track of the test runtime test_instance._start_time = time.time() - test_instance.timeout = test_instance._getTimeout() + test_instance.timeout = getTimeout(test_instance, test_method) # call setups on the test test_instance._setUp() test_instance.setUp()
- # check that we haven't inherited a skip - test_ignored = txtrutil.acquireAttribute( + test_skip = txtrutil.acquireAttribute( test_instance._parents, 'skip', None) - if test_ignored is not None: + if test_skip is not None: # XXX we'll need to do something more than warn - log.warn("test_skip is %s" % test_ignored) - - # now check our instance for test_methods set to be skipped: - skip_list = test_instance._getSkip() - if skip_list is not None: - log.debug("%s marked these tests to be skipped: %s" - % (test_instance.name, skip_list)) - else: - log.debug("No tests marked as skip") - skip_list = [skip_list] + log.warn("%s marked these tests to be skipped: %s" + % (test_instance.name, test_skip)) + skip_list = [test_skip]
if not test_method in skip_list: test = getattr(test_instance, test_method) @@ -249,10 +237,8 @@ def runTestWithInput(test_class, test_method, test_input, oreporter):
d.addCallback(test_done, test_instance, test_method) d.addErrback(test_error, test_instance, test_method) - log.debug("returning %s input" % test_method) else: d = defer.Deferred() - return d
def runTestWithInputUnit(test_class, test_method, input_unit, oreporter):