[tor-commits] [flashproxy/master] Add tests for WebSocket request handling.

dcf at torproject.org dcf at torproject.org
Fri Sep 7 12:06:14 UTC 2012


commit 0c1afe22b38596675f9aa06d75fbb754ea915638
Author: David Fifield <david at bamsoftware.com>
Date:   Fri Sep 7 05:04:53 2012 -0700

    Add tests for WebSocket request handling.
---
 flashproxy-client-test |  169 ++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 169 insertions(+), 0 deletions(-)

diff --git a/flashproxy-client-test b/flashproxy-client-test
index 527093c..084a272 100755
--- a/flashproxy-client-test
+++ b/flashproxy-client-test
@@ -1,11 +1,20 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import base64
+import cStringIO
+import httplib
 import socket
 import subprocess
 import sys
 import unittest
 
+try:
+    from hashlib import sha1
+except ImportError:
+    # Python 2.4 uses this name.
+    from sha import sha as sha1
+
 # Special tricks to load a module whose filename contains a dash and doesn't end
 # in ".py".
 import imp
@@ -13,6 +22,7 @@ dont_write_bytecode = sys.dont_write_bytecode
 sys.dont_write_bytecode = True
 flashproxy = imp.load_source("flashproxy", "flashproxy-client")
 parse_socks_request = flashproxy.parse_socks_request
+handle_websocket_request = flashproxy.handle_websocket_request
 WebSocketDecoder = flashproxy.WebSocketDecoder
 WebSocketEncoder = flashproxy.WebSocketEncoder
 sys.dont_write_bytecode = dont_write_bytecode
@@ -40,6 +50,165 @@ class TestSocks(unittest.TestCase):
     def test_parse_socks_request_hostname(self):
         dest, port = parse_socks_request("\x04\x01\x99\x99\x00\x00\x00\x01userid\x00abc\x00")
 
+class DummySocket(object):
+    def __init__(self, read_fd, write_fd):
+        self.read_fd = read_fd
+        self.write_fd = write_fd
+        self.readp = 0
+
+    def read(self, *args, **kwargs):
+        self.read_fd.seek(self.readp, 0)
+        data = self.read_fd.read(*args, **kwargs)
+        self.readp = self.read_fd.tell()
+        return data
+
+    def readline(self, *args, **kwargs):
+        self.read_fd.seek(self.readp, 0)
+        data = self.read_fd.readline(*args, **kwargs)
+        self.readp = self.read_fd.tell()
+        return data
+
+    def recv(self, size, *args, **kwargs):
+        return self.read(size)
+
+    def write(self, data):
+        self.write_fd.seek(0, 2)
+        self.write_fd.write(data)
+
+    def send(self, data, *args, **kwargs):
+        return self.write(data)
+
+    def sendall(self, data, *args, **kwargs):
+        return self.write(data)
+
+    def makefile(self, *args, **kwargs):
+        return self
+
+def dummy_socketpair():
+    f1 = cStringIO.StringIO()
+    f2 = cStringIO.StringIO()
+    return (DummySocket(f1, f2), DummySocket(f2, f1))
+
+class HTTPRequest(object):
+    def __init__(self):
+        self.method = "GET"
+        self.path = "/"
+        self.headers = {}
+
+def transact_http(req):
+    l, r = dummy_socketpair()
+    r.send("%s %s HTTP/1.0\r\n" % (req.method, req.path))
+    for k, v in req.headers.items():
+        r.send("%s: %s\r\n" % (k, v))
+    r.send("\r\n")
+    protocols = handle_websocket_request(l)
+
+    resp = httplib.HTTPResponse(r)
+    resp.begin()
+    return resp, protocols
+
+class TestHandleWebSocketRequest(unittest.TestCase):
+    DEFAULT_KEY = "0123456789ABCDEF"
+    DEFAULT_KEY_BASE64 = base64.b64encode(DEFAULT_KEY)
+    MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
+    @staticmethod
+    def default_req():
+        req = HTTPRequest()
+        req.method = "GET"
+        req.path = "/"
+        req.headers["Upgrade"] = "websocket"
+        req.headers["Connection"] = "Upgrade"
+        req.headers["Sec-WebSocket-Key"] = TestHandleWebSocketRequest.DEFAULT_KEY_BASE64
+        req.headers["Sec-WebSocket-Version"] = "13"
+
+        return req
+
+    def assert_ok(self, req):
+        resp, protocols = transact_http(req)
+        self.assertEqual(resp.status, 101)
+        self.assertEqual(resp.getheader("Upgrade").lower(), "websocket")
+        self.assertEqual(resp.getheader("Connection").lower(), "upgrade")
+        self.assertEqual(resp.getheader("Sec-WebSocket-Accept"), base64.b64encode(sha1(self.DEFAULT_KEY_BASE64 + self.MAGIC_GUID).digest()))
+        self.assertEqual(protocols, [])
+
+    def assert_not_ok(self, req):
+        resp, protocols = transact_http(req)
+        self.assertEqual(resp.status // 100, 4)
+        self.assertEqual(protocols, None)
+
+    def test_default(self):
+        req = self.default_req()
+        self.assert_ok(req)
+
+    def test_missing_upgrade(self):
+        req = self.default_req()
+        del req.headers["Upgrade"]
+        self.assert_not_ok(req)
+
+    def test_missing_connection(self):
+        req = self.default_req()
+        del req.headers["Connection"]
+        self.assert_not_ok(req)
+
+    def test_case_insensitivity(self):
+        """Test that the values of the Upgrade and Connection headers are
+        case-insensitive."""
+        req = self.default_req()
+        req.headers["Upgrade"] = req.headers["Upgrade"].lower()
+        self.assert_ok(req)
+        req.headers["Upgrade"] = req.headers["Upgrade"].upper()
+        self.assert_ok(req)
+        req.headers["Connection"] = req.headers["Connection"].lower()
+        self.assert_ok(req)
+        req.headers["Connection"] = req.headers["Connection"].upper()
+        self.assert_ok(req)
+
+    def test_bogus_key(self):
+        req = self.default_req()
+        req.headers["Sec-WebSocket-Key"] = base64.b64encode(self.DEFAULT_KEY[:-1])
+        self.assert_not_ok(req)
+
+        req.headers["Sec-WebSocket-Key"] = "///"
+        self.assert_not_ok(req)
+
+    def test_versions(self):
+        req = self.default_req()
+        req.headers["Sec-WebSocket-Version"] = "13"
+        self.assert_ok(req)
+        req.headers["Sec-WebSocket-Version"] = "8"
+        self.assert_ok(req)
+
+        req.headers["Sec-WebSocket-Version"] = "7"
+        self.assert_not_ok(req)
+        req.headers["Sec-WebSocket-Version"] = "9"
+        self.assert_not_ok(req)
+
+        del req.headers["Sec-WebSocket-Version"]
+        self.assert_not_ok(req)
+
+    def test_protocols(self):
+        req = self.default_req()
+        req.headers["Sec-WebSocket-Protocol"] = "base64"
+        resp, protocols = transact_http(req)
+        self.assertEqual(resp.status, 101)
+        self.assertEqual(protocols, ["base64"])
+        self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64")
+
+        req = self.default_req()
+        req.headers["Sec-WebSocket-Protocol"] = "cat"
+        resp, protocols = transact_http(req)
+        self.assertEqual(resp.status, 101)
+        self.assertEqual(protocols, ["cat"])
+        self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), None)
+
+        req = self.default_req()
+        req.headers["Sec-WebSocket-Protocol"] = "cat, base64"
+        resp, protocols = transact_http(req)
+        self.assertEqual(resp.status, 101)
+        self.assertEqual(protocols, ["cat", "base64"])
+        self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64")
+
 def read_frames(dec):
     frames = []
     while True:





More information about the tor-commits mailing list