[tor-commits] [chutney/master] Refactor Source and Sink to not know about their data

teor at torproject.org teor at torproject.org
Thu Jun 20 07:45:41 UTC 2019


commit 1bbe2079c40ec6d6517c2125274e9b9510c09197
Author: Nick Mathewson <nickm at torproject.org>
Date:   Fri May 10 10:12:00 2019 -0400

    Refactor Source and Sink to not know about their data
    
    Now data is generated by a DataSource type, and consumed by a
    DataChecker type that compares its incoming data against the data
    in a DataSource.
---
 lib/chutney/Traffic.py | 113 +++++++++++++++++++++++++++++--------------------
 1 file changed, 68 insertions(+), 45 deletions(-)

diff --git a/lib/chutney/Traffic.py b/lib/chutney/Traffic.py
index f877452..72b9ca3 100755
--- a/lib/chutney/Traffic.py
+++ b/lib/chutney/Traffic.py
@@ -107,7 +107,6 @@ class TestSuite(object):
         return('%s: %d/%d/%d' % (self.tests, self.not_done, self.successes,
                                  self.failures))
 
-
 class Listener(asyncore.dispatcher):
     "A TCP listener, binding, listening and accepting new connections."
 
@@ -132,14 +131,69 @@ class Listener(asyncore.dispatcher):
     def fileno(self):
         return self.socket.fileno()
 
+class DataSource(object):
+    """A data source generates some number of bytes of data, and then
+       returns None.
+
+       For convenience, it conforms to the 'producer' api.
+    """
+    def __init__(self, data, repetitions=1):
+        self.data = data
+        self.repetitions = repetitions
+        self.sent_any = False
+
+    def copy(self):
+        assert not self.sent_any
+        return DataSource(self.data, self.repetitions)
+
+    def more(self):
+        self.sent_any = True
+        if self.repetitions > 0:
+            self.repetitions -= 1
+            return self.data
+
+        return None
+
+class DataChecker(object):
+    """A data checker verifies its input against bytes in a stream."""
+    def __init__(self, source):
+        self.source = source
+        self.pending = b""
+        self.succeeded = False
+        self.failed = False
+
+    def consume(self, inp):
+        if self.failed:
+            return
+        if self.succeeded and len(inp):
+            self.succeeded = False
+            self.failed = True
+            return
+
+        while len(inp):
+            n = min(len(inp), len(self.pending))
+            if inp[:n] != self.pending[:n]:
+                self.failed = True
+                return
+            inp = inp[n:]
+            self.pending = self.pending[n:]
+            if not self.pending:
+                self.pending = self.source.more()
+
+                if self.pending is None:
+                    if len(inp):
+                        self.failed = True
+                    else:
+                        self.succeeded = True
+                    return
+
 class Sink(asynchat.async_chat):
     "A data sink, reading from its peer and verifying the data."
     def __init__(self, sock, tt):
         asynchat.async_chat.__init__(self, sock)
-        self.inbuf = b""
         self.set_terminator(None)
         self.tt = tt
-        self.repetitions = tt.repetitions
+        self.data_checker = DataChecker(tt.data_source.copy())
         self.testname = "recv-data%s"%id(self)
 
     def get_test_names(self):
@@ -148,33 +202,16 @@ class Sink(asynchat.async_chat):
     def collect_incoming_data(self, inp):
         # shortcut read when we don't ever expect any data
 
-        self.inbuf += inp
-        data = self.tt.data
-        debug("successfully received (bytes=%d)" % len(self.inbuf))
-        while len(self.inbuf) >= len(data):
-            assert(len(self.inbuf) <= len(data) or self.repetitions > 1)
-            if self.inbuf[:len(data)] != data:
-                debug("receive comparison failed (bytes=%d)" % len(data))
-                self.tt.failure(self.testname)
-                self.close()
-            # if we're not debugging, print a dot every dot_repetitions reps
-            elif (not debug_flag and self.tt.dot_repetitions > 0 and
-                  self.repetitions % self.tt.dot_repetitions == 0):
-                sys.stdout.write('.')
-                sys.stdout.flush()
-            # repeatedly check data against self.inbuf if required
-            debug("receive comparison success (bytes=%d)" % len(data))
-            self.inbuf = self.inbuf[len(data):]
-            debug("receive leftover bytes (bytes=%d)" % len(self.inbuf))
-            self.repetitions -= 1
-            debug("receive remaining repetitions (reps=%d)" % self.repetitions)
-        if self.repetitions == 0 and len(self.inbuf) == 0:
+        debug("successfully received (bytes=%d)" % len(inp))
+        self.data_checker.consume(inp)
+        if self.data_checker.succeeded:
             debug("successful verification")
             self.close()
             self.tt.success(self.testname)
-        # calculate the actual length of data remaining, including reps
-        debug("receive remaining bytes (bytes=%d)"
-              % (self.repetitions*len(data) - len(self.inbuf)))
+        elif self.data_checker.failed:
+            debug("receive comparison failed")
+            self.tt.failure(self.testname)
+            self.close()
 
     def fileno(self):
         return self.socket.fileno()
@@ -197,22 +234,13 @@ class Source(asynchat.async_chat):
 
     def __init__(self, tt, server, buf, proxy=None, repetitions=1):
         asynchat.async_chat.__init__(self)
-        self.data = buf
-        self.outbuf = b''
+        self.data_source = DataSource(buf, repetitions)
         self.inbuf = b''
         self.proxy = proxy
         self.server = server
-        self.repetitions = repetitions
-        self._sent_no_bytes = 0
         self.tt = tt
         self.testname = "send-data%s"%id(self)
 
-        # sanity checks
-        if len(self.data) == 0:
-            self.repetitions = 0
-        if self.repetitions == 0:
-            self.data = b""
-
         self.set_terminator(None)
         dest = (self.proxy or self.server)
         self.create_socket(addr_to_family(dest[0]), socket.SOCK_STREAM)
@@ -251,8 +279,7 @@ class Source(asynchat.async_chat):
                     self.close()
 
     def push_output(self):
-        for _ in range(self.repetitions):
-            self.push_with_producer(asynchat.simple_producer(self.data))
+        self.push_with_producer(self.data_source)
 
         self.push_with_producer(CloseSourceProducer(self))
         self.close_when_done()
@@ -279,13 +306,9 @@ class TrafficTester(object):
         self.pending_close = []
         self.timeout = timeout
         self.tests = TestSuite()
-        self.data = data
-        self.repetitions = repetitions
+        self.data_source = DataSource(data, repetitions)
+
         # sanity checks
-        if len(self.data) == 0:
-            self.repetitions = 0
-        if self.repetitions == 0:
-            self.data = b""
         self.dot_repetitions = dot_repetitions
         debug("listener fd=%d" % self.listener.fileno())
 





More information about the tor-commits mailing list