commit 8bead060681d86cdfa43f35b1b57567e51816aa6 Author: Arturo Filastò art@fuffa.org Date: Wed Feb 27 17:05:39 2013 +0100
Refactoring of NetTestLoader
* Make it clear that calling one of those methods can be extremely dangerous * Kill a bug spotted thanks to unittesting --- ooni/nettest.py | 86 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 40 deletions(-)
diff --git a/ooni/nettest.py b/ooni/nettest.py index 6273d34..6323989 100644 --- a/ooni/nettest.py +++ b/ooni/nettest.py @@ -23,7 +23,10 @@ class NetTestLoader(object):
def __init__(self, options): self.options = options - self.testCases = self.loadNetTest(options['test']) + if 'test_file' in options: + self.loadNetTestFile(options['test_file']) + elif 'test_string' in options: + self.loadNetTestString(options['test_string'])
@property def testDetails(self): @@ -109,7 +112,45 @@ class NetTestLoader(object): assert usage_options == test_class.usageOptions return usage_options
- def loadNetTest(self, net_test_file): + def loadNetTestString(self, net_test_string): + """ + Load NetTest from a string. + WARNING input to this function *MUST* be sanitized and *NEVER* be + untrusted. + Failure to do so will result in code exec. + + net_test_string: + + a string that contains the net test to be run. + """ + net_test_file_object = StringIO(net_test_string) + + ns = {} + test_cases = [] + exec net_test_file_object.read() in ns + for item in ns.itervalues(): + test_cases.extend(self._get_test_methods(item)) + + if not test_cases: + raise NoTestCasesFound + + self.setupTestCases(test_cases) + + def loadNetTestFile(self, net_test_file): + """ + Load NetTest from a file. + """ + test_cases = [] + module = filenameToModule(net_test_file) + for __, item in getmembers(module): + test_cases.extend(self._get_test_methods(item)) + + if not test_cases: + raise NoTestCasesFound + + self.setupTestCases(test_cases) + + def setupTestCases(self, test_cases): """ Creates all the necessary test_cases (a list of tuples containing the NetTestCase (test_class, test_method)) @@ -130,25 +171,10 @@ class NetTestLoader(object): is either a file path or a file like object that will be used to generate the test_cases. """ - test_cases = None - try: - if os.path.isfile(net_test_file): - test_cases = self._loadNetTestFile(net_test_file) - else: - net_test_file = StringIO(net_test_file) - raise TypeError("not a file path") - - except TypeError: - if hasattr(net_test_file, 'read'): - test_cases = self._loadNetTestFromFileObject(net_test_file) - - if not test_cases: - raise NoTestCasesFound - test_class, _ = test_cases[0] self.testVersion = test_class.version self.testName = test_class.name.lower().replace(' ','_') - return test_cases + self.testCases = test_cases
def checkOptions(self): """ @@ -160,7 +186,8 @@ class NetTestLoader(object):
for klass in test_classes: options = self.usageOptions() - options.parseOptions(self.options['subargs']) + options.parseOptions(self.options) + if options: klass.localOptions = options
@@ -175,27 +202,6 @@ class NetTestLoader(object): inputs = [None] klass.inputs = inputs
- def _loadNetTestFromFileObject(self, net_test_string): - """ - Load NetTest from a string - """ - ns = {} - test_cases = [] - exec net_test_string.read() in ns - for item in ns.itervalues(): - test_cases.extend(self._get_test_methods(item)) - return test_cases - - def _loadNetTestFile(self, net_test_file): - """ - Load NetTest from a file - """ - test_cases = [] - module = filenameToModule(net_test_file) - for __, item in getmembers(module): - test_cases.extend(self._get_test_methods(item)) - return test_cases - def _get_test_methods(self, item): """ Look for test_ methods in subclasses of NetTestCase