[tor-commits] [trunnel/master] Add support for remembering a position within the input stream

nickm at torproject.org nickm at torproject.org
Wed Oct 8 14:42:12 UTC 2014


commit 879255e21821087a438f418b79e8ad7977832797
Author: Nick Mathewson <nickm at torproject.org>
Date:   Wed Oct 8 10:42:02 2014 -0400

    Add support for remembering a position within the input stream
---
 doc/trunnel.md               |   22 +++++++++
 lib/trunnel/CodeGen.py       |   37 ++++++++++++++
 lib/trunnel/Grammar.py       |   29 ++++++++++-
 test/Makefile                |    6 +++
 test/c/test.c                |    1 +
 test/c/test.h                |    1 +
 test/c/test_positions.c      |  111 ++++++++++++++++++++++++++++++++++++++++++
 test/failing/badptr.trunnel  |    5 ++
 test/valid/positions.trunnel |   10 ++++
 9 files changed, 221 insertions(+), 1 deletion(-)

diff --git a/doc/trunnel.md b/doc/trunnel.md
index 2ffbab8..eaf768a 100644
--- a/doc/trunnel.md
+++ b/doc/trunnel.md
@@ -281,6 +281,28 @@ In newly constructed structures, all variable-length arrays are empty.
 It's an error to try to encode a variable-length array with a length field if
 that array's length field doesn't match its actual length.
 
+### Structure members: zero-length indices into the input
+
+Sometimes you need to record the position in the input the corresponds to
+a position in the structure.  You can use an `@ptr` field to record
+a position within a structure when parsing it:
+
+    struct s {
+      nulterm unsigned_header;
+      @ptr start_of_signed_material;
+      u32 bodylen;
+      u8 body[bodylen];
+      u64 flags;
+      @ptr end_of_signed_material;
+      u16 signature_len;
+      u8 signature[signature_len];
+    }
+
+When an object of this type is parsed, then `start_of_signed_material`
+and `end_of_signed_material` will get set to pointers into the input.
+These pointers are only set when the input is parsed; you don't need
+to set them to encode the object.
+
 ### Structure members: unions
 
 You can specify that different elements should be parsed based on some
diff --git a/lib/trunnel/CodeGen.py b/lib/trunnel/CodeGen.py
index 5867fd0..5a5f070 100644
--- a/lib/trunnel/CodeGen.py
+++ b/lib/trunnel/CodeGen.py
@@ -338,6 +338,9 @@ class Checker(ASTVisitor):
     def visitSMString(self, sms):
         self.addMemberName(sms.name)
 
+    def visitSMPosition(self, smp):
+        self.addMemberName(smp.name)
+
     def visitSMLenConstrained(self, sml):
         if sml.lengthfield != None:
             self.checkIntField(
@@ -544,6 +547,9 @@ class Annotator(ASTVisitor):
     def visitSMString(self, ss):
         self.annotateMember(ss)
 
+    def visitSMPosition(self, smp):
+        self.annotateMember(smp)
+
     def visitSMLenConstrained(self, sml):
         sml.lengthfieldmember = None
         if sml.lengthfield is not None:
@@ -781,6 +787,11 @@ class DeclarationGenerationVisitor(CodeGenerator):
 
         self.w("char *%s;\n" % (ss.c_name))
 
+    def visitSMPosition(self, smp):
+        if smp.annotation != None:
+            self.w(smp.annotation)
+        self.w("const uint8_t *%s;\n" % smp.c_name)
+
     def visitSMLenConstrained(self, sml):
         sml.visitChildren(self)
 
@@ -1072,6 +1083,9 @@ class FreeFnGenerator(CodeGenerator):
         self.w("trunnel_wipestr(obj->%s);\n" % (ss.c_name))
         self.w("trunnel_free(obj->%s);\n" % (ss.c_name))
 
+    def visitSMPosition(self, smp):
+        pass
+
     def visitSMLenConstrained(self, sml):
         sml.visitChildren(self)
 
@@ -1579,6 +1593,17 @@ class AccessorFnGenerator(CodeGenerator):
                return 0;
              }}""", c_name=sms.c_name)
 
+    def visitSMPosition(self, smp):
+        st = self.structName
+        nm = smp.c_fn_name
+        self.docstring("Return the position for %s when we parsed "
+                       "this object"%nm)
+        self.declaration("const uint8_t *",
+                         "%s_get_%s(const %s_t *inp)" % (st,nm,st))
+        self.format("""
+              {{
+                return inp->{nm};
+              }}""", nm = smp.c_name)
 
 def iterateOverFixedArray(generator, sfa, body, extraDecl=""):
     """Helper: write the code needed to iterate over every element of a
@@ -1739,6 +1764,9 @@ class CheckFnGenerator(CodeGenerator):
         self.w('if (NULL == obj->%s)\n  return "Missing %s";\n' %
                (ss.c_name, ss.c_name))
 
+    def visitSMPosition(self, smp):
+        pass
+
     def visitSMLenConstrained(self, sml):
         # To check a SMlenConstrained, check its children.
         sml.visitChildren(self)
@@ -1872,6 +1900,9 @@ class EncodedLenFnGenerator(CodeGenerator):
         self.eltHeader(ss)
         self.w("result += strlen(obj->%s) + 1;\n" % ss.c_name)
 
+    def visitSMPosition(self, smp):
+        pass
+
     def visitSMLenConstrained(self, sml):
         sml.visitChildren(self)
 
@@ -2191,6 +2222,9 @@ class EncodeFnGenerator(CodeGenerator):
                   ptr += len + 1; written += len + 1;
                 }}""", c_name=ss.c_name)
 
+    def visitSMPosition(self, smp):
+        pass
+
     def visitSMLenConstrained(self, sml):
         # To encode a length-constained field of a structure,
         # remember the position at which we began writing to the union.
@@ -2643,6 +2677,9 @@ class ParseFnGenerator(CodeGenerator):
                   remaining -= memlen; ptr += memlen;
                 }}""", c_name=ss.c_name, truncated=self.truncatedLabel)
 
+    def visitSMPosition(self, smp):
+        self.format("obj->{c_name} = ptr;", c_name=smp.c_name);
+
     def visitSMLenConstrained(self, sml):
         # To parse a length-constrained region, make sure that at
         # least that many bytes remain in the structure.  Then,
diff --git a/lib/trunnel/Grammar.py b/lib/trunnel/Grammar.py
index 0de7fa7..25dc53a 100644
--- a/lib/trunnel/Grammar.py
+++ b/lib/trunnel/Grammar.py
@@ -113,7 +113,7 @@ class Lexer(trunnel.spark.GenericScanner, object):
         trunnel.spark.GenericScanner.tokenize(self, input)
         return self.rv
 
-    @pattern(r"(?:[;{}\[\]\-=,:]|\.\.\.|\.\.|\.)")
+    @pattern(r"(?:[;{}@\[\]\-=,:]|\.\.\.|\.\.|\.)")
     def t_punctuation(self, s):
         self.rv.append(Token(s, self.lineno))
 
@@ -570,6 +570,15 @@ class SMIgnore(StructMember):
        ignored."""
     pass
 
+class SMPosition(StructMember):
+    """ A struct member: notes that we should store a pointer to this point
+        in the input when we """
+    def __init__(self, name):
+        StructMember.__init__(self, name)
+
+    def __str__(self):
+        return "@" + self.name
+
 
 class IDReference(AST):
 
@@ -778,6 +787,10 @@ class Parser(trunnel.spark.GenericParser, object):
     def p_StructMember_4(self, info):
         return info[0]
 
+    @rule(" StructMember ::= SMPosition ")
+    def p_StructMember_5(self, info):
+        return info[0]
+
     @rule(" SMInteger ::= IntType ID OptIntConstraint ")
     def p_SMInteger(self, info):
         return SMInteger(info[0], str(info[1]), info[2])
@@ -997,6 +1010,10 @@ class Parser(trunnel.spark.GenericParser, object):
     def p_UnionField_5(self, info):
         return info[0]
 
+    @rule(" UnionField ::= SMSPosition ")
+    def p_UnionField_6(self, info):
+        return info[0]
+
     @rule(" ContextDecl ::= context ID { ContextMembers } ")
     def p_ContextDecl(self, info):
         return StructDecl(str(info[1]), info[3], isContext=True)
@@ -1017,6 +1034,16 @@ class Parser(trunnel.spark.GenericParser, object):
     def p_ContextMember(self, info):
         return SMInteger(info[0], str(info[1]), None)
 
+    @rule(" SMPosition ::= @ PtrKW ID ")
+    def p_SMPosition(self, info):
+        return SMPosition(str(info[2]))
+
+    @rule(" PtrKW ::= ID ")
+    def p_PtrKW(self, info):
+        if str(info[0]) != 'ptr':
+            raise SyntaxError("Expected 'ptr' at %s" % info[0].lineno)
+        return None
+
 if __name__ == '__main__':
     print ("===== Here is our actual grammar, extracted from Grammar.py\n")
 
diff --git a/test/Makefile b/test/Makefile
index ba70e8d..603526f 100644
--- a/test/Makefile
+++ b/test/Makefile
@@ -31,6 +31,7 @@ TEST_OBJS = \
     c/test_contexts_varsize2.o \
     c/test_contexts_complex.o \
     c/test_remainder_repeats.o \
+    c/test_positions.o \
     c/test_util.o
 
 BOILERPLATE_FILES=\
@@ -44,6 +45,7 @@ OBJS=tinytest/tinytest.o \
     valid/opaque.o \
     valid/leftover.o \
     valid/contexts.o \
+    valid/positions.o \
     ./include/trunnel.o \
     $(TEST_OBJS)
 
@@ -69,6 +71,7 @@ valid/derived.o: valid/derived.h valid/derived.c
 valid/opaque.o: valid/opaque.h valid/opaque.h
 valid/leftover.o: valid/leftover.h valid/leftover.c
 valid/contexts.o: valid/contexts.h
+valid/positions.o: valid/positions.h
 $(TEST_OBJS) : tinytest/tinytest.h tinytest/tinytest_macros.h valid/simple.h valid/derived.h
 $(OBJS) : include/trunnel.h include/trunnel-impl.h
 tinytest/tinytest.o: tinytest/tinytest.h tinytest/tinytest_macros.h
@@ -88,5 +91,8 @@ valid/leftover.c valid/leftover.h: valid/leftover.trunnel ../lib/trunnel/*py
 valid/contexts.c valid/contexts.h: valid/contexts.trunnel ../lib/trunnel/*py
 	PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/contexts.trunnel
 
+valid/positions.c valid/positions.h: valid/positions.trunnel ../lib/trunnel/*py
+	PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel valid/positions.trunnel
+
 $(BOILERPLATE_FILES): ../lib/trunnel/*py ../lib/trunnel/data/*.[ch]
 	PYTHONPATH=../lib:${PYTHONPATH} python -m trunnel --target-dir=./include --write-c-files
diff --git a/test/c/test.c b/test/c/test.c
index 534ae69..cd9c803 100644
--- a/test/c/test.c
+++ b/test/c/test.c
@@ -21,6 +21,7 @@ struct testgroup_t test_groups[] = {
   { "contexts/varsize/", contexts_varsize_tests },
   { "contexts/varsize2/", contexts_varsize2_tests },
   { "contexts/complex/", contexts_complex_tests },
+  { "positions/", positions_tests },
   END_OF_GROUPS,
 };
 
diff --git a/test/c/test.h b/test/c/test.h
index 36fbde8..44242e1 100644
--- a/test/c/test.h
+++ b/test/c/test.h
@@ -31,6 +31,7 @@ extern struct testcase_t contexts_uniontag_tests[];
 extern struct testcase_t contexts_varsize_tests[];
 extern struct testcase_t contexts_varsize2_tests[];
 extern struct testcase_t contexts_complex_tests[];
+extern struct testcase_t positions_tests[];
 
 ssize_t unhex(uint8_t *out, size_t outlen, const char *in);
 const uint8_t *ux(const char *in);
diff --git a/test/c/test_positions.c b/test/c/test_positions.c
new file mode 100644
index 0000000..60ede17
--- /dev/null
+++ b/test/c/test_positions.c
@@ -0,0 +1,111 @@
+#include "test.h"
+
+#include "valid/positions.h"
+
+static void
+test_pos_invalid(void *arg)
+{
+  haspos_t *hp = NULL, *hp2 = NULL;
+  uint8_t buf[64];
+  int i;
+  (void)arg;
+
+  /* Encode invalid */
+  tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), NULL));
+  hp = haspos_new();
+  tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+  haspos_set_s1(hp, "Foo");
+  tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+  haspos_set_s2(hp, "Bar");
+  tt_int_op(12, ==, haspos_encode(buf, sizeof(buf), hp));
+
+  /* Encode truncated */
+  for (i = 0; i < 12; ++i)
+    tt_int_op(-2, ==, haspos_encode(buf, i, hp));
+  tt_int_op(12, ==, haspos_encode(buf, 12, hp));
+
+  /* Parse truncated */
+  for (i = 0; i < 12; ++i)
+    tt_int_op(-2, ==, haspos_parse(&hp2, buf, i));
+
+ end:
+  haspos_free(hp);
+}
+
+static void
+test_pos_encdec(void *arg)
+{
+  haspos_t *hp = NULL, *hp2 = NULL;
+  uint8_t buf[64];
+  (void)arg;
+
+  hp = haspos_new();
+  haspos_set_s1(hp, "hello");
+  haspos_set_s2(hp, "world");
+  haspos_set_x(hp, 3);
+  tt_int_op(16, ==, haspos_encode(buf, sizeof(buf), hp));
+  tt_mem_op("hello\0world\0\0\0\0\x03", ==, buf, 16);
+
+  tt_int_op(16, ==, haspos_parse(&hp2, buf, sizeof(buf)));
+  tt_str_op("hello", ==, haspos_get_s1(hp2));
+  tt_str_op("world", ==, haspos_get_s2(hp2));
+  tt_int_op(3, ==, haspos_get_x(hp));
+  tt_ptr_op(buf + 6, ==, haspos_get_pos1(hp2));
+  tt_ptr_op(buf + 12, ==, haspos_get_pos2(hp2));
+
+ end:
+  haspos_free(hp);
+  haspos_free(hp2);
+}
+
+static void
+test_pos_allocfail(void *arg)
+{
+#ifdef ALLOCFAIL
+  haspos_t *hp = NULL;
+  const uint8_t inp[] = "hello\0world\0\0\0\0\x03";
+  uint8_t buf[32];
+  (void)arg;
+
+  set_alloc_fail(1);
+  tt_ptr_op(NULL, ==, haspos_new());
+
+  set_alloc_fail(1);
+  tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+  set_alloc_fail(2);
+  tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+  set_alloc_fail(3);
+  tt_int_op(-1, ==, haspos_parse(&hp, inp, sizeof(inp)));
+
+  hp = haspos_new();
+  tt_assert(hp);
+  set_alloc_fail(1);
+  tt_int_op(-1, ==, haspos_set_s1(hp,"Hi"));
+  tt_int_op(0, ==, haspos_set_s1(hp,"Hi"));
+
+  set_alloc_fail(1);
+  tt_int_op(-1, ==, haspos_set_s2(hp,"Hi"));
+  tt_int_op(0, ==, haspos_set_s2(hp,"Hi"));
+
+  tt_int_op(-1, ==, haspos_encode(buf, sizeof(buf), hp));
+  haspos_clear_errors(hp);
+  tt_int_op(10, ==, haspos_encode(buf, sizeof(buf), hp));
+
+  haspos_free(hp); hp = NULL;
+
+ end:
+  haspos_free(hp);
+#else
+  (void)arg;
+  tt_skip();
+#endif
+}
+
+struct testcase_t positions_tests[] = {
+  { "invalid", test_pos_invalid, 0, NULL, NULL },
+  { "encdec", test_pos_encdec, 0, NULL, NULL },
+  { "allocfail", test_pos_allocfail, 0, NULL, NULL },
+  END_OF_TESTCASES
+};
diff --git a/test/failing/badptr.trunnel b/test/failing/badptr.trunnel
new file mode 100644
index 0000000..14e55da
--- /dev/null
+++ b/test/failing/badptr.trunnel
@@ -0,0 +1,5 @@
+
+struct s {
+  nulterm s;
+  @foo bar;
+}
\ No newline at end of file
diff --git a/test/valid/positions.trunnel b/test/valid/positions.trunnel
new file mode 100644
index 0000000..497ebb0
--- /dev/null
+++ b/test/valid/positions.trunnel
@@ -0,0 +1,10 @@
+
+struct haspos {
+  nulterm s1;
+  /** Position right after the first NUL. */
+  @ptr pos1;
+  nulterm s2;
+  @ptr pos2;
+  u32 x;
+}
+



More information about the tor-commits mailing list