commit 1e47a583f7b7ad1cd1f6b24289c1a6f1cd0a9a73
Author: Nick Mathewson <nickm(a)torproject.org>
Date: Thu Dec 29 20:53:49 2016 -0500
Neat feature to help seed fuzzers
To run it, write a trunnel description for what you want to fuzz,
and run python -m trunnel.SeedFuzzer foo.trunnel . The
subdirectories of "fuzzing-inputs" will fill up with strings that
conform to that description, suitable for consumption by afl or
libfuzzer.
---
lib/trunnel/CodeGen.py | 1 +
lib/trunnel/Grammar.py | 2 +
lib/trunnel/SeedFuzzer.py | 549 ++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 552 insertions(+)
diff --git a/lib/trunnel/CodeGen.py b/lib/trunnel/CodeGen.py
index ba33287..35073f9 100644
--- a/lib/trunnel/CodeGen.py
+++ b/lib/trunnel/CodeGen.py
@@ -266,6 +266,7 @@ class Checker(ASTVisitor):
self.structUses[sd.name] = set(sd.contextList)
self.structUsesContexts[sd.name] = set(sd.contextList)
sd.visitChildren(self)
+ sd.constrainedIntFields = set(self.structIntFieldUsage.keys())
self.structFieldNames = None
self.structIntFieldNames = None
self.structIntFieldUsage = None
diff --git a/lib/trunnel/Grammar.py b/lib/trunnel/Grammar.py
index 25dc53a..3608122 100644
--- a/lib/trunnel/Grammar.py
+++ b/lib/trunnel/Grammar.py
@@ -247,6 +247,8 @@ class StructDecl(AST):
# for every field that is used as the length of a SMLenConstrained
# has_leftover_field -- boolean: true iff this struct contains
# an SMLenConstrained.
+ # constrainedIntFields -- set: names of integer fields that
+ # are referenced elsewhere in the structure.
def __init__(self, name, members, contextList=(), isContext=False):
self.name = name
diff --git a/lib/trunnel/SeedFuzzer.py b/lib/trunnel/SeedFuzzer.py
new file mode 100644
index 0000000..369eb88
--- /dev/null
+++ b/lib/trunnel/SeedFuzzer.py
@@ -0,0 +1,549 @@
+#!/usr/bin/python
+"""Use a trunnel input file to generate examples of that file for
+ fuzzing.
+
+ Here's the strategy:
+
+ First, sort all the types topologically so that we consider
+ every type before any type that depends on it.
+
+ Then, for we iterate over each type to make examples of it. We do
+ a recursive descent on the syntax tree, yielding a sequence of
+ (entry, constraint) tuples. The "entry" item is a list whose
+ members are bytestrings or NamedInt objects. The "constraint" item
+ is an instance of Constraint that describes which NamedInt entries
+ must have certain values.
+
+ As we handle each (entry,constraint) tuple, we replace each
+ NamedInt value in the entry with its constrained value, then merge
+ the parts of the entry together. If we haven't seen it before for
+ this type, we save it to disk.
+
+ To avoid combinatorial explosions, we limit the fan-out for each
+ step, and choose different combinatoric strategies depending
+ on the number of items to be considered at once.
+"""
+
+
+import trunnel.CodeGen
+import trunnel.Grammar
+
+import os
+import hashlib
+import random
+
+
+class Constraints(object):
+ """A Constraints object represents a set of constraints on named integer
+ values. It may also represent a 'failed constraint', which is
+ impossible to satisfy.
+ """
+ def __init__(self):
+ pass
+
+ def isFailed(self):
+ """Return true iff this constraint is unsatisfiable."""
+ return False
+
+ def add(self, k, v):
+ """Return a new constraint made by adding the constraint "k=v" to this
+ constrint.
+ """
+ raise NotImplemented()
+
+ def merge(self, other):
+ """Return a (maybe) new constraint made by adding all the constraints
+ in 'other' to this constraint."""
+ raise NotImplemented()
+
+ def apply(self, item):
+ """Given an object that might be a NamedInt or a byte sequence, return
+ a byte sequence obtained by applying this constraint to
+ that item.
+ """
+ if isinstance(item, NamedInt):
+ return item.apply(self)
+ return item
+
+ def getConstraint(self, name):
+ """Return the (integer) value that the integer field 'name'
+ must have, or None if there is no such constraint.
+ """
+ return None
+
+
+class NoConstraints(Constraints):
+ """Represents the absence of any constraints. Use the NIL singleton
+ instead of creating more of this object.
+ """
+ def __init__(self):
+ Constraints.__init__(self)
+
+ def add(self, k, v):
+ # Nothing plus something is something
+ some = SomeConstraints({k: v})
+ return some
+
+ def merge(self, other):
+ # Nothing plus anything is that thing
+ return other
+
+
+NIL = NoConstraints()
+
+
+class FailedConstraint(Constraints):
+ """Represents an unsatisfiable constraint, probably created by setting
+ the same integer to two incompatible values."""
+ def __init__(self):
+ Constraints.__init__(self)
+
+ def isFailed(self):
+ return True
+
+ def add(self, k, v):
+ # Failed can't become any more failed
+ return self
+
+ def merge(self, other):
+ # Failed can't become any more failed
+ return self
+
+ def apply(self, item):
+ # You should never call apply on a failed constraint.
+ assert False
+
+
+FAILED = FailedConstraint()
+
+
+class SomeConstraints(Constraints):
+ """Represents a set of one or more constraints in a key-value dictionary.
+ """
+ def __init__(self, d): # Owns reference to d!
+ Constraints.__init__(self)
+ self._d = d
+
+ def add(self, k, v):
+ try:
+ oldval = self._d[k]
+ except KeyError:
+ # We had no previous value for this, so we can just add it
+ # to our dict.
+ newd = self._d.copy()
+ newd[k] = v
+ return SomeConstraints(newd)
+
+ if oldval == v:
+ # No change, so no need to allocate a new object.
+ return self
+ else:
+ # Incompatible change; we can't satisfy it.
+ return FAILED
+
+ def merge(self, other):
+ if not isinstance(other, SomeConstraints):
+ # 'other' is either NIL or FAILED, which have simple merge rules.
+ return other.merge(self)
+ if len(other._d) < len(self._d):
+ # This function runs in O(len(self._d)), so let's run it
+ # on the shorter item.
+ return other.merge(self)
+
+ newd = self._d.copy()
+ newd.update(other._d)
+ for k, v in self._d.iteritems(): # XXX Here's the inefficient O(n).
+ if newd[k] != v:
+ return FAILED
+ return SomeConstraints(newd)
+
+ def getConstraint(self, name):
+ return self._d.get(name)
+
+
+def constrain(k, v):
+ if k is None:
+ return NIL
+ else:
+ return SomeConstraints({k: v})
+
+
+class NamedInt(object):
+ """Represents an integer object with a name whose value (maybe)
+ depends on some other part of the structure.
+ """
+ def __init__(self, name, width, val=None):
+ self._name = name
+ self._width = width
+ self._val = val
+
+ def withVal(self, val):
+ assert self._val is None
+ return NamedInt(self._name, self._width, val)
+
+ def __len__(self):
+ return self._width
+
+ def apply(self, constraints):
+ val = constraints.getConstraint(self._name)
+ if val is None:
+ val = self._val
+ if val is None:
+ # We expected to have some constraint on this value, but we
+ # didn't. How about 3? 3 is a nice number.
+ val = 3
+ # encode val little-endian in width bytes.
+ return b"".join(chr((val >> (self._width-i)) & 0xff)
+ for i in xrange(1, self._width+1))
+
+
+def findLength(lst):
+ """Given a list of bytestrings and NamedInts, return the total
+ length of all items in the list.
+ """
+ return sum(len(item) for item in lst)
+
+
+def combineExamples(grp, n, maximum=256):
+ """Given a sequence of examples, yield up to 'maxiumum' values built
+ by concatenating n items from the sequence (chosen with
+ replacement).
+
+ If possible, do an exhaustive combination of values. Otherwise,
+ take items randomly.
+
+ """
+ if len(grp) ** n > maximum:
+ # we have to sample.
+ for i in xrange(maximum):
+ result = []
+ for j in xrange(n):
+ result.append(random.choice(grp))
+ yield b"".join(result)
+ return
+ else:
+ for e in combineExhaustively(grp, n):
+ yield e
+
+
+def combineExhaustively(grp, n):
+ """Yield all bytestrings made by concatenating n members of grp
+ (with replacement)."""
+ if n == 0:
+ yield b""
+ elif n == 1:
+ for e in grp:
+ yield e
+ else:
+ for e in grp:
+ for rest in combineExhaustively(grp, n-1):
+ yield e + rest
+
+
+def crossProduct(lol):
+ """Given a list of lists of (entry, constraint) pairs,
+ yield the cross-product of those lists.
+ """
+ if len(lol) == 0:
+ return
+ elif len(lol) == 1:
+ for item, constraint in lol[0]:
+ yield item, constraint
+ else:
+ for item, constraint in lol[0]:
+ for irest, crest in crossProduct(lol[1:]):
+ c2 = constraint.merge(crest)
+ if not c2.isFailed():
+ yield item + irest, c2
+
+
+def explore(lol):
+ """As cross-product, but for cases where we face a much more
+ combinatorically intense list of lists. For this case,
+ we consider the inputs position by position. For each position,
+ we let it vary over all its values, while choosing the simplest
+ value for the other positions that allows it to meet its constraints.
+
+ For example, if the lists had members (a), (x,y,z), (1,2,3), and no
+ constraints, we'd yield: ax1, ax1, ay1, az1, ax1, ax2, ax3.
+ """
+ if len(lol) == 0:
+ return
+ elif len(lol) == 1:
+ for item, constraint in lol[0]:
+ yield item, constraint
+ else:
+ for idx in xrange(len(lol)):
+ for item, constraint in exploreAt(lol, idx):
+ yield item, constraint
+
+
+def findComplying(lol, c):
+ """Find a single value from among crossproduct(lol) complying with c.
+ Return that value and its combined constraints."""
+ if len(lol) == 0:
+ return [], c
+
+ for i, c2 in lol[0]:
+ cboth = c.merge(c2)
+ if cboth.isFailed():
+ continue
+ rest, call = findComplying(lol[1:], cboth)
+ if call.isFailed():
+ continue
+ return rest, call
+
+ return [], FAILED
+
+
+def exploreAt(lol, idx):
+ """Helper for explore."""
+ before = lol[:idx]
+ at = lol[idx]
+ after = lol[idx+1:]
+ for item, constraint in at:
+ pre, c = findComplying(before, constraint)
+ post, c2 = findComplying(after, c)
+ yield pre + item + post, c2
+
+
+def take_n(iterator, n):
+ """Takes an iterator and yields up to the first n items
+ from that iterator."""
+ so_far = 0
+ for item in iterator:
+ so_far += 1
+ if so_far > n:
+ return
+ yield item
+
+
+class CorpusGenerator(trunnel.CodeGen.ASTVisitor):
+ # target_dir -- where to write items
+ # sort_order -- topologically sorted list of structure names
+ # structExamples -- map from structure name to possible
+ # values that we generated for that structure
+ # _expandConst -- helper function that knows how to map constant
+ # names to integers.
+ # _maxFanout -- used to limit the branching factor when running
+ # combinatorically intense generators.
+ # _maxExamples -- maximum number of distinct examples to generate
+ # for each structure
+ # _maxCombinatorics -- when building long sequences, we try a cross-product
+ # approach when it would generate fewer than this many entries.
+ # Otherwise, we try an alternative approach; see explore().
+ def __init__(self, target_dir):
+ trunnel.CodeGen.ASTVisitor.__init__(self)
+ self.target_dir = target_dir
+ self.structExamples = {}
+ self._maxFanout = 128
+ self._maxCombinatorics = 1024
+ self._maxExamples = 1024
+ self._constrainedIntFieldNames = None
+ self._strictFail = False # DOCDOC
+
+ def setChecker(self, ch):
+ self.sort_order = ch.sortedStructs
+ self._expandConst = ch.expandConstant
+
+ def expandConst(self, v):
+ """If v is a constant name, expand it. Otherwise return v."""
+ if isinstance(v, str):
+ return self._expandConst(v)
+ else:
+ return v
+
+ def visitFile(self, f):
+ f.visitChildrenSorted(self.sort_order, self)
+
+ def visitConstDecl(self, cd):
+ pass
+
+ def visitStructDecl(self, sd):
+ self._constrainedIntFieldNames = sd.constrainedIntFields
+ target = os.path.join(self.target_dir, sd.name)
+ if not os.path.exists(target):
+ os.makedirs(target)
+ examples = set()
+ for item in self.enumerateStructValues(sd):
+ if item in examples:
+ continue
+ digest = hashlib.sha256(item).hexdigest()
+ fname = os.path.join(target, digest)
+ print fname
+ with open(fname, 'wb') as f:
+ f.write(item)
+ examples.add(item)
+ if len(examples) >= self._maxExamples:
+ break
+ self.structExamples[sd.name] = sorted(examples, key=len)
+ self._constrainedIntFieldNames = None
+
+ def enumerateStructValues(self, sd):
+ """Helper: yields bytestrings that match a StructDecl."""
+ for members, constraints in self.visitListOfMembers(sd.members):
+ if constraints.isFailed():
+ continue
+ result = b"".join(constraints.apply(m) for m in members)
+ yield result
+
+ def visitSMInteger(self, smi):
+ width = smi.inttype.width
+ ni = NamedInt(smi.name, width // 8)
+ if smi.name in self._constrainedIntFieldNames:
+ # This will be set elsewhere, I hope.
+ yield [ni], NIL
+ elif smi.constraints is None:
+ yield [ni.withVal(0)], NIL
+ yield [ni.withVal((1L << width) - 1)], NIL
+ else:
+ for lo, hi in smi.constraints.ranges:
+ lo = self.expandConst(lo)
+ hi = self.expandConst(hi)
+ yield [ni.withVal(lo)], NIL
+ if lo != hi:
+ yield [ni.withVal(hi)], NIL
+
+ def visitListOfMembers(self, members):
+ results = []
+ n_vals = 1
+ for m in members:
+ results.append(list(take_n(self.visit(m), self._maxFanout)))
+ n_vals *= len(results[-1])
+ if n_vals < self._maxCombinatorics:
+ for i, c in crossProduct(results):
+ yield i, c
+ else:
+ for i, c in explore(results):
+ yield i, c
+
+ # if len(members) == 0:
+ # return
+ # elif len(members) == 1:
+ # for i, c in take_n(self.visit(members[0]), self._maxFanout):
+ # yield i, c
+ # return
+
+ # for i, c in take_n(self.visit(members[0]), self._maxFanout):
+ # for irest, crest in self.visitListOfMembers(members[1:]):
+ # c2 = c.merge(crest)
+ # if not c2.isFailed():
+ # yield i + irest, c2
+
+ def visitSMStruct(self, sms):
+ for e in self.structExamples[sms.structname][:self._maxFanout]:
+ yield [e], NIL
+
+ def visitSMString(self, sms):
+ yield [b"\0"], NIL
+ yield [b"a\0"], NIL
+ yield [b"abc\0"], NIL
+
+ def visitSMFixedArray(self, sma):
+ w = self.expandConst(sma.width)
+ if type(sma.basetype) == str:
+ examples = self.structExamples[sma.basetype]
+ for e in combineExamples(examples, w, self._maxFanout):
+ yield [e], NIL
+ elif str(sma.basetype) == 'char':
+ yield [b"x"*w], NIL
+ yield [b"\xff"*w], NIL
+ else:
+ bitwidth = sma.basetype.width
+ nbytes = w * (bitwidth // 8)
+ yield [b"\0"*nbytes], NIL
+ yield [b"\xff"*nbytes], NIL
+
+ def visitSMVarArray(self, smva):
+ widthfield = smva.widthfield
+ if type(smva.basetype) == str:
+ examples = self.structExamples[smva.basetype]
+ yield [b""], constrain(widthfield, 0)
+ c = constrain(widthfield, 1)
+ for e in examples[:self._maxFanout]:
+ yield [e], c
+ c = constrain(widthfield, 2)
+ for e in combineExamples(examples, 2, self._maxFanout):
+ yield [e], c
+ elif str(smva.basetype) == 'char':
+ yield [b""], constrain(widthfield, 0)
+ yield [b"h"], constrain(widthfield, 1)
+ yield [b"hi"], constrain(widthfield, 2)
+ else:
+ w = smva.basetype.width // 8
+ yield [b""], constrain(widthfield, 0)
+ yield [b"\x00"*w], constrain(widthfield, 1)
+ yield [b"\x00"*w*2], constrain(widthfield, 2)
+
+ def visitSMLenConstrained(self, smlc):
+ varname = smlc.lengthfield
+ assert len(smlc.members) == 1 # XXX limitation
+ for item, constraints in self.visit(smlc.members[0]):
+ c = constraints.add(varname, findLength(item))
+ if not c.isFailed():
+ yield item, c
+
+ def visitSMUnion(self, smu):
+ tagfield = smu.tagfield
+ for m in smu.members:
+ for item, constraints in take_n(
+ self.visitListOfMembers(m.decls), self._maxFanout):
+ if m.is_default:
+ c = constraints
+ else:
+ oneval = m.tagvalue[0][0]
+ c = constraints.add(tagfield, self.expandConst(oneval))
+ if not c.isFailed():
+ yield item, c
+
+ def visitSMFail(self, x):
+ if self._strictFail:
+ return
+ else:
+ yield [b""], NIL
+
+ def visitSMEos(self, x):
+ yield [b""], NIL
+
+ def visitSMIgnore(self, x):
+ yield [b""], NIL
+ yield [b"bla"], NIL
+
+ def visitSMPosition(self, x):
+ yield [b""], NIL
+
+
+def generate_corpus(input_fnames, target_dir):
+ generator = CorpusGenerator(target_dir)
+ for input_fname in input_fnames:
+ inp = open(input_fname, 'r')
+ t = trunnel.Grammar.Lexer().tokenize(inp.read())
+ inp.close()
+ parsed = trunnel.Grammar.Parser().parse(t)
+
+ c = trunnel.CodeGen.Checker()
+ c.visit(parsed)
+
+ generator.setChecker(c)
+ generator.visit(parsed)
+
+
+if __name__ == '__main__':
+ import getopt
+ import sys
+
+ opts, args = getopt.gnu_getopt(sys.argv[1:],
+ "o:",
+ ["output-dir="])
+
+ target_dir = "fuzzing-inputs"
+ for (k, v) in opts:
+ if k in ("-o", "--output-dir"):
+ target_dir = v
+
+ if len(args) == 0:
+ sys.stderr.write("Syntax: python -m trunnel.SeedFuzzer [-o <dir>] "
+ "<fname...>\n")
+ sys.exit(1)
+
+ generate_corpus(args, target_dir)