commit 879255e21821087a438f418b79e8ad7977832797
Author: Nick Mathewson <nickm(a)torproject.org>
Date: Wed Oct 8 10:42:02 2014 -0400
Add support for remembering a position within the input stream
---
doc/trunnel.md | 22 +++++++++
lib/trunnel/CodeGen.py | 37 ++++++++++++++
lib/trunnel/Grammar.py | 29 ++++++++++-
test/Makefile | 6 +++
test/c/test.c | 1 +
test/c/test.h | 1 +
test/c/test_positions.c | 111 ++++++++++++++++++++++++++++++++++++++++++
test/failing/badptr.trunnel | 5 ++
test/valid/positions.trunnel | 10 ++++
9 files changed, 221 insertions(+), 1 deletion(-)
diff --git a/doc/trunnel.md b/doc/trunnel.md
index 2ffbab8..eaf768a 100644
--- a/doc/trunnel.md
+++ b/doc/trunnel.md
@@ -281,6 +281,28 @@ In newly constructed structures, all variable-length arrays are empty.
It's an error to try to encode a variable-length array with a length field if
that array's length field doesn't match its actual length.
+### Structure members: zero-length indices into the input
+
+Sometimes you need to record the position in the input the corresponds to
+a position in the structure. You can use an `@ptr` field to record
+a position within a structure when parsing it:
+
+ struct s {
+ nulterm unsigned_header;
+ @ptr start_of_signed_material;
+ u32 bodylen;
+ u8 body[bodylen];
+ u64 flags;
+ @ptr end_of_signed_material;
+ u16 signature_len;
+ u8 signature[signature_len];
+ }
+
+When an object of this type is parsed, then `start_of_signed_material`
+and `end_of_signed_material` will get set to pointers into the input.
+These pointers are only set when the input is parsed; you don't need
+to set them to encode the object.
+
### Structure members: unions
You can specify that different elements should be parsed based on some
diff --git a/lib/trunnel/CodeGen.py b/lib/trunnel/CodeGen.py
index 5867fd0..5a5f070 100644
--- a/lib/trunnel/CodeGen.py
+++ b/lib/trunnel/CodeGen.py
@@ -338,6 +338,9 @@ class Checker(ASTVisitor):
def visitSMString(self, sms):
self.addMemberName(sms.name)
+ def visitSMPosition(self, smp):
+ self.addMemberName(smp.name)
+
def visitSMLenConstrained(self, sml):
if sml.lengthfield != None:
self.checkIntField(
@@ -544,6 +547,9 @@ class Annotator(ASTVisitor):
def visitSMString(self, ss):
self.annotateMember(ss)
+ def visitSMPosition(self, smp):
+ self.annotateMember(smp)
+
def visitSMLenConstrained(self, sml):
sml.lengthfieldmember = None
if sml.lengthfield is not None:
@@ -781,6 +787,11 @@ class DeclarationGenerationVisitor(CodeGenerator):
self.w("char *%s;\n" % (ss.c_name))
+ def visitSMPosition(self, smp):
+ if smp.annotation != None:
+ self.w(smp.annotation)
+ self.w("const uint8_t *%s;\n" % smp.c_name)
+
def visitSMLenConstrained(self, sml):
sml.visitChildren(self)
@@ -1072,6 +1083,9 @@ class FreeFnGenerator(CodeGenerator):
self.w("trunnel_wipestr(obj->%s);\n" % (ss.c_name))
self.w("trunnel_free(obj->%s);\n" % (ss.c_name))
+ def visitSMPosition(self, smp):
+ pass
+
def visitSMLenConstrained(self, sml):
sml.visitChildren(self)
@@ -1579,6 +1593,17 @@ class AccessorFnGenerator(CodeGenerator):
return 0;
}}""", c_name=sms.c_name)
+ def visitSMPosition(self, smp):
+ st = self.structName
+ nm = smp.c_fn_name
+ self.docstring("Return the position for %s when we parsed "
+ "this object"%nm)
+ self.declaration("const uint8_t *",
+ "%s_get_%s(const %s_t *inp)" % (st,nm,st))
+ self.format("""
+ {{
+ return inp->{nm};
+ }}""", nm = smp.c_name)
def iterateOverFixedArray(generator, sfa, body, extraDecl=""):
"""Helper: write the code needed to iterate over every element of a
@@ -1739,6 +1764,9 @@ class CheckFnGenerator(CodeGenerator):
self.w('if (NULL == obj->%s)\n return "Missing %s";\n' %
(ss.c_name, ss.c_name))
+ def visitSMPosition(self, smp):
+ pass
+
def visitSMLenConstrained(self, sml):
# To check a SMlenConstrained, check its children.
sml.visitChildren(self)
@@ -1872,6 +1900,9 @@ class EncodedLenFnGenerator(CodeGenerator):
self.eltHeader(ss)
self.w("result += strlen(obj->%s) + 1;\n" % ss.c_name)
+ def visitSMPosition(self, smp):
+ pass
+
def visitSMLenConstrained(self, sml):
sml.visitChildren(self)
@@ -2191,6 +2222,9 @@ class EncodeFnGenerator(CodeGenerator):
ptr += len + 1; written += len + 1;
}}""", c_name=ss.c_name)
+ def visitSMPosition(self, smp):
+ pass
+
def visitSMLenConstrained(self, sml):
# To encode a length-constained field of a structure,
# remember the position at which we began writing to the union.
@@ -2643,6 +2677,9 @@ class ParseFnGenerator(CodeGenerator):
remaining -= memlen; ptr += memlen;
}}""", c_name=ss.c_name, truncated=self.truncatedLabel)
+ def visitSMPosition(self, smp):
+ self.format("obj->{c_name} = ptr;", c_name=smp.c_name);
+
def visitSMLenConstrained(self, sml):
# To parse a length-constrained region, make sure that at
# least that many bytes remain in the structure. Then,
diff --git a/lib/trunnel/Grammar.py b/lib/trunnel/Grammar.py
index 0de7fa7..25dc53a 100644
--- a/lib/trunnel/Grammar.py
+++ b/lib/trunnel/Grammar.py
@@ -113,7 +113,7 @@ class Lexer(trunnel.spark.GenericScanner, object):
trunnel.spark.GenericScanner.tokenize(self, input)
return self.rv
- @pattern(r"(?:[;{}\[\]\-=,:]|\.\.\.|\.\.|\.)")
+ @pattern(r"(?:[;{}@\[\]\-=,:]|\.\.\.|\.\.|\.)")
def t_punctuation(self, s):
self.rv.append(Token(s, self.lineno))
@@ -570,6 +570,15 @@ class SMIgnore(StructMember):
ignored."""
pass
+class SMPosition(StructMember):
+ """ A struct member: notes that we should store a pointer to this point
+ in the input when we """
+ def __init__(self, name):
+ StructMember.__init__(self, name)
+
+ def __str__(self):
+ return "@" + self.name
+
class IDReference(AST):
@@ -778,6 +787,10 @@ class Parser(trunnel.spark.GenericParser, object):
def p_StructMember_4(self, info):
return info[0]
+ @rule(" StructMember ::= SMPosition ")
+ def p_StructMember_5(self, info):
+ return info[0]
+
@rule(" SMInteger ::= IntType ID OptIntConstraint ")
def p_SMInteger(self, info):
return SMInteger(info[0], str(info[1]), info[2])
@@ -997,6 +1010,10 @@ class Parser(trunnel.spark.GenericParser, object):
def p_UnionField_5(self, info):
return info[0]
+ @rule(" UnionField ::= SMSPosition ")
+ def p_UnionField_6(self, info):
+ return info[0]
+
@rule(" ContextDecl ::= context ID { ContextMembers } ")
def p_ContextDecl(self, info):
return StructDecl(str(info[1]), info[3], isContext=True)
@@ -1017,6 +1034,16 @@ class Parser(trunnel.spark.GenericParser, object):
def p_ContextMember(self, info):
return SMInteger(info[0], str(info[1]), None)
+ @rule(" SMPosition ::= @ PtrKW ID ")
+ def p_SMPosition(self, info):
+ return SMPosition(str(info[2]))
+
+ @rule(" PtrKW ::= ID ")
+ def p_PtrKW(self, info):
+ if str(info[0]) != 'ptr':
+ raise SyntaxError("Expected 'ptr' at %s" % info[0].lineno)
+ return None
+
if __name__ == '__main__':
print ("===== Here is our actual grammar, extracted from Grammar.py\n")
diff --git a/test/Makefile b/test/Makefile
index ba70e8d..603526f 100644
--- a/test/Makefile
+++ b/test/Makefile
@@ -31,6 +31,7 @@ TEST_OBJS = \
c/test_contexts_varsize2.o \
c/test_contexts_complex.o \
c/test_remainder_repeats.o \
+ c/test_positions.o \
c/test_util.o
BOILERPLATE_FILES=\
@@ -44,6 +45,7 @@ OBJS=tinytest/tinytest.o \
valid/opaque.o \
valid/leftover.o \
valid/contexts.o \
+ valid/positions.o \
./include/trunnel.o \
$(TEST_OBJS)
@@ -69,6 +71,7 @@ valid/derived.o: valid/derived.h valid/derived.c
valid/opaque.o: valid/opaque.h valid/opaque.h
valid/leftover.o: valid/leftover.h valid/leftover.c
valid/contexts.o: valid/contexts.h
+valid/positions.o: valid/positions.h
$(TEST_OBJS) : tinytest/tinytest.h tinytest/tinytest_macros.h valid/simple.h valid/derived.h
$(OBJS) : include/trunnel.h include/trunnel-impl.h
tinytest/tinytest.o: tinytest/tinytest.h tinytest/tinytest_macros.h
@@ -88,5 +91,8 @@ valid/leftover.c valid/leftover.h: valid/leftover.trunnel ../lib/trunnel/*py
valid/contexts.c valid/contexts.h: valid/contexts.trunnel ../lib/trunnel/*py
PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/contexts.trunnel
+valid/positions.c valid/positions.h: valid/positions.trunnel ../lib/trunnel/*py
+ PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/positions.trunnel
+
$(BOILERPLATE_FILES): ../lib/trunnel/*py ../lib/trunnel/data/*.[ch]
PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel --target-dir=./include --write-c-files
diff --git a/test/c/test.c b/test/c/test.c
index 534ae69..cd9c803 100644
--- a/test/c/test.c
+++ b/test/c/test.c
@@ -21,6 +21,7 @@ struct testgroup_t test_groups[] = {
{ "contexts/varsize/", contexts_varsize_tests },
{ "contexts/varsize2/", contexts_varsize2_tests },
{ "contexts/complex/", contexts_complex_tests },
+ { "positions/", positions_tests },
END_OF_GROUPS,
};
diff --git a/test/c/test.h b/test/c/test.h
index 36fbde8..44242e1 100644
--- a/test/c/test.h
+++ b/test/c/test.h
@@ -31,6 +31,7 @@ extern struct testcase_t contexts_uniontag_tests[];
extern struct testcase_t contexts_varsize_tests[];
extern struct testcase_t contexts_varsize2_tests[];
extern struct testcase_t contexts_complex_tests[];
+extern struct testcase_t positions_tests[];
ssize_t unhex(uint8_t *out, size_t outlen, const char *in);
const uint8_t *ux(const char *in);
diff --git a/test/c/test_positions.c b/test/c/test_positions.c
new file mode 100644
index 0000000..60ede17
--- /dev/null
+++ b/test/c/test_positions.c
@@ -0,0 +1,111 @@
+#include "test.h"
+
+#include "valid/positions.h"
+
+static void
+test_pos_invalid(void *arg)
+{
+ haspos_t *hp = NULL, *hp2 = NULL;
+ uint8_t buf[64];
+ int i;
+ (void)arg;
+
+ /* Encode invalid */
+ tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), NULL));
+ hp = haspos_new();
+ tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+ haspos_set_s1(hp, "Foo");
+ tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+ haspos_set_s2(hp, "Bar");
+ tt_int_op(12, ==, haspos_encode(buf, sizeof(buf), hp));
+
+ /* Encode truncated */
+ for (i = 0; i < 12; ++i)
+ tt_int_op(-2, ==, haspos_encode(buf, i, hp));
+ tt_int_op(12, ==, haspos_encode(buf, 12, hp));
+
+ /* Parse truncated */
+ for (i = 0; i < 12; ++i)
+ tt_int_op(-2, ==, haspos_parse(&hp2, buf, i));
+
+ end:
+ haspos_free(hp);
+}
+
+static void
+test_pos_encdec(void *arg)
+{
+ haspos_t *hp = NULL, *hp2 = NULL;
+ uint8_t buf[64];
+ (void)arg;
+
+ hp = haspos_new();
+ haspos_set_s1(hp, "hello");
+ haspos_set_s2(hp, "world");
+ haspos_set_x(hp, 3);
+ tt_int_op(16, ==, haspos_encode(buf, sizeof(buf), hp));
+ tt_mem_op("hello\0world\0\0\0\0\x03", ==, buf, 16);
+
+ tt_int_op(16, ==, haspos_parse(&hp2, buf, sizeof(buf)));
+ tt_str_op("hello", ==, haspos_get_s1(hp2));
+ tt_str_op("world", ==, haspos_get_s2(hp2));
+ tt_int_op(3, ==, haspos_get_x(hp));
+ tt_ptr_op(buf + 6, ==, haspos_get_pos1(hp2));
+ tt_ptr_op(buf + 12, ==, haspos_get_pos2(hp2));
+
+ end:
+ haspos_free(hp);
+ haspos_free(hp2);
+}
+
+static void
+test_pos_allocfail(void *arg)
+{
+#ifdef ALLOCFAIL
+ haspos_t *hp = NULL;
+ const uint8_t inp[] = "hello\0world\0\0\0\0\x03";
+ uint8_t buf[32];
+ (void)arg;
+
+ set_alloc_fail(1);
+ tt_ptr_op(NULL, ==, haspos_new());
+
+ set_alloc_fail(1);
+ tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+ set_alloc_fail(2);
+ tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+ set_alloc_fail(3);
+ tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+ hp = haspos_new();
+ tt_assert(hp);
+ set_alloc_fail(1);
+ tt_int_op(-1, ==, haspos_set_s1(hp,"Hi"));
+ tt_int_op(0, ==, haspos_set_s1(hp,"Hi"));
+
+ set_alloc_fail(1);
+ tt_int_op(-1, ==, haspos_set_s2(hp,"Hi"));
+ tt_int_op(0, ==, haspos_set_s2(hp,"Hi"));
+
+ tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+ haspos_clear_errors(hp);
+ tt_int_op(10, ==, haspos_encode(buf, sizeof(buf), hp));
+
+ haspos_free(hp); hp = NULL;
+
+ end:
+ haspos_free(hp);
+#else
+ (void)arg;
+ tt_skip();
+#endif
+}
+
+struct testcase_t positions_tests[] = {
+ { "invalid", test_pos_invalid, 0, NULL, NULL },
+ { "encdec", test_pos_encdec, 0, NULL, NULL },
+ { "allocfail", test_pos_allocfail, 0, NULL, NULL },
+ END_OF_TESTCASES
+};
diff --git a/test/failing/badptr.trunnel b/test/failing/badptr.trunnel
new file mode 100644
index 0000000..14e55da
--- /dev/null
+++ b/test/failing/badptr.trunnel
@@ -0,0 +1,5 @@
+
+struct s {
+ nulterm s;
+ @foo bar;
+}
\ No newline at end of file
diff --git a/test/valid/positions.trunnel b/test/valid/positions.trunnel
new file mode 100644
index 0000000..497ebb0
--- /dev/null
+++ b/test/valid/positions.trunnel
@@ -0,0 +1,10 @@
+
+struct haspos {
+ nulterm s1;
+ /** Position right after the first NUL. */
+ @ptr pos1;
+ nulterm s2;
+ @ptr pos2;
+ u32 x;
+}
+