[tor-commits] [flashproxy/master] Add WebSocketDecoder and tests.

dcf at torproject.org dcf at torproject.org
Mon Apr 9 04:08:42 UTC 2012


commit 4726aa424262be3af244a57fe274805da6f26842
Author: David Fifield <david at bamsoftware.com>
Date:   Tue Mar 27 22:59:31 2012 -0700

    Add WebSocketDecoder and tests.
---
 Makefile          |    1 +
 connector-test.py |  151 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 connector.py      |  153 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 305 insertions(+), 0 deletions(-)

diff --git a/Makefile b/Makefile
index 0fb0c79..61c3901 100644
--- a/Makefile
+++ b/Makefile
@@ -20,6 +20,7 @@ clean:
 	rm -f $(TARGETS)
 
 test:
+	./connector-test.py
 	./flashproxy-test.js
 
 .PHONY: all clean test
diff --git a/connector-test.py b/connector-test.py
new file mode 100755
index 0000000..c0479bd
--- /dev/null
+++ b/connector-test.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import unittest
+from connector import WebSocketDecoder
+
+def read_frames(dec):
+    frames = []
+    while True:
+        frame = dec.read_frame()
+        if frame is None:
+            break
+        frames.append((frame.fin, frame.opcode, frame.payload))
+    return frames
+
+def read_messages(dec):
+    messages = []
+    while True:
+        message = dec.read_message()
+        if message is None:
+            break
+        messages.append((message.opcode, message.payload))
+    return messages
+
+class TestWebSocketDecoder(unittest.TestCase):
+    def test_rfc(self):
+        """Test samples from RFC 6455 section 5.7."""
+        TESTS = [
+            ("\x81\x05\x48\x65\x6c\x6c\x6f", False,
+                [(True, 1, "Hello")],
+                [(1, u"Hello")]),
+            ("\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True,
+                [(True, 1, "Hello")],
+                [(1, u"Hello")]),
+            ("\x01\x03\x48\x65\x6c\x80\x02\x6c\x6f", False,
+                [(False, 1, "Hel"), (True, 0, "lo")],
+                [(1, u"Hello")]),
+            ("\x89\x05\x48\x65\x6c\x6c\x6f", False,
+                [(True, 9, "Hello")],
+                [(9, u"Hello")]),
+            ("\x8a\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True,
+                [(True, 10, "Hello")],
+                [(10, u"Hello")]),
+            ("\x82\x7e\x01\x00" + "\x00" * 256, False,
+                [(True, 2, "\x00" * 256)],
+                [(2, "\x00" * 256)]),
+            ("\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + "\x00" * 65536, False,
+                [(True, 2, "\x00" * 65536)],
+                [(2, "\x00" * 65536)]),
+        ]
+        for data, use_mask, expected_frames, expected_messages in TESTS:
+            dec = WebSocketDecoder(use_mask = use_mask)
+            dec.feed(data)
+            actual_frames = read_frames(dec)
+            self.assertEqual(actual_frames, expected_frames)
+
+            dec = WebSocketDecoder(use_mask = use_mask)
+            dec.feed(data)
+            actual_messages = read_messages(dec)
+            self.assertEqual(actual_messages, expected_messages)
+
+            dec = WebSocketDecoder(use_mask = not use_mask)
+            dec.feed(data)
+            self.assertRaises(WebSocketDecoder.MaskingError, dec.read_frame)
+
+    def test_empty_feed(self):
+        """Test that the decoder can handle a zero-byte feed."""
+        dec = WebSocketDecoder()
+        self.assertIsNone(dec.read_frame())
+        dec.feed("")
+        self.assertIsNone(dec.read_frame())
+        dec.feed("\x81\x05H")
+        self.assertIsNone(dec.read_frame())
+        dec.feed("ello")
+        self.assertEqual(read_frames(dec), [(True, 1, u"Hello")])
+
+    def test_empty_frame(self):
+        """Test that a frame may contain a zero-byte payload."""
+        dec = WebSocketDecoder()
+        dec.feed("\x81\x00")
+        self.assertEqual(read_frames(dec), [(True, 1, u"")])
+        dec.feed("\x82\x00")
+        self.assertEqual(read_frames(dec), [(True, 2, "")])
+
+    def test_empty_message(self):
+        """Test that a message may have a zero-byte payload."""
+        dec = WebSocketDecoder()
+        dec.feed("\x01\x00\x00\x00\x80\x00")
+        self.assertEqual(read_messages(dec), [(1, u"")])
+        dec.feed("\x02\x00\x00\x00\x80\x00")
+        self.assertEqual(read_messages(dec), [(2, "")])
+
+    def test_interleaved_control(self):
+        """Test that control messages interleaved with fragmented messages are
+        returned."""
+        dec = WebSocketDecoder()
+        dec.feed("\x89\x04PING\x01\x03Hel\x8a\x04PONG\x80\x02lo\x89\x04PING")
+        self.assertEqual(read_messages(dec), [(9, "PING"), (10, "PONG"), (1, u"Hello"), (9, "PING")])
+
+    def test_fragmented_control(self):
+        """Test that illegal fragmented control messages cause an error."""
+        dec = WebSocketDecoder()
+        dec.feed("\x09\x04PING")
+        self.assertRaises(ValueError, dec.read_message)
+
+    def test_zero_opcode(self):
+        """Test that it is an error for the first frame in a message to have an
+        opcode of 0."""
+        dec = WebSocketDecoder()
+        dec.feed("\x80\x05Hello")
+        self.assertRaises(ValueError, dec.read_message)
+        dec = WebSocketDecoder()
+        dec.feed("\x00\x05Hello")
+        self.assertRaises(ValueError, dec.read_message)
+
+    def test_nonzero_opcode(self):
+        """Test that every frame after the first must have a zero opcode."""
+        dec = WebSocketDecoder()
+        dec.feed("\x01\x01H\x01\x02el\x80\x02lo")
+        self.assertRaises(ValueError, dec.read_message)
+        dec = WebSocketDecoder()
+        dec.feed("\x01\x01H\x00\x02el\x01\x02lo")
+        self.assertRaises(ValueError, dec.read_message)
+
+    def test_utf8(self):
+        """Test that text frames (opcode 1) are decoded from UTF-8."""
+        text = u"Hello World or Καλημέρα κόσμε or こんにちは 世界 or \U0001f639"
+        utf8_text = text.encode("utf-8")
+        dec = WebSocketDecoder()
+        dec.feed("\x81" + chr(len(utf8_text)) + utf8_text)
+        self.assertEqual(read_messages(dec), [(1, text)])
+
+    def test_wrong_utf8(self):
+        """Test that failed UTF-8 decoding causes an error."""
+        TESTS = [
+            "\xc0\x41", # Non-shortest form.
+            "\xc2", # Unfinished sequence.
+        ]
+        for test in TESTS:
+            dec = WebSocketDecoder()
+            dec.feed("\x81" + chr(len(test)) + test)
+            self.assertRaises(ValueError, dec.read_message)
+
+    def test_overly_large_payload(self):
+        """Test that large payloads are rejected."""
+        dec = WebSocketDecoder()
+        dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00")
+        self.assertRaises(ValueError, dec.read_frame)
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/connector.py b/connector.py
index f84f8a2..39d89e8 100755
--- a/connector.py
+++ b/connector.py
@@ -130,6 +130,159 @@ class BufferSocket(object):
     def is_expired(self, timeout):
         return time.time() - self.birthday > timeout
 
+
+class WebSocketFrame(object):
+    def __init__(self):
+        self.fin = False
+        self.opcode = None
+        self.payload = None
+
+    def is_control(self):
+        return (self.opcode & 0x08) != 0
+
+class WebSocketMessage(object):
+    def __init__(self):
+        self.opcode = None
+        self.payload = None
+
+    def is_control(self):
+        return (self.opcode & 0x08) != 0
+
+class WebSocketDecoder(object):
+    """RFC 6455 section 5 is about the WebSocket framing format."""
+    # Raise an exception rather than buffer anything arger than this.
+    MAX_MESSAGE_LENGTH = 1024 * 1024
+
+    class MaskingError(ValueError):
+        pass
+
+    def __init__(self, use_mask = False):
+        """use_mask should be True for server-to-client sockets, and False for
+        client-to-server sockets."""
+        self.use_mask = use_mask
+
+        # Per-frame state.
+        self.buf = ""
+
+        # Per-message state.
+        self.message_buf = ""
+        self.message_opcode = None
+
+    def feed(self, data):
+        self.buf += data
+
+    @staticmethod
+    def mask(payload, mask_key):
+        result = []
+        for i, c in enumerate(payload):
+            mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
+            result.append(mc)
+        return "".join(result)
+
+    def read_frame(self):
+        """Read a frame from the internal buffer, if one is available. Returns a
+        WebSocketFrame object, or None if there are no complete frames to
+        read."""
+        # RFC 6255 section 5.2.
+        if len(self.buf) < 2:
+            return None
+        offset = 0
+        b0, b1 = struct.unpack_from(">BB", self.buf, offset)
+        offset += 2
+        fin = (b0 & 0x80) != 0
+        opcode = b0 & 0x0f
+        frame_masked = (b1 & 0x80) != 0
+        payload_len = b1 & 0x7f
+
+        if payload_len == 126:
+            if len(self.buf) < offset + 2:
+                return None
+            payload_len, = struct.unpack_from(">H", self.buf, offset)
+            offset += 2
+        elif payload_len == 127:
+            if len(self.buf) < offset + 8:
+                return None
+            payload_len, = struct.unpack_from(">Q", self.buf, offset)
+            offset += 8
+
+        if frame_masked:
+            if not self.use_mask:
+                # "A client MUST close a connection if it detects a masked
+                # frame."
+                raise self.MaskingError("Got masked payload from server")
+            if len(self.buf) < offset + 4:
+                return None
+            mask_key = self.buf[offset:offset+4]
+            offset += 4
+        else:
+            if self.use_mask:
+                # "The server MUST close the connection upon receiving a frame
+                # that is not masked."
+                raise self.MaskingError("Got unmasked payload from client")
+            mask_key = "\x00\x00\x00\x00"
+
+        if payload_len > self.MAX_MESSAGE_LENGTH:
+            raise ValueError("Refusing to buffer payload of %d bytes" % payload_len)
+
+        if len(self.buf) < offset + payload_len:
+            return None
+        payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key)
+        self.buf = self.buf[offset+payload_len:]
+
+        frame = WebSocketFrame()
+        frame.fin = fin
+        frame.opcode = opcode
+        frame.payload = payload
+
+        return frame
+
+    def read_message(self):
+        """Read a complete message. If the opcode is 1, the payload is decoded
+        from a UTF-8 binary string to a unicode string. If a control frame is
+        read while another fragmented message is in progress, the control frame
+        is returned as a new message immediately. Returns None if there is no
+        complete frame to be read."""
+        # RFC 6455 section 5.4 is about fragmentation.
+        while True:
+            frame = self.read_frame()
+            if frame is None:
+                return None
+            # "Control frames (see Section 5.5) MAY be injected in the middle of
+            # a fragmented message. Control frames themselves MUST NOT be
+            # fragmented.
+            if frame.is_control():
+                if not frame.fin:
+                    raise ValueError("Control frame (opcode %d) has FIN bit clear" % frame.opcode)
+                message = WebSocketMessage()
+                message.opcode = frame.opcode
+                message.payload = frame.payload
+                return message
+
+            if self.message_opcode is None:
+                if frame.opcode == 0:
+                    raise ValueError("First frame has opcode 0")
+                self.message_opcode = frame.opcode
+            else:
+                if frame.opcode != 0:
+                    raise ValueError("Non-first frame has nonzero opcode %d" % frame.opcode)
+            self.message_buf += frame.payload
+
+            if frame.fin:
+                break
+        message = WebSocketMessage()
+        message.opcode = self.message_opcode
+        message.payload = self.message_buf
+        self.postprocess_message(message)
+        self.message_opcode = None
+        self.message_buf = ""
+
+        return message
+
+    def postprocess_message(self, message):
+        if message.opcode == 1:
+            message.payload = message.payload.decode("utf-8")
+        return message
+
 def listen_socket(addr):
     """Return a nonblocking socket listening on the given address."""
     addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]





More information about the tor-commits mailing list