commit 7ea31933aef27276cb443eee49369bccb6661ffa Author: David Fifield david@bamsoftware.com Date: Wed Mar 28 00:14:51 2012 -0700
Add WebSocketEncoder and tests. --- connector-test.py | 28 +++++++++++++++++++++++++++- connector.py | 51 ++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 69 insertions(+), 10 deletions(-)
diff --git a/connector-test.py b/connector-test.py index c0479bd..a15c1ad 100755 --- a/connector-test.py +++ b/connector-test.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*-
import unittest -from connector import WebSocketDecoder +from connector import WebSocketDecoder, WebSocketEncoder
def read_frames(dec): frames = [] @@ -147,5 +147,31 @@ class TestWebSocketDecoder(unittest.TestCase): dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00") self.assertRaises(ValueError, dec.read_frame)
+class TestWebSocketEncoder(unittest.TestCase): + def test_length(self): + """Test that payload lengths are encoded using the smallest number of + bytes.""" + TESTS = [(0, 0), (125, 0), (126, 2), (65535, 2), (65536, 8)] + for length, encoded_length in TESTS: + enc = WebSocketEncoder(use_mask = False) + eframe = enc.encode_frame(2, "\x00" * length) + self.assertEqual(len(eframe), 1 + 1 + encoded_length + length) + enc = WebSocketEncoder(use_mask = True) + eframe = enc.encode_frame(2, "\x00" * length) + self.assertEqual(len(eframe), 1 + 1 + encoded_length + 4 + length) + + def test_roundtrip(self): + TESTS = [ + (1, u"Hello world"), + (1, u"Hello \N{WHITE SMILING FACE}"), + ] + for opcode, payload in TESTS: + for use_mask in (False, True): + enc = WebSocketEncoder(use_mask = use_mask) + enc_message = enc.encode_message(opcode, payload) + dec = WebSocketDecoder(use_mask = use_mask) + dec.feed(enc_message) + self.assertEqual(read_messages(dec), [(opcode, payload)]) + if __name__ == "__main__": unittest.main() diff --git a/connector.py b/connector.py index 39d89e8..2a06523 100755 --- a/connector.py +++ b/connector.py @@ -131,6 +131,13 @@ class BufferSocket(object): return time.time() - self.birthday > timeout
+def apply_mask(payload, mask_key): + result = [] + for i, c in enumerate(payload): + mc = chr(ord(payload[i]) ^ ord(mask_key[i%4])) + result.append(mc) + return "".join(result) + class WebSocketFrame(object): def __init__(self): self.fin = False @@ -171,14 +178,6 @@ class WebSocketDecoder(object): def feed(self, data): self.buf += data
- @staticmethod - def mask(payload, mask_key): - result = [] - for i, c in enumerate(payload): - mc = chr(ord(payload[i]) ^ ord(mask_key[i%4])) - result.append(mc) - return "".join(result) - def read_frame(self): """Read a frame from the internal buffer, if one is available. Returns a WebSocketFrame object, or None if there are no complete frames to @@ -226,7 +225,7 @@ class WebSocketDecoder(object):
if len(self.buf) < offset + payload_len: return None - payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key) + payload = apply_mask(self.buf[offset:offset+payload_len], mask_key) self.buf = self.buf[offset+payload_len:]
frame = WebSocketFrame() @@ -283,6 +282,40 @@ class WebSocketDecoder(object): message.payload = message.payload.decode("utf-8") return message
+class WebSocketEncoder(object): + def __init__(self, use_mask = False): + self.use_mask = use_mask + + def encode_frame(self, opcode, payload): + if opcode >= 16: + raise ValueError("Opcode of %d is >= 16" % opcode) + length = len(payload) + + if self.use_mask: + mask_key = os.urandom(4) + payload = apply_mask(payload, mask_key) + mask_bit = 0x80 + else: + mask_key = "" + mask_bit = 0x00 + + if length < 126: + len_b, len_ext = length, "" + elif length < 0x10000: + len_b, len_ext = 126, struct.pack(">H", length) + elif length < 0x10000000000000000: + len_b, len_ext = 127, struct.pack(">Q", length) + else: + raise ValueError("payload length of %d is too long" % length) + + return chr(0x80 | opcode) + chr(mask_bit | len_b) + len_ext + mask_key + payload + + def encode_message(self, opcode, payload): + if opcode == 1: + payload = payload.encode("utf-8") + return self.encode_frame(opcode, payload) + + def listen_socket(addr): """Return a nonblocking socket listening on the given address.""" addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]