commit 879255e21821087a438f418b79e8ad7977832797 Author: Nick Mathewson nickm@torproject.org Date: Wed Oct 8 10:42:02 2014 -0400
Add support for remembering a position within the input stream --- doc/trunnel.md | 22 +++++++++ lib/trunnel/CodeGen.py | 37 ++++++++++++++ lib/trunnel/Grammar.py | 29 ++++++++++- test/Makefile | 6 +++ test/c/test.c | 1 + test/c/test.h | 1 + test/c/test_positions.c | 111 ++++++++++++++++++++++++++++++++++++++++++ test/failing/badptr.trunnel | 5 ++ test/valid/positions.trunnel | 10 ++++ 9 files changed, 221 insertions(+), 1 deletion(-)
diff --git a/doc/trunnel.md b/doc/trunnel.md index 2ffbab8..eaf768a 100644 --- a/doc/trunnel.md +++ b/doc/trunnel.md @@ -281,6 +281,28 @@ In newly constructed structures, all variable-length arrays are empty. It's an error to try to encode a variable-length array with a length field if that array's length field doesn't match its actual length.
+### Structure members: zero-length indices into the input + +Sometimes you need to record the position in the input the corresponds to +a position in the structure. You can use an `@ptr` field to record +a position within a structure when parsing it: + + struct s { + nulterm unsigned_header; + @ptr start_of_signed_material; + u32 bodylen; + u8 body[bodylen]; + u64 flags; + @ptr end_of_signed_material; + u16 signature_len; + u8 signature[signature_len]; + } + +When an object of this type is parsed, then `start_of_signed_material` +and `end_of_signed_material` will get set to pointers into the input. +These pointers are only set when the input is parsed; you don't need +to set them to encode the object. + ### Structure members: unions
You can specify that different elements should be parsed based on some diff --git a/lib/trunnel/CodeGen.py b/lib/trunnel/CodeGen.py index 5867fd0..5a5f070 100644 --- a/lib/trunnel/CodeGen.py +++ b/lib/trunnel/CodeGen.py @@ -338,6 +338,9 @@ class Checker(ASTVisitor): def visitSMString(self, sms): self.addMemberName(sms.name)
+ def visitSMPosition(self, smp): + self.addMemberName(smp.name) + def visitSMLenConstrained(self, sml): if sml.lengthfield != None: self.checkIntField( @@ -544,6 +547,9 @@ class Annotator(ASTVisitor): def visitSMString(self, ss): self.annotateMember(ss)
+ def visitSMPosition(self, smp): + self.annotateMember(smp) + def visitSMLenConstrained(self, sml): sml.lengthfieldmember = None if sml.lengthfield is not None: @@ -781,6 +787,11 @@ class DeclarationGenerationVisitor(CodeGenerator):
self.w("char *%s;\n" % (ss.c_name))
+ def visitSMPosition(self, smp): + if smp.annotation != None: + self.w(smp.annotation) + self.w("const uint8_t *%s;\n" % smp.c_name) + def visitSMLenConstrained(self, sml): sml.visitChildren(self)
@@ -1072,6 +1083,9 @@ class FreeFnGenerator(CodeGenerator): self.w("trunnel_wipestr(obj->%s);\n" % (ss.c_name)) self.w("trunnel_free(obj->%s);\n" % (ss.c_name))
+ def visitSMPosition(self, smp): + pass + def visitSMLenConstrained(self, sml): sml.visitChildren(self)
@@ -1579,6 +1593,17 @@ class AccessorFnGenerator(CodeGenerator): return 0; }}""", c_name=sms.c_name)
+ def visitSMPosition(self, smp): + st = self.structName + nm = smp.c_fn_name + self.docstring("Return the position for %s when we parsed " + "this object"%nm) + self.declaration("const uint8_t *", + "%s_get_%s(const %s_t *inp)" % (st,nm,st)) + self.format(""" + {{ + return inp->{nm}; + }}""", nm = smp.c_name)
def iterateOverFixedArray(generator, sfa, body, extraDecl=""): """Helper: write the code needed to iterate over every element of a @@ -1739,6 +1764,9 @@ class CheckFnGenerator(CodeGenerator): self.w('if (NULL == obj->%s)\n return "Missing %s";\n' % (ss.c_name, ss.c_name))
+ def visitSMPosition(self, smp): + pass + def visitSMLenConstrained(self, sml): # To check a SMlenConstrained, check its children. sml.visitChildren(self) @@ -1872,6 +1900,9 @@ class EncodedLenFnGenerator(CodeGenerator): self.eltHeader(ss) self.w("result += strlen(obj->%s) + 1;\n" % ss.c_name)
+ def visitSMPosition(self, smp): + pass + def visitSMLenConstrained(self, sml): sml.visitChildren(self)
@@ -2191,6 +2222,9 @@ class EncodeFnGenerator(CodeGenerator): ptr += len + 1; written += len + 1; }}""", c_name=ss.c_name)
+ def visitSMPosition(self, smp): + pass + def visitSMLenConstrained(self, sml): # To encode a length-constained field of a structure, # remember the position at which we began writing to the union. @@ -2643,6 +2677,9 @@ class ParseFnGenerator(CodeGenerator): remaining -= memlen; ptr += memlen; }}""", c_name=ss.c_name, truncated=self.truncatedLabel)
+ def visitSMPosition(self, smp): + self.format("obj->{c_name} = ptr;", c_name=smp.c_name); + def visitSMLenConstrained(self, sml): # To parse a length-constrained region, make sure that at # least that many bytes remain in the structure. Then, diff --git a/lib/trunnel/Grammar.py b/lib/trunnel/Grammar.py index 0de7fa7..25dc53a 100644 --- a/lib/trunnel/Grammar.py +++ b/lib/trunnel/Grammar.py @@ -113,7 +113,7 @@ class Lexer(trunnel.spark.GenericScanner, object): trunnel.spark.GenericScanner.tokenize(self, input) return self.rv
- @pattern(r"(?:[;{}[]-=,:]|...|..|.)") + @pattern(r"(?:[;{}@[]-=,:]|...|..|.)") def t_punctuation(self, s): self.rv.append(Token(s, self.lineno))
@@ -570,6 +570,15 @@ class SMIgnore(StructMember): ignored.""" pass
+class SMPosition(StructMember): + """ A struct member: notes that we should store a pointer to this point + in the input when we """ + def __init__(self, name): + StructMember.__init__(self, name) + + def __str__(self): + return "@" + self.name +
class IDReference(AST):
@@ -778,6 +787,10 @@ class Parser(trunnel.spark.GenericParser, object): def p_StructMember_4(self, info): return info[0]
+ @rule(" StructMember ::= SMPosition ") + def p_StructMember_5(self, info): + return info[0] + @rule(" SMInteger ::= IntType ID OptIntConstraint ") def p_SMInteger(self, info): return SMInteger(info[0], str(info[1]), info[2]) @@ -997,6 +1010,10 @@ class Parser(trunnel.spark.GenericParser, object): def p_UnionField_5(self, info): return info[0]
+ @rule(" UnionField ::= SMSPosition ") + def p_UnionField_6(self, info): + return info[0] + @rule(" ContextDecl ::= context ID { ContextMembers } ") def p_ContextDecl(self, info): return StructDecl(str(info[1]), info[3], isContext=True) @@ -1017,6 +1034,16 @@ class Parser(trunnel.spark.GenericParser, object): def p_ContextMember(self, info): return SMInteger(info[0], str(info[1]), None)
+ @rule(" SMPosition ::= @ PtrKW ID ") + def p_SMPosition(self, info): + return SMPosition(str(info[2])) + + @rule(" PtrKW ::= ID ") + def p_PtrKW(self, info): + if str(info[0]) != 'ptr': + raise SyntaxError("Expected 'ptr' at %s" % info[0].lineno) + return None + if __name__ == '__main__': print ("===== Here is our actual grammar, extracted from Grammar.py\n")
diff --git a/test/Makefile b/test/Makefile index ba70e8d..603526f 100644 --- a/test/Makefile +++ b/test/Makefile @@ -31,6 +31,7 @@ TEST_OBJS = \ c/test_contexts_varsize2.o \ c/test_contexts_complex.o \ c/test_remainder_repeats.o \ + c/test_positions.o \ c/test_util.o
BOILERPLATE_FILES=\ @@ -44,6 +45,7 @@ OBJS=tinytest/tinytest.o \ valid/opaque.o \ valid/leftover.o \ valid/contexts.o \ + valid/positions.o \ ./include/trunnel.o \ $(TEST_OBJS)
@@ -69,6 +71,7 @@ valid/derived.o: valid/derived.h valid/derived.c valid/opaque.o: valid/opaque.h valid/opaque.h valid/leftover.o: valid/leftover.h valid/leftover.c valid/contexts.o: valid/contexts.h +valid/positions.o: valid/positions.h $(TEST_OBJS) : tinytest/tinytest.h tinytest/tinytest_macros.h valid/simple.h valid/derived.h $(OBJS) : include/trunnel.h include/trunnel-impl.h tinytest/tinytest.o: tinytest/tinytest.h tinytest/tinytest_macros.h @@ -88,5 +91,8 @@ valid/leftover.c valid/leftover.h: valid/leftover.trunnel ../lib/trunnel/*py valid/contexts.c valid/contexts.h: valid/contexts.trunnel ../lib/trunnel/*py PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/contexts.trunnel
+valid/positions.c valid/positions.h: valid/positions.trunnel ../lib/trunnel/*py + PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/positions.trunnel + $(BOILERPLATE_FILES): ../lib/trunnel/*py ../lib/trunnel/data/*.[ch] PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel --target-dir=./include --write-c-files diff --git a/test/c/test.c b/test/c/test.c index 534ae69..cd9c803 100644 --- a/test/c/test.c +++ b/test/c/test.c @@ -21,6 +21,7 @@ struct testgroup_t test_groups[] = { { "contexts/varsize/", contexts_varsize_tests }, { "contexts/varsize2/", contexts_varsize2_tests }, { "contexts/complex/", contexts_complex_tests }, + { "positions/", positions_tests }, END_OF_GROUPS, };
diff --git a/test/c/test.h b/test/c/test.h index 36fbde8..44242e1 100644 --- a/test/c/test.h +++ b/test/c/test.h @@ -31,6 +31,7 @@ extern struct testcase_t contexts_uniontag_tests[]; extern struct testcase_t contexts_varsize_tests[]; extern struct testcase_t contexts_varsize2_tests[]; extern struct testcase_t contexts_complex_tests[]; +extern struct testcase_t positions_tests[];
ssize_t unhex(uint8_t *out, size_t outlen, const char *in); const uint8_t *ux(const char *in); diff --git a/test/c/test_positions.c b/test/c/test_positions.c new file mode 100644 index 0000000..60ede17 --- /dev/null +++ b/test/c/test_positions.c @@ -0,0 +1,111 @@ +#include "test.h" + +#include "valid/positions.h" + +static void +test_pos_invalid(void *arg) +{ + haspos_t *hp = NULL, *hp2 = NULL; + uint8_t buf[64]; + int i; + (void)arg; + + /* Encode invalid */ + tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), NULL)); + hp = haspos_new(); + tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp)); + haspos_set_s1(hp, "Foo"); + tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp)); + haspos_set_s2(hp, "Bar"); + tt_int_op(12, ==, haspos_encode(buf, sizeof(buf), hp)); + + /* Encode truncated */ + for (i = 0; i < 12; ++i) + tt_int_op(-2, ==, haspos_encode(buf, i, hp)); + tt_int_op(12, ==, haspos_encode(buf, 12, hp)); + + /* Parse truncated */ + for (i = 0; i < 12; ++i) + tt_int_op(-2, ==, haspos_parse(&hp2, buf, i)); + + end: + haspos_free(hp); +} + +static void +test_pos_encdec(void *arg) +{ + haspos_t *hp = NULL, *hp2 = NULL; + uint8_t buf[64]; + (void)arg; + + hp = haspos_new(); + haspos_set_s1(hp, "hello"); + haspos_set_s2(hp, "world"); + haspos_set_x(hp, 3); + tt_int_op(16, ==, haspos_encode(buf, sizeof(buf), hp)); + tt_mem_op("hello\0world\0\0\0\0\x03", ==, buf, 16); + + tt_int_op(16, ==, haspos_parse(&hp2, buf, sizeof(buf))); + tt_str_op("hello", ==, haspos_get_s1(hp2)); + tt_str_op("world", ==, haspos_get_s2(hp2)); + tt_int_op(3, ==, haspos_get_x(hp)); + tt_ptr_op(buf + 6, ==, haspos_get_pos1(hp2)); + tt_ptr_op(buf + 12, ==, haspos_get_pos2(hp2)); + + end: + haspos_free(hp); + haspos_free(hp2); +} + +static void +test_pos_allocfail(void *arg) +{ +#ifdef ALLOCFAIL + haspos_t *hp = NULL; + const uint8_t inp[] = "hello\0world\0\0\0\0\x03"; + uint8_t buf[32]; + (void)arg; + + set_alloc_fail(1); + tt_ptr_op(NULL, ==, haspos_new()); + + set_alloc_fail(1); + tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp))); + + set_alloc_fail(2); + tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp))); + + set_alloc_fail(3); + tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp))); + + hp = haspos_new(); + tt_assert(hp); + set_alloc_fail(1); + tt_int_op(-1, ==, haspos_set_s1(hp,"Hi")); + tt_int_op(0, ==, haspos_set_s1(hp,"Hi")); + + set_alloc_fail(1); + tt_int_op(-1, ==, haspos_set_s2(hp,"Hi")); + tt_int_op(0, ==, haspos_set_s2(hp,"Hi")); + + tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp)); + haspos_clear_errors(hp); + tt_int_op(10, ==, haspos_encode(buf, sizeof(buf), hp)); + + haspos_free(hp); hp = NULL; + + end: + haspos_free(hp); +#else + (void)arg; + tt_skip(); +#endif +} + +struct testcase_t positions_tests[] = { + { "invalid", test_pos_invalid, 0, NULL, NULL }, + { "encdec", test_pos_encdec, 0, NULL, NULL }, + { "allocfail", test_pos_allocfail, 0, NULL, NULL }, + END_OF_TESTCASES +}; diff --git a/test/failing/badptr.trunnel b/test/failing/badptr.trunnel new file mode 100644 index 0000000..14e55da --- /dev/null +++ b/test/failing/badptr.trunnel @@ -0,0 +1,5 @@ + +struct s { + nulterm s; + @foo bar; +} \ No newline at end of file diff --git a/test/valid/positions.trunnel b/test/valid/positions.trunnel new file mode 100644 index 0000000..497ebb0 --- /dev/null +++ b/test/valid/positions.trunnel @@ -0,0 +1,10 @@ + +struct haspos { + nulterm s1; + /** Position right after the first NUL. */ + @ptr pos1; + nulterm s2; + @ptr pos2; + u32 x; +} +