commit 4726aa424262be3af244a57fe274805da6f26842 Author: David Fifield david@bamsoftware.com Date: Tue Mar 27 22:59:31 2012 -0700
Add WebSocketDecoder and tests. --- Makefile | 1 + connector-test.py | 151 ++++++++++++++++++++++++++++++++++++++++++++++++++++ connector.py | 153 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 305 insertions(+), 0 deletions(-)
diff --git a/Makefile b/Makefile index 0fb0c79..61c3901 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,7 @@ clean: rm -f $(TARGETS)
test: + ./connector-test.py ./flashproxy-test.js
.PHONY: all clean test diff --git a/connector-test.py b/connector-test.py new file mode 100755 index 0000000..c0479bd --- /dev/null +++ b/connector-test.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest +from connector import WebSocketDecoder + +def read_frames(dec): + frames = [] + while True: + frame = dec.read_frame() + if frame is None: + break + frames.append((frame.fin, frame.opcode, frame.payload)) + return frames + +def read_messages(dec): + messages = [] + while True: + message = dec.read_message() + if message is None: + break + messages.append((message.opcode, message.payload)) + return messages + +class TestWebSocketDecoder(unittest.TestCase): + def test_rfc(self): + """Test samples from RFC 6455 section 5.7.""" + TESTS = [ + ("\x81\x05\x48\x65\x6c\x6c\x6f", False, + [(True, 1, "Hello")], + [(1, u"Hello")]), + ("\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True, + [(True, 1, "Hello")], + [(1, u"Hello")]), + ("\x01\x03\x48\x65\x6c\x80\x02\x6c\x6f", False, + [(False, 1, "Hel"), (True, 0, "lo")], + [(1, u"Hello")]), + ("\x89\x05\x48\x65\x6c\x6c\x6f", False, + [(True, 9, "Hello")], + [(9, u"Hello")]), + ("\x8a\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58", True, + [(True, 10, "Hello")], + [(10, u"Hello")]), + ("\x82\x7e\x01\x00" + "\x00" * 256, False, + [(True, 2, "\x00" * 256)], + [(2, "\x00" * 256)]), + ("\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + "\x00" * 65536, False, + [(True, 2, "\x00" * 65536)], + [(2, "\x00" * 65536)]), + ] + for data, use_mask, expected_frames, expected_messages in TESTS: + dec = WebSocketDecoder(use_mask = use_mask) + dec.feed(data) + actual_frames = read_frames(dec) + self.assertEqual(actual_frames, expected_frames) + + dec = WebSocketDecoder(use_mask = use_mask) + dec.feed(data) + actual_messages = read_messages(dec) + self.assertEqual(actual_messages, expected_messages) + + dec = WebSocketDecoder(use_mask = not use_mask) + dec.feed(data) + self.assertRaises(WebSocketDecoder.MaskingError, dec.read_frame) + + def test_empty_feed(self): + """Test that the decoder can handle a zero-byte feed.""" + dec = WebSocketDecoder() + self.assertIsNone(dec.read_frame()) + dec.feed("") + self.assertIsNone(dec.read_frame()) + dec.feed("\x81\x05H") + self.assertIsNone(dec.read_frame()) + dec.feed("ello") + self.assertEqual(read_frames(dec), [(True, 1, u"Hello")]) + + def test_empty_frame(self): + """Test that a frame may contain a zero-byte payload.""" + dec = WebSocketDecoder() + dec.feed("\x81\x00") + self.assertEqual(read_frames(dec), [(True, 1, u"")]) + dec.feed("\x82\x00") + self.assertEqual(read_frames(dec), [(True, 2, "")]) + + def test_empty_message(self): + """Test that a message may have a zero-byte payload.""" + dec = WebSocketDecoder() + dec.feed("\x01\x00\x00\x00\x80\x00") + self.assertEqual(read_messages(dec), [(1, u"")]) + dec.feed("\x02\x00\x00\x00\x80\x00") + self.assertEqual(read_messages(dec), [(2, "")]) + + def test_interleaved_control(self): + """Test that control messages interleaved with fragmented messages are + returned.""" + dec = WebSocketDecoder() + dec.feed("\x89\x04PING\x01\x03Hel\x8a\x04PONG\x80\x02lo\x89\x04PING") + self.assertEqual(read_messages(dec), [(9, "PING"), (10, "PONG"), (1, u"Hello"), (9, "PING")]) + + def test_fragmented_control(self): + """Test that illegal fragmented control messages cause an error.""" + dec = WebSocketDecoder() + dec.feed("\x09\x04PING") + self.assertRaises(ValueError, dec.read_message) + + def test_zero_opcode(self): + """Test that it is an error for the first frame in a message to have an + opcode of 0.""" + dec = WebSocketDecoder() + dec.feed("\x80\x05Hello") + self.assertRaises(ValueError, dec.read_message) + dec = WebSocketDecoder() + dec.feed("\x00\x05Hello") + self.assertRaises(ValueError, dec.read_message) + + def test_nonzero_opcode(self): + """Test that every frame after the first must have a zero opcode.""" + dec = WebSocketDecoder() + dec.feed("\x01\x01H\x01\x02el\x80\x02lo") + self.assertRaises(ValueError, dec.read_message) + dec = WebSocketDecoder() + dec.feed("\x01\x01H\x00\x02el\x01\x02lo") + self.assertRaises(ValueError, dec.read_message) + + def test_utf8(self): + """Test that text frames (opcode 1) are decoded from UTF-8.""" + text = u"Hello World or Καλημέρα κόσμε or こんにちは 世界 or \U0001f639" + utf8_text = text.encode("utf-8") + dec = WebSocketDecoder() + dec.feed("\x81" + chr(len(utf8_text)) + utf8_text) + self.assertEqual(read_messages(dec), [(1, text)]) + + def test_wrong_utf8(self): + """Test that failed UTF-8 decoding causes an error.""" + TESTS = [ + "\xc0\x41", # Non-shortest form. + "\xc2", # Unfinished sequence. + ] + for test in TESTS: + dec = WebSocketDecoder() + dec.feed("\x81" + chr(len(test)) + test) + self.assertRaises(ValueError, dec.read_message) + + def test_overly_large_payload(self): + """Test that large payloads are rejected.""" + dec = WebSocketDecoder() + dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00") + self.assertRaises(ValueError, dec.read_frame) + +if __name__ == "__main__": + unittest.main() diff --git a/connector.py b/connector.py index f84f8a2..39d89e8 100755 --- a/connector.py +++ b/connector.py @@ -130,6 +130,159 @@ class BufferSocket(object): def is_expired(self, timeout): return time.time() - self.birthday > timeout
+ +class WebSocketFrame(object): + def __init__(self): + self.fin = False + self.opcode = None + self.payload = None + + def is_control(self): + return (self.opcode & 0x08) != 0 + +class WebSocketMessage(object): + def __init__(self): + self.opcode = None + self.payload = None + + def is_control(self): + return (self.opcode & 0x08) != 0 + +class WebSocketDecoder(object): + """RFC 6455 section 5 is about the WebSocket framing format.""" + # Raise an exception rather than buffer anything arger than this. + MAX_MESSAGE_LENGTH = 1024 * 1024 + + class MaskingError(ValueError): + pass + + def __init__(self, use_mask = False): + """use_mask should be True for server-to-client sockets, and False for + client-to-server sockets.""" + self.use_mask = use_mask + + # Per-frame state. + self.buf = "" + + # Per-message state. + self.message_buf = "" + self.message_opcode = None + + def feed(self, data): + self.buf += data + + @staticmethod + def mask(payload, mask_key): + result = [] + for i, c in enumerate(payload): + mc = chr(ord(payload[i]) ^ ord(mask_key[i%4])) + result.append(mc) + return "".join(result) + + def read_frame(self): + """Read a frame from the internal buffer, if one is available. Returns a + WebSocketFrame object, or None if there are no complete frames to + read.""" + # RFC 6255 section 5.2. + if len(self.buf) < 2: + return None + offset = 0 + b0, b1 = struct.unpack_from(">BB", self.buf, offset) + offset += 2 + fin = (b0 & 0x80) != 0 + opcode = b0 & 0x0f + frame_masked = (b1 & 0x80) != 0 + payload_len = b1 & 0x7f + + if payload_len == 126: + if len(self.buf) < offset + 2: + return None + payload_len, = struct.unpack_from(">H", self.buf, offset) + offset += 2 + elif payload_len == 127: + if len(self.buf) < offset + 8: + return None + payload_len, = struct.unpack_from(">Q", self.buf, offset) + offset += 8 + + if frame_masked: + if not self.use_mask: + # "A client MUST close a connection if it detects a masked + # frame." + raise self.MaskingError("Got masked payload from server") + if len(self.buf) < offset + 4: + return None + mask_key = self.buf[offset:offset+4] + offset += 4 + else: + if self.use_mask: + # "The server MUST close the connection upon receiving a frame + # that is not masked." + raise self.MaskingError("Got unmasked payload from client") + mask_key = "\x00\x00\x00\x00" + + if payload_len > self.MAX_MESSAGE_LENGTH: + raise ValueError("Refusing to buffer payload of %d bytes" % payload_len) + + if len(self.buf) < offset + payload_len: + return None + payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key) + self.buf = self.buf[offset+payload_len:] + + frame = WebSocketFrame() + frame.fin = fin + frame.opcode = opcode + frame.payload = payload + + return frame + + def read_message(self): + """Read a complete message. If the opcode is 1, the payload is decoded + from a UTF-8 binary string to a unicode string. If a control frame is + read while another fragmented message is in progress, the control frame + is returned as a new message immediately. Returns None if there is no + complete frame to be read.""" + # RFC 6455 section 5.4 is about fragmentation. + while True: + frame = self.read_frame() + if frame is None: + return None + # "Control frames (see Section 5.5) MAY be injected in the middle of + # a fragmented message. Control frames themselves MUST NOT be + # fragmented. + if frame.is_control(): + if not frame.fin: + raise ValueError("Control frame (opcode %d) has FIN bit clear" % frame.opcode) + message = WebSocketMessage() + message.opcode = frame.opcode + message.payload = frame.payload + return message + + if self.message_opcode is None: + if frame.opcode == 0: + raise ValueError("First frame has opcode 0") + self.message_opcode = frame.opcode + else: + if frame.opcode != 0: + raise ValueError("Non-first frame has nonzero opcode %d" % frame.opcode) + self.message_buf += frame.payload + + if frame.fin: + break + message = WebSocketMessage() + message.opcode = self.message_opcode + message.payload = self.message_buf + self.postprocess_message(message) + self.message_opcode = None + self.message_buf = "" + + return message + + def postprocess_message(self, message): + if message.opcode == 1: + message.payload = message.payload.decode("utf-8") + return message + def listen_socket(addr): """Return a nonblocking socket listening on the given address.""" addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]