commit 0c1afe22b38596675f9aa06d75fbb754ea915638 Author: David Fifield david@bamsoftware.com Date: Fri Sep 7 05:04:53 2012 -0700
Add tests for WebSocket request handling. --- flashproxy-client-test | 169 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 169 insertions(+), 0 deletions(-)
diff --git a/flashproxy-client-test b/flashproxy-client-test index 527093c..084a272 100755 --- a/flashproxy-client-test +++ b/flashproxy-client-test @@ -1,11 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-
+import base64 +import cStringIO +import httplib import socket import subprocess import sys import unittest
+try: + from hashlib import sha1 +except ImportError: + # Python 2.4 uses this name. + from sha import sha as sha1 + # Special tricks to load a module whose filename contains a dash and doesn't end # in ".py". import imp @@ -13,6 +22,7 @@ dont_write_bytecode = sys.dont_write_bytecode sys.dont_write_bytecode = True flashproxy = imp.load_source("flashproxy", "flashproxy-client") parse_socks_request = flashproxy.parse_socks_request +handle_websocket_request = flashproxy.handle_websocket_request WebSocketDecoder = flashproxy.WebSocketDecoder WebSocketEncoder = flashproxy.WebSocketEncoder sys.dont_write_bytecode = dont_write_bytecode @@ -40,6 +50,165 @@ class TestSocks(unittest.TestCase): def test_parse_socks_request_hostname(self): dest, port = parse_socks_request("\x04\x01\x99\x99\x00\x00\x00\x01userid\x00abc\x00")
+class DummySocket(object): + def __init__(self, read_fd, write_fd): + self.read_fd = read_fd + self.write_fd = write_fd + self.readp = 0 + + def read(self, *args, **kwargs): + self.read_fd.seek(self.readp, 0) + data = self.read_fd.read(*args, **kwargs) + self.readp = self.read_fd.tell() + return data + + def readline(self, *args, **kwargs): + self.read_fd.seek(self.readp, 0) + data = self.read_fd.readline(*args, **kwargs) + self.readp = self.read_fd.tell() + return data + + def recv(self, size, *args, **kwargs): + return self.read(size) + + def write(self, data): + self.write_fd.seek(0, 2) + self.write_fd.write(data) + + def send(self, data, *args, **kwargs): + return self.write(data) + + def sendall(self, data, *args, **kwargs): + return self.write(data) + + def makefile(self, *args, **kwargs): + return self + +def dummy_socketpair(): + f1 = cStringIO.StringIO() + f2 = cStringIO.StringIO() + return (DummySocket(f1, f2), DummySocket(f2, f1)) + +class HTTPRequest(object): + def __init__(self): + self.method = "GET" + self.path = "/" + self.headers = {} + +def transact_http(req): + l, r = dummy_socketpair() + r.send("%s %s HTTP/1.0\r\n" % (req.method, req.path)) + for k, v in req.headers.items(): + r.send("%s: %s\r\n" % (k, v)) + r.send("\r\n") + protocols = handle_websocket_request(l) + + resp = httplib.HTTPResponse(r) + resp.begin() + return resp, protocols + +class TestHandleWebSocketRequest(unittest.TestCase): + DEFAULT_KEY = "0123456789ABCDEF" + DEFAULT_KEY_BASE64 = base64.b64encode(DEFAULT_KEY) + MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + @staticmethod + def default_req(): + req = HTTPRequest() + req.method = "GET" + req.path = "/" + req.headers["Upgrade"] = "websocket" + req.headers["Connection"] = "Upgrade" + req.headers["Sec-WebSocket-Key"] = TestHandleWebSocketRequest.DEFAULT_KEY_BASE64 + req.headers["Sec-WebSocket-Version"] = "13" + + return req + + def assert_ok(self, req): + resp, protocols = transact_http(req) + self.assertEqual(resp.status, 101) + self.assertEqual(resp.getheader("Upgrade").lower(), "websocket") + self.assertEqual(resp.getheader("Connection").lower(), "upgrade") + self.assertEqual(resp.getheader("Sec-WebSocket-Accept"), base64.b64encode(sha1(self.DEFAULT_KEY_BASE64 + self.MAGIC_GUID).digest())) + self.assertEqual(protocols, []) + + def assert_not_ok(self, req): + resp, protocols = transact_http(req) + self.assertEqual(resp.status // 100, 4) + self.assertEqual(protocols, None) + + def test_default(self): + req = self.default_req() + self.assert_ok(req) + + def test_missing_upgrade(self): + req = self.default_req() + del req.headers["Upgrade"] + self.assert_not_ok(req) + + def test_missing_connection(self): + req = self.default_req() + del req.headers["Connection"] + self.assert_not_ok(req) + + def test_case_insensitivity(self): + """Test that the values of the Upgrade and Connection headers are + case-insensitive.""" + req = self.default_req() + req.headers["Upgrade"] = req.headers["Upgrade"].lower() + self.assert_ok(req) + req.headers["Upgrade"] = req.headers["Upgrade"].upper() + self.assert_ok(req) + req.headers["Connection"] = req.headers["Connection"].lower() + self.assert_ok(req) + req.headers["Connection"] = req.headers["Connection"].upper() + self.assert_ok(req) + + def test_bogus_key(self): + req = self.default_req() + req.headers["Sec-WebSocket-Key"] = base64.b64encode(self.DEFAULT_KEY[:-1]) + self.assert_not_ok(req) + + req.headers["Sec-WebSocket-Key"] = "///" + self.assert_not_ok(req) + + def test_versions(self): + req = self.default_req() + req.headers["Sec-WebSocket-Version"] = "13" + self.assert_ok(req) + req.headers["Sec-WebSocket-Version"] = "8" + self.assert_ok(req) + + req.headers["Sec-WebSocket-Version"] = "7" + self.assert_not_ok(req) + req.headers["Sec-WebSocket-Version"] = "9" + self.assert_not_ok(req) + + del req.headers["Sec-WebSocket-Version"] + self.assert_not_ok(req) + + def test_protocols(self): + req = self.default_req() + req.headers["Sec-WebSocket-Protocol"] = "base64" + resp, protocols = transact_http(req) + self.assertEqual(resp.status, 101) + self.assertEqual(protocols, ["base64"]) + self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64") + + req = self.default_req() + req.headers["Sec-WebSocket-Protocol"] = "cat" + resp, protocols = transact_http(req) + self.assertEqual(resp.status, 101) + self.assertEqual(protocols, ["cat"]) + self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), None) + + req = self.default_req() + req.headers["Sec-WebSocket-Protocol"] = "cat, base64" + resp, protocols = transact_http(req) + self.assertEqual(resp.status, 101) + self.assertEqual(protocols, ["cat", "base64"]) + self.assertEqual(resp.getheader("Sec-WebSocket-Protocol"), "base64") + def read_frames(dec): frames = [] while True:
tor-commits@lists.torproject.org