[tor-commits] [flashproxy/master] Add WebSocketEncoder and tests.

dcf at torproject.org dcf at torproject.org
Mon Apr 9 04:08:42 UTC 2012


commit 7ea31933aef27276cb443eee49369bccb6661ffa
Author: David Fifield <david at bamsoftware.com>
Date:   Wed Mar 28 00:14:51 2012 -0700

    Add WebSocketEncoder and tests.
---
 connector-test.py |   28 +++++++++++++++++++++++++++-
 connector.py      |   51 ++++++++++++++++++++++++++++++++++++++++++---------
 2 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/connector-test.py b/connector-test.py
index c0479bd..a15c1ad 100755
--- a/connector-test.py
+++ b/connector-test.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 import unittest
-from connector import WebSocketDecoder
+from connector import WebSocketDecoder, WebSocketEncoder
 
 def read_frames(dec):
     frames = []
@@ -147,5 +147,31 @@ class TestWebSocketDecoder(unittest.TestCase):
         dec.feed("\x82\x7f\x00\x00\x00\x00\x01\x00\x00\x00")
         self.assertRaises(ValueError, dec.read_frame)
 
+class TestWebSocketEncoder(unittest.TestCase):
+    def test_length(self):
+        """Test that payload lengths are encoded using the smallest number of
+        bytes."""
+        TESTS = [(0, 0), (125, 0), (126, 2), (65535, 2), (65536, 8)]
+        for length, encoded_length in TESTS:
+            enc = WebSocketEncoder(use_mask = False)
+            eframe = enc.encode_frame(2, "\x00" * length)
+            self.assertEqual(len(eframe), 1 + 1 + encoded_length + length)
+            enc = WebSocketEncoder(use_mask = True)
+            eframe = enc.encode_frame(2, "\x00" * length)
+            self.assertEqual(len(eframe), 1 + 1 + encoded_length + 4 + length)
+
+    def test_roundtrip(self):
+        TESTS = [
+            (1, u"Hello world"),
+            (1, u"Hello \N{WHITE SMILING FACE}"),
+        ]
+        for opcode, payload in TESTS:
+            for use_mask in (False, True):
+                enc = WebSocketEncoder(use_mask = use_mask)
+                enc_message = enc.encode_message(opcode, payload)
+                dec = WebSocketDecoder(use_mask = use_mask)
+                dec.feed(enc_message)
+                self.assertEqual(read_messages(dec), [(opcode, payload)])
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/connector.py b/connector.py
index 39d89e8..2a06523 100755
--- a/connector.py
+++ b/connector.py
@@ -131,6 +131,13 @@ class BufferSocket(object):
         return time.time() - self.birthday > timeout
 
 
+def apply_mask(payload, mask_key):
+    result = []
+    for i, c in enumerate(payload):
+        mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
+        result.append(mc)
+    return "".join(result)
+
 class WebSocketFrame(object):
     def __init__(self):
         self.fin = False
@@ -171,14 +178,6 @@ class WebSocketDecoder(object):
     def feed(self, data):
         self.buf += data
 
-    @staticmethod
-    def mask(payload, mask_key):
-        result = []
-        for i, c in enumerate(payload):
-            mc = chr(ord(payload[i]) ^ ord(mask_key[i%4]))
-            result.append(mc)
-        return "".join(result)
-
     def read_frame(self):
         """Read a frame from the internal buffer, if one is available. Returns a
         WebSocketFrame object, or None if there are no complete frames to
@@ -226,7 +225,7 @@ class WebSocketDecoder(object):
 
         if len(self.buf) < offset + payload_len:
             return None
-        payload = WebSocketDecoder.mask(self.buf[offset:offset+payload_len], mask_key)
+        payload = apply_mask(self.buf[offset:offset+payload_len], mask_key)
         self.buf = self.buf[offset+payload_len:]
 
         frame = WebSocketFrame()
@@ -283,6 +282,40 @@ class WebSocketDecoder(object):
             message.payload = message.payload.decode("utf-8")
         return message
 
+class WebSocketEncoder(object):
+    def __init__(self, use_mask = False):
+        self.use_mask = use_mask
+
+    def encode_frame(self, opcode, payload):
+        if opcode >= 16:
+            raise ValueError("Opcode of %d is >= 16" % opcode)
+        length = len(payload)
+
+        if self.use_mask:
+            mask_key = os.urandom(4)
+            payload = apply_mask(payload, mask_key)
+            mask_bit = 0x80
+        else:
+            mask_key = ""
+            mask_bit = 0x00
+
+        if length < 126:
+            len_b, len_ext = length, ""
+        elif length < 0x10000:
+            len_b, len_ext = 126, struct.pack(">H", length)
+        elif length < 0x10000000000000000:
+            len_b, len_ext = 127, struct.pack(">Q", length)
+        else:
+            raise ValueError("payload length of %d is too long" % length)
+
+        return chr(0x80 | opcode) + chr(mask_bit | len_b) + len_ext + mask_key + payload
+
+    def encode_message(self, opcode, payload):
+        if opcode == 1:
+            payload = payload.encode("utf-8")
+        return self.encode_frame(opcode, payload)
+
+
 def listen_socket(addr):
     """Return a nonblocking socket listening on the given address."""
     addrinfo = socket.getaddrinfo(addr[0], addr[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0]





More information about the tor-commits mailing list