commit 4726aa424262be3af244a57fe274805da6f26842
Author: David Fifield <david(a)bamsoftware.com>
Date: Tue Mar 27 22:59:31 2012 -0700
Add WebSocketDecoder and tests.
---
Makefile | 1 +
connector-test.py | 151 ++++++++++++++++++++++++++++++++++++++++++++++++++++
connector.py | 153 +++++++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 305 insertions(+), 0 deletions(-)
diff --git a/Makefile b/Makefile
index 0fb0c79..61c3901 100644
--- a/Makefile
+++ b/Makefile
@@ -20,6 +20,7 @@ clean:
rm -f $(TARGETS)
test:
+ ./connector-test.py
./flashproxy-test.js
.PHONY: all clean test
diff --git a/connector-test.py b/connector-test.py
new file mode 100755
index 0000000..c0479bd
--- /dev/null
+++ b/connector-test.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import unittest
+from connector import WebSocketDecoder
+
+def read_frames(dec):
+ frames = []
+ while True:
+ frame = dec.read_frame()
+ if frame is None:
+ break
+ frames.append((frame.fin, frame.opcode, frame.payload))
+ return frames
+
+def read_messages(dec):
+ messages = []
+ while True:
+ message = dec.read_message()
+ if message is None:
+ break
+ messages.append((message.opcode, message.payload))
+ return messages
+
+class TestWebSocketDecoder(unittest.TestCase):
+ def test_rfc(self):
+ """Test samples from RFC 6455 section 5.7."""
+ TESTS = [
+ ("\x81\x05\x48\x65\x6c\x6c\x6f", False,
+ [(True, 1, "Hello")],
+ [(1, u"Hello")]),
+ ("\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True,
+ [(True, 1, "Hello")],
+ [(1, u"Hello")]),
+ ("\x01\x03\x48\x65\x6c\x80\x02\x6c\x6f", False,
+ [(False, 1, "Hel"), (True, 0, "lo")],
+ [(1, u"Hello")]),
+ ("\x89\x05\x48\x65\x6c\x6c\x6f", False,
+ [(True, 9, "Hello")],
+ [(9, u"Hello")]),
+ ("\x8a\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True,
+ [(True, 10, "Hello")],
+ [(10, u"Hello")]),
+ ("\x82\x7e\x01\x00" + "\x00" * 256, False,
+ [(True, 2, "\x00" * 256)],
+ [(2, "\x00" * 256)]),
+ ("\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + "\x00" * 65536, False,
+ [(True, 2, "\x00" * 65536)],
+ [(2, "\x00" * 65536)]),
+ ]
+ for data, use_mask, expected_frames, expected_messages in TESTS:
+ dec = WebSocketDecoder(use_mask = use_mask)
+ dec.feed(data)
+ actual_frames = read_frames(dec)
+ self.assertEqual(actual_frames, expected_frames)
+
+ dec = WebSocketDecoder(use_mask = use_mask)
+ dec.feed(data)
+ actual_messages = read_messages(dec)
+ self.assertEqual(actual_messages, expected_messages)
+
+ dec = WebSocketDecoder(use_mask = not use_mask)
+ dec.feed(data)
+ self.assertRaises(WebSocketDecoder.MaskingError, dec.read_frame)
+
+ def test_empty_feed(self):
+ """Test that the decoder can handle a zero-byte feed."""
+ dec = WebSocketDecoder()
+ self.assertIsNone(dec.read_frame())
+ dec.feed("")
+ self.assertIsNone(dec.read_frame())
+ dec.feed("\x81\x05H")
+ self.assertIsNone(dec.read_frame())
+ dec.feed("ello")
+ self.assertEqual(read_frames(dec), [(True, 1, u"Hello")])
+
+ def test_empty_frame(self):
+ """Test that a frame may contain a zero-byte payload."""
+ dec = WebSocketDecoder()
+ dec.feed("\x81\x00")
+ self.assertEqual(read_frames(dec), [(True, 1, u"")])
+ dec.feed("\x82\x00")
+ self.assertEqual(read_frames(dec), [(True, 2, "")])
+
+ def test_empty_message(self):
+ """Test that a message may have a zero-byte payload."""
+ dec = WebSocketDecoder()
+ dec.feed("\x01\x00\x00\x00\x80\x00")
+ self.assertEqual(read_messages(dec), [(1, u"")])
+ dec.feed("\x02\x00\x00\x00\x80\x00")
+ self.assertEqual(read_messages(dec), [(2, "")])
+
+ def test_interleaved_control(self):
+ """Test that control messages interleaved with fragmented messages are
+ returned."""
+ dec = WebSocketDecoder()
+ dec.feed("\x89\x04PING\x01\x03Hel\x8a\x04PONG\x80\x02lo\x89\x04PING")
+ self.assertEqual(read_messages(dec), [(9, "PING"), (10, "PONG"), (1, u"Hello"), (9, "PING")])
+
+ def test_fragmented_control(self):
+ """Test that illegal fragmented control messages cause an error."""
+ dec = WebSocketDecoder()
+ dec.feed("\x09\x04PING")
+ self.assertRaises(ValueError, dec.read_message)
+
+ def test_zero_opcode(self):
+ """Test that it is an error for the first frame in a message to have an
+ opcode of 0."""
+ dec = WebSocketDecoder()
+ dec.feed("\x80\x05Hello")
+ self.assertRaises(ValueError, dec.read_message)
+ dec = WebSocketDecoder()
+ dec.feed("\x00\x05Hello")
+ self.assertRaises(ValueError, dec.read_message)
+
+ def test_nonzero_opcode(self):
+ """Test that every frame after the first must have a zero opcode."""
+ dec = WebSocketDecoder()
+ dec.feed("\x01\x01H\x01\x02el\x80\x02lo")
+ self.assertRaises(ValueError, dec.read_message)
+ dec = WebSocketDecoder()
+ dec.feed("\x01\x01H\x00\x02el\x01\x02lo")
+ self.assertRaises(ValueError, dec.read_message)
+
+ def test_utf8(self):
+ """Test that text frames (opcode 1) are decoded from UTF-8."""
+ text = u"Hello World or Καλημέρα κόσμε or こんにちは 世界 or \U0001f639"
+ utf8_text = text.encode("utf-8")
+ dec = WebSocketDecoder()
+ dec.feed("\x81" + chr(len(utf8_text)) + utf8_text)
+ self.assertEqual(read_messages(dec), [(1, text)])
+
+ def test_wrong_utf8(self):
+ """Test that failed UTF-8 decoding causes an error."""
+ TESTS = [
+ "\xc0\x41", # Non-shortest form.
+ "\xc2", # Unfinished sequence.
+ ]
+ for test in TESTS:
+ dec = WebSocketDecoder()
+ dec.feed("\x81" + chr(len(test)) + test)
+ self.assertRaises(ValueError, dec.read_message)
+
+ def test_overly_large_payload(self):
+ """Test that large payloads are rejected."""
+ dec = WebSocketDecoder()
+ dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00")
+ self.assertRaises(ValueError, dec.read_frame)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/connector.py b/connector.py
index f84f8a2..39d89e8 100755
--- a/connector.py
+++ b/connector.py
@@ -130,6 +130,159 @@ class BufferSocket(object):
def is_expired(self, timeout):
return time.time() - self.birthday > timeout
+
+class WebSocketFrame(object):
+ def __init__(self):
+ self.fin = False
+ self.opcode = None
+ self.payload = None
+
+ def is_control(self):
+ return (self.opcode & 0x08) != 0
+
+class WebSocketMessage(object):
+ def __init__(self):
+ self.opcode = None
+ self.payload = None
+
+ def is_control(self):
+ return (self.opcode & 0x08) != 0
+
+class WebSocketDecoder(object):
+ """RFC 6455 section 5 is about the WebSocket framing format."""
+ # Raise an exception rather than buffer anything arger than this.
+ MAX_MESSAGE_LENGTH = 1024 * 1024
+
+ class MaskingError(ValueError):
+ pass
+
+ def __init__(self, use_mask = False):
+ """use_mask should be True for server-to-client sockets, and False for
+ client-to-server sockets."""
+ self.use_mask = use_mask
+
+ # Per-frame state.
+ self.buf = ""
+
+ # Per-message state.
+ self.message_buf = ""
+ self.message_opcode = None
+
+ def feed(self, data):
+ self.buf += data
+
+ @staticmethod
+ def mask(payload, mask_key):
+ result = []
+ for i, c in enumerate(payload):
+ mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
+ result.append(mc)
+ return "".join(result)
+
+ def read_frame(self):
+ """Read a frame from the internal buffer, if one is available. Returns a
+ WebSocketFrame object, or None if there are no complete frames to
+ read."""
+ # RFC 6255 section 5.2.
+ if len(self.buf) < 2:
+ return None
+ offset = 0
+ b0, b1 = struct.unpack_from(">BB", self.buf, offset)
+ offset += 2
+ fin = (b0 & 0x80) != 0
+ opcode = b0 & 0x0f
+ frame_masked = (b1 & 0x80) != 0
+ payload_len = b1 & 0x7f
+
+ if payload_len == 126:
+ if len(self.buf) < offset + 2:
+ return None
+ payload_len, = struct.unpack_from(">H", self.buf, offset)
+ offset += 2
+ elif payload_len == 127:
+ if len(self.buf) < offset + 8:
+ return None
+ payload_len, = struct.unpack_from(">Q", self.buf, offset)
+ offset += 8
+
+ if frame_masked:
+ if not self.use_mask:
+ # "A client MUST close a connection if it detects a masked
+ # frame."
+ raise self.MaskingError("Got masked payload from server")
+ if len(self.buf) < offset + 4:
+ return None
+ mask_key = self.buf[offset:offset+4]
+ offset += 4
+ else:
+ if self.use_mask:
+ # "The server MUST close the connection upon receiving a frame
+ # that is not masked."
+ raise self.MaskingError("Got unmasked payload from client")
+ mask_key = "\x00\x00\x00\x00"
+
+ if payload_len > self.MAX_MESSAGE_LENGTH:
+ raise ValueError("Refusing to buffer payload of %d bytes" % payload_len)
+
+ if len(self.buf) < offset + payload_len:
+ return None
+ payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key)
+ self.buf = self.buf[offset+payload_len:]
+
+ frame = WebSocketFrame()
+ frame.fin = fin
+ frame.opcode = opcode
+ frame.payload = payload
+
+ return frame
+
+ def read_message(self):
+ """Read a complete message. If the opcode is 1, the payload is decoded
+ from a UTF-8 binary string to a unicode string. If a control frame is
+ read while another fragmented message is in progress, the control frame
+ is returned as a new message immediately. Returns None if there is no
+ complete frame to be read."""
+ # RFC 6455 section 5.4 is about fragmentation.
+ while True:
+ frame = self.read_frame()
+ if frame is None:
+ return None
+ # "Control frames (see Section 5.5) MAY be injected in the middle of
+ # a fragmented message. Control frames themselves MUST NOT be
+ # fragmented.
+ if frame.is_control():
+ if not frame.fin:
+ raise ValueError("Control frame (opcode %d) has FIN bit clear" % frame.opcode)
+ message = WebSocketMessage()
+ message.opcode = frame.opcode
+ message.payload = frame.payload
+ return message
+
+ if self.message_opcode is None:
+ if frame.opcode == 0:
+ raise ValueError("First frame has opcode 0")
+ self.message_opcode = frame.opcode
+ else:
+ if frame.opcode != 0:
+ raise ValueError("Non-first frame has nonzero opcode %d" % frame.opcode)
+ self.message_buf += frame.payload
+
+ if frame.fin:
+ break
+ message = WebSocketMessage()
+ message.opcode = self.message_opcode
+ message.payload = self.message_buf
+ self.postprocess_message(message)
+ self.message_opcode = None
+ self.message_buf = ""
+
+ return message
+
+ def postprocess_message(self, message):
+ if message.opcode == 1:
+ message.payload = message.payload.decode("utf-8")
+ return message
+
def listen_socket(addr):
"""Return a nonblocking socket listening on the given address."""
addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]