commit 9c5acd661f1d616c65987545884912bd38c05404 Author: David Fifield david@bamsoftware.com Date: Sun Aug 11 00:07:12 2013 -0700
Use Go workspace layout.
http://golang.org/doc/code.html#Workspaces
This allows importing the pt package without using a local import like "./pt", which is not well supported by gccgo and I gather mostly discouraged. --- websocket-transport/Makefile | 13 +- websocket-transport/pt/pt.go | 602 -------------------- websocket-transport/socks.go | 107 ---- websocket-transport/src/pt/pt.go | 602 ++++++++++++++++++++ websocket-transport/src/websocket-client/socks.go | 107 ++++ .../src/websocket-client/websocket-client.go | 253 ++++++++ .../src/websocket-server/websocket-server.go | 275 +++++++++ .../src/websocket-server/websocket.go | 432 ++++++++++++++ websocket-transport/websocket-client.go | 253 -------- websocket-transport/websocket-server.go | 275 --------- websocket-transport/websocket.go | 432 -------------- 11 files changed, 1677 insertions(+), 1674 deletions(-)
diff --git a/websocket-transport/Makefile b/websocket-transport/Makefile index 6a6b8a7..e03f5b4 100644 --- a/websocket-transport/Makefile +++ b/websocket-transport/Makefile @@ -3,6 +3,7 @@ BINDIR = $(PREFIX)/bin
PROGRAMS = websocket-client websocket-server
+export GOPATH = $(CURDIR) GOBUILDFLAGS = # Alternate flags to use gccgo, allowing cross-compiling for x86 from # x86_64, and presumably better optimization. Install this package: @@ -11,11 +12,13 @@ GOBUILDFLAGS =
all: websocket-server
-websocket-client: websocket-client.go socks.go -websocket-server: websocket-server.go websocket.go +%: $(GOPATH)/src/%/*.go + go build $(GOBUILDFLAGS) "$*"
-%: %.go - go build $(GOBUILDFLAGS) -o $@ $^ +# websocket-client has a special rule because "go get" is necessary. +websocket-client: $(GOPATH)/src/websocket-client/*.go + go get -d $(GOBUILDFLAGS) websocket-client + go build $(GOBUILDFLAGS) websocket-client
install: mkdir -p "$(BINDIR)" @@ -25,6 +28,6 @@ clean: rm -f $(PROGRAMS)
fmt: - go fmt + go fmt $$(basename -a src/*)
.PHONY: all install clean fmt diff --git a/websocket-transport/pt/pt.go b/websocket-transport/pt/pt.go deleted file mode 100644 index 60e4507..0000000 --- a/websocket-transport/pt/pt.go +++ /dev/null @@ -1,602 +0,0 @@ -// Tor pluggable transports library. -// -// Sample client usage: -// -// pt.ClientSetup([]string{"foo"}) -// ln, err := startSocksListener() -// if err != nil { -// panic(err.Error()) -// } -// pt.Cmethod("foo", "socks4", ln.Addr()) -// pt.CmethodsDone() -// -// Sample server usage: -// -// var ptInfo pt.ServerInfo -// info = pt.ServerSetup([]string{"foo", "bar"}) -// for _, bindAddr := range info.BindAddrs { -// ln, err := startListener(bindAddr.Addr) -// if err != nil { -// pt.SmethodError(bindAddr.MethodName, err.Error()) -// } -// pt.Smethod(bindAddr.MethodName, ln.Addr()) -// } -// pt.SmethodsDone() -// func handler(conn net.Conn) { -// or, err := pt.ConnectOr(&ptInfo, ws.Conn) -// if err != nil { -// return -// } -// // Do something with or and conn. -// } - -package pt - -import ( - "bufio" - "bytes" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "os" - "strings" - "time" -) - -func getenv(key string) string { - return os.Getenv(key) -} - -// Abort with an ENV-ERROR if the environment variable isn't set. -func getenvRequired(key string) string { - value := os.Getenv(key) - if value == "" { - EnvError(fmt.Sprintf("no %s environment variable", key)) - } - return value -} - -// Escape a string so it contains no byte values over 127 and doesn't contain -// any of the characters '\x00', '\n', or '\'. -func escape(s string) string { - var buf bytes.Buffer - for _, b := range []byte(s) { - if b == '\n' { - buf.WriteString("\n") - } else if b == '\' { - buf.WriteString("\\") - } else if 0 < b && b < 128 { - buf.WriteByte(b) - } else { - fmt.Fprintf(&buf, "\x%02x", b) - } - } - return buf.String() -} - -// Print a pluggable transports protocol line to stdout. The line consists of an -// unescaped keyword, followed by any number of escaped strings. -func Line(keyword string, v ...string) { - var buf bytes.Buffer - buf.WriteString(keyword) - for _, x := range v { - buf.WriteString(" " + escape(x)) - } - fmt.Println(buf.String()) - os.Stdout.Sync() -} - -// All of the *Error functions call os.Exit(1). - -// Emit an ENV-ERROR with explanation text. -func EnvError(msg string) { - Line("ENV-ERROR", msg) - os.Exit(1) -} - -// Emit a VERSION-ERROR with explanation text. -func VersionError(msg string) { - Line("VERSION-ERROR", msg) - os.Exit(1) -} - -// Emit a CMETHOD-ERROR with explanation text. -func CmethodError(methodName, msg string) { - Line("CMETHOD-ERROR", methodName, msg) - os.Exit(1) -} - -// Emit an SMETHOD-ERROR with explanation text. -func SmethodError(methodName, msg string) { - Line("SMETHOD-ERROR", methodName, msg) - os.Exit(1) -} - -// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for -// each listening client SOCKS port. -func Cmethod(name string, socks string, addr net.Addr) { - Line("CMETHOD", name, socks, addr.String()) -} - -// Emit a CMETHODS DONE line. Call this after opening all client listeners. -func CmethodsDone() { - Line("CMETHODS", "DONE") -} - -// Emit an SMETHOD line. Call this once for each listening server port. -func Smethod(name string, addr net.Addr) { - Line("SMETHOD", name, addr.String()) -} - -// Emit an SMETHODS DONE line. Call this after opening all server listeners. -func SmethodsDone() { - Line("SMETHODS", "DONE") -} - -// Get a pluggable transports version offered by Tor and understood by us, if -// any. The only version we understand is "1". This function reads the -// environment variable TOR_PT_MANAGED_TRANSPORT_VER. -func getManagedTransportVer() string { - const transportVersion = "1" - for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") { - if offered == transportVersion { - return offered - } - } - return "" -} - -// Get the intersection of the method names offered by Tor and those in -// methodNames. This function reads the environment variable -// TOR_PT_CLIENT_TRANSPORTS. -func getClientTransports(methodNames []string) []string { - clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS") - if clientTransports == "*" { - return methodNames - } - result := make([]string, 0) - for _, requested := range strings.Split(clientTransports, ",") { - for _, methodName := range methodNames { - if requested == methodName { - result = append(result, methodName) - break - } - } - } - return result -} - -// This structure is returned by ClientSetup. It consists of a list of method -// names. -type ClientInfo struct { - MethodNames []string -} - -// Check the client pluggable transports environments, emitting an error message -// and exiting the program if any error is encountered. Returns a subset of -// methodNames requested by Tor. -func ClientSetup(methodNames []string) ClientInfo { - var info ClientInfo - - ver := getManagedTransportVer() - if ver == "" { - VersionError("no-version") - } else { - Line("VERSION", ver) - } - - info.MethodNames = getClientTransports(methodNames) - if len(info.MethodNames) == 0 { - CmethodsDone() - os.Exit(1) - } - - return info -} - -// A combination of a method name and an address, as extracted from -// TOR_PT_SERVER_BINDADDR. -type BindAddr struct { - MethodName string - Addr *net.TCPAddr -} - -// Resolve an address string into a net.TCPAddr. -func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) { - addr, err := net.ResolveTCPAddr("tcp", bindAddr) - if err == nil { - return addr, nil - } - // Before the fixing of bug #7011, tor doesn't put brackets around IPv6 - // addresses. Split after the last colon, assuming it is a port - // separator, and try adding the brackets. - parts := strings.Split(bindAddr, ":") - if len(parts) <= 2 { - return nil, err - } - bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1] - return net.ResolveTCPAddr("tcp", bindAddr) -} - -// Return a new slice, the members of which are those members of addrs having a -// MethodName in methodsNames. -func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr { - var result []BindAddr - - for _, ba := range addrs { - for _, methodName := range methodNames { - if ba.MethodName == methodName { - result = append(result, ba) - break - } - } - } - - return result -} - -// Return a map from method names to bind addresses. The map is the contents of -// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and -// further filtered by the methods in methodNames. -func getServerBindAddrs(methodNames []string) []BindAddr { - var result []BindAddr - - // Get the list of all requested bindaddrs. - var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR") - for _, spec := range strings.Split(serverBindAddr, ",") { - var bindAddr BindAddr - - parts := strings.SplitN(spec, "-", 2) - if len(parts) != 2 { - EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain "-"", spec)) - } - bindAddr.MethodName = parts[0] - addr, err := resolveBindAddr(parts[1]) - if err != nil { - EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error())) - } - bindAddr.Addr = addr - result = append(result, bindAddr) - } - - // Filter by TOR_PT_SERVER_TRANSPORTS. - serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS") - if serverTransports != "*" { - result = filterBindAddrs(result, strings.Split(serverTransports, ",")) - } - - // Finally filter by what we understand. - result = filterBindAddrs(result, methodNames) - - return result -} - -// Reads and validates the contents of an auth cookie file. Returns the 32-byte -// cookie. See section 4.2.1.2 of pt-spec.txt. -func readAuthCookieFile(filename string) ([]byte, error) { - authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a") - header := make([]byte, 32) - cookie := make([]byte, 32) - - f, err := os.Open(filename) - if err != nil { - return cookie, err - } - defer f.Close() - - n, err := io.ReadFull(f, header) - if err != nil { - return cookie, err - } - n, err = io.ReadFull(f, cookie) - if err != nil { - return cookie, err - } - // Check that the file ends here. - n, err = f.Read(make([]byte, 1)) - if n != 0 { - return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes")) - } else if err != io.EOF { - return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file")) - } - - if !bytes.Equal(header, authCookieHeader) { - return cookie, errors.New(fmt.Sprintf("missing auth cookie header")) - } - - return cookie, nil -} - -// This structure is returned by ServerSetup. It consists of a list of -// BindAddrs, along with a single address for the ORPort. -type ServerInfo struct { - BindAddrs []BindAddr - OrAddr *net.TCPAddr - ExtendedOrAddr *net.TCPAddr - AuthCookie []byte -} - -// Check the server pluggable transports environments, emitting an error message -// and exiting the program if any error is encountered. Resolves the various -// requested bind addresses and the server ORPort. Returns a ServerInfo struct. -func ServerSetup(methodNames []string) ServerInfo { - var info ServerInfo - var err error - - ver := getManagedTransportVer() - if ver == "" { - VersionError("no-version") - } else { - Line("VERSION", ver) - } - - var orPort = getenvRequired("TOR_PT_ORPORT") - info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort) - if err != nil { - EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error())) - } - - info.BindAddrs = getServerBindAddrs(methodNames) - if len(info.BindAddrs) == 0 { - SmethodsDone() - os.Exit(1) - } - - var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT") - if extendedOrPort != "" { - info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort) - if err != nil { - EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error())) - } - } - - var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE") - if authCookieFilename != "" { - info.AuthCookie, err = readAuthCookieFile(authCookieFilename) - if err != nil { - EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error())) - } - } - - return info -} - -// See 217-ext-orport-auth.txt section 4.2.1.3. -func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { - h := hmac.New(sha256.New, info.AuthCookie) - io.WriteString(h, "ExtORPort authentication server-to-client hash") - h.Write(clientNonce) - h.Write(serverNonce) - return h.Sum([]byte{}) -} - -// See 217-ext-orport-auth.txt section 4.2.1.3. -func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { - h := hmac.New(sha256.New, info.AuthCookie) - io.WriteString(h, "ExtORPort authentication client-to-server hash") - h.Write(clientNonce) - h.Write(serverNonce) - return h.Sum([]byte{}) -} - -func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error { - r := bufio.NewReader(s) - - // Read auth types. 217-ext-orport-auth.txt section 4.1. - var authTypes [256]bool - var count int - for count = 0; count < 256; count++ { - b, err := r.ReadByte() - if err != nil { - return err - } - if b == 0 { - break - } - authTypes[b] = true - } - if count >= 256 { - return errors.New(fmt.Sprintf("read 256 auth types without seeing \x00")) - } - - // We support only type 1, SAFE_COOKIE. - if !authTypes[1] { - return errors.New(fmt.Sprintf("server didn't offer auth type 1")) - } - _, err := s.Write([]byte{1}) - if err != nil { - return err - } - - clientNonce := make([]byte, 32) - clientHash := make([]byte, 32) - serverNonce := make([]byte, 32) - serverHash := make([]byte, 32) - - _, err = io.ReadFull(rand.Reader, clientNonce) - if err != nil { - return err - } - _, err = s.Write(clientNonce) - if err != nil { - return err - } - - _, err = io.ReadFull(r, serverHash) - if err != nil { - return err - } - _, err = io.ReadFull(r, serverNonce) - if err != nil { - return err - } - - expectedServerHash := computeServerHash(info, clientNonce, serverNonce) - if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 { - return errors.New(fmt.Sprintf("mismatch in server hash")) - } - - clientHash = computeClientHash(info, clientNonce, serverNonce) - _, err = s.Write(clientHash) - if err != nil { - return err - } - - status := make([]byte, 1) - _, err = io.ReadFull(r, status) - if err != nil { - return err - } - if status[0] != 1 { - return errors.New(fmt.Sprintf("server rejected authentication")) - } - - if r.Buffered() != 0 { - return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered())) - } - - return nil -} - -// See section 3.1 of 196-transport-control-ports.txt. -const ( - extOrCmdDone = 0x0000 - extOrCmdUserAddr = 0x0001 - extOrCmdTransport = 0x0002 - extOrCmdOkay = 0x1000 - extOrCmdDeny = 0x1001 -) - -func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error { - var buf bytes.Buffer - if len(body) > 65535 { - return errors.New("command exceeds maximum length of 65535") - } - err := binary.Write(&buf, binary.BigEndian, cmd) - if err != nil { - return err - } - err = binary.Write(&buf, binary.BigEndian, uint16(len(body))) - if err != nil { - return err - } - err = binary.Write(&buf, binary.BigEndian, body) - if err != nil { - return err - } - _, err = s.Write(buf.Bytes()) - if err != nil { - return err - } - - return nil -} - -// Send a USERADDR command on s. See section 3.1.2.1 of -// 196-transport-control-ports.txt. -func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error { - return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String())) -} - -// Send a TRANSPORT command on s. See section 3.1.2.2 of -// 196-transport-control-ports.txt. -func extOrPortSendTransport(s *net.TCPConn, methodName string) error { - return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName)) -} - -// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt. -func extOrPortSendDone(s *net.TCPConn) error { - return extOrPortWriteCommand(s, extOrCmdDone, []byte{}) -} - -func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) { - var bodyLen uint16 - data := make([]byte, 4) - - _, err = io.ReadFull(s, data) - if err != nil { - return - } - buf := bytes.NewBuffer(data) - err = binary.Read(buf, binary.BigEndian, &cmd) - if err != nil { - return - } - err = binary.Read(buf, binary.BigEndian, &bodyLen) - if err != nil { - return - } - body = make([]byte, bodyLen) - _, err = io.ReadFull(s, body) - if err != nil { - return - } - - return cmd, body, err -} - -// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an -// OKAY or DENY response command from the server. Returns nil if and only if -// OKAY is received. -func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error { - var err error - - err = extOrPortSendUserAddr(s, conn) - if err != nil { - return err - } - err = extOrPortSendTransport(s, methodName) - if err != nil { - return err - } - err = extOrPortSendDone(s) - if err != nil { - return err - } - cmd, _, err := extOrPortRecvCommand(s) - if err != nil { - return err - } - if cmd == extOrCmdDeny { - return errors.New("server returned DENY after our USERADDR and DONE") - } else if cmd != extOrCmdOkay { - return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd)) - } - - return nil -} - -// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an -// open *net.TCPConn. If connecting to the extended OR port, extended OR port -// authentication à la 217-ext-orport-auth.txt is done before returning; an -// error is returned if authentication fails. -func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) { - if info.ExtendedOrAddr == nil { - return net.DialTCP("tcp", nil, info.OrAddr) - } - - s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr) - if err != nil { - return nil, err - } - s.SetDeadline(time.Now().Add(5 * time.Second)) - err = extOrPortAuthenticate(s, info) - if err != nil { - s.Close() - return nil, err - } - err = extOrPortSetup(s, conn, methodName) - if err != nil { - s.Close() - return nil, err - } - s.SetDeadline(time.Time{}) - - return s, nil -} diff --git a/websocket-transport/socks.go b/websocket-transport/socks.go deleted file mode 100644 index 1fa847f..0000000 --- a/websocket-transport/socks.go +++ /dev/null @@ -1,107 +0,0 @@ -// SOCKS4a server library. - -package main - -import ( - "bufio" - "errors" - "fmt" - "io" - "net" -) - -const ( - socksVersion = 0x04 - socksCmdConnect = 0x01 - socksResponseVersion = 0x00 - socksRequestGranted = 0x5a - socksRequestFailed = 0x5b -) - -// Read a SOCKS4a connect request, and call the given connect callback with the -// requested destination string. If the callback returns an error, sends a SOCKS -// request failed message. Otherwise, sends a SOCKS request granted message for -// the destination address returned by the callback. -func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error { - dest, err := ReadSocks4aConnect(conn) - if err != nil { - SendSocks4aResponseFailed(conn) - return err - } - destAddr, err := connect(dest) - if err != nil { - SendSocks4aResponseFailed(conn) - return err - } - SendSocks4aResponseGranted(conn, destAddr) - return nil -} - -// Read a SOCKS4a connect request. Returns a "host:port" string. -func ReadSocks4aConnect(s io.Reader) (string, error) { - r := bufio.NewReader(s) - - var h [8]byte - n, err := io.ReadFull(r, h[:]) - if err != nil { - return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err)) - } - if h[0] != socksVersion { - return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion)) - } - if h[1] != socksCmdConnect { - return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect)) - } - - _, err = r.ReadBytes('\x00') - if err != nil { - return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err)) - } - - var port int - var host string - - port = int(h[2])<<8 | int(h[3])<<0 - if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 { - hostBytes, err := r.ReadBytes('\x00') - if err != nil { - return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err)) - } - host = string(hostBytes[:len(hostBytes)-1]) - } else { - host = net.IPv4(h[4], h[5], h[6], h[7]).String() - } - - if r.Buffered() != 0 { - return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered())) - } - - return fmt.Sprintf("%s:%d", host, port), nil -} - -// Send a SOCKS4a response with the given code and address. -func SendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error { - var resp [8]byte - resp[0] = socksResponseVersion - resp[1] = code - resp[2] = byte((addr.Port >> 8) & 0xff) - resp[3] = byte((addr.Port >> 0) & 0xff) - resp[4] = addr.IP[0] - resp[5] = addr.IP[1] - resp[6] = addr.IP[2] - resp[7] = addr.IP[3] - _, err := w.Write(resp[:]) - return err -} - -var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0} - -// Send a SOCKS4a response code 0x5a. -func SendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error { - return SendSocks4aResponse(w, socksRequestGranted, addr) -} - -// Send a SOCKS4a response code 0x5b (with an all-zero address). -func SendSocks4aResponseFailed(w io.Writer) error { - return SendSocks4aResponse(w, socksRequestFailed, &emptyAddr) -} diff --git a/websocket-transport/src/pt/pt.go b/websocket-transport/src/pt/pt.go new file mode 100644 index 0000000..60e4507 --- /dev/null +++ b/websocket-transport/src/pt/pt.go @@ -0,0 +1,602 @@ +// Tor pluggable transports library. +// +// Sample client usage: +// +// pt.ClientSetup([]string{"foo"}) +// ln, err := startSocksListener() +// if err != nil { +// panic(err.Error()) +// } +// pt.Cmethod("foo", "socks4", ln.Addr()) +// pt.CmethodsDone() +// +// Sample server usage: +// +// var ptInfo pt.ServerInfo +// info = pt.ServerSetup([]string{"foo", "bar"}) +// for _, bindAddr := range info.BindAddrs { +// ln, err := startListener(bindAddr.Addr) +// if err != nil { +// pt.SmethodError(bindAddr.MethodName, err.Error()) +// } +// pt.Smethod(bindAddr.MethodName, ln.Addr()) +// } +// pt.SmethodsDone() +// func handler(conn net.Conn) { +// or, err := pt.ConnectOr(&ptInfo, ws.Conn) +// if err != nil { +// return +// } +// // Do something with or and conn. +// } + +package pt + +import ( + "bufio" + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "time" +) + +func getenv(key string) string { + return os.Getenv(key) +} + +// Abort with an ENV-ERROR if the environment variable isn't set. +func getenvRequired(key string) string { + value := os.Getenv(key) + if value == "" { + EnvError(fmt.Sprintf("no %s environment variable", key)) + } + return value +} + +// Escape a string so it contains no byte values over 127 and doesn't contain +// any of the characters '\x00', '\n', or '\'. +func escape(s string) string { + var buf bytes.Buffer + for _, b := range []byte(s) { + if b == '\n' { + buf.WriteString("\n") + } else if b == '\' { + buf.WriteString("\\") + } else if 0 < b && b < 128 { + buf.WriteByte(b) + } else { + fmt.Fprintf(&buf, "\x%02x", b) + } + } + return buf.String() +} + +// Print a pluggable transports protocol line to stdout. The line consists of an +// unescaped keyword, followed by any number of escaped strings. +func Line(keyword string, v ...string) { + var buf bytes.Buffer + buf.WriteString(keyword) + for _, x := range v { + buf.WriteString(" " + escape(x)) + } + fmt.Println(buf.String()) + os.Stdout.Sync() +} + +// All of the *Error functions call os.Exit(1). + +// Emit an ENV-ERROR with explanation text. +func EnvError(msg string) { + Line("ENV-ERROR", msg) + os.Exit(1) +} + +// Emit a VERSION-ERROR with explanation text. +func VersionError(msg string) { + Line("VERSION-ERROR", msg) + os.Exit(1) +} + +// Emit a CMETHOD-ERROR with explanation text. +func CmethodError(methodName, msg string) { + Line("CMETHOD-ERROR", methodName, msg) + os.Exit(1) +} + +// Emit an SMETHOD-ERROR with explanation text. +func SmethodError(methodName, msg string) { + Line("SMETHOD-ERROR", methodName, msg) + os.Exit(1) +} + +// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for +// each listening client SOCKS port. +func Cmethod(name string, socks string, addr net.Addr) { + Line("CMETHOD", name, socks, addr.String()) +} + +// Emit a CMETHODS DONE line. Call this after opening all client listeners. +func CmethodsDone() { + Line("CMETHODS", "DONE") +} + +// Emit an SMETHOD line. Call this once for each listening server port. +func Smethod(name string, addr net.Addr) { + Line("SMETHOD", name, addr.String()) +} + +// Emit an SMETHODS DONE line. Call this after opening all server listeners. +func SmethodsDone() { + Line("SMETHODS", "DONE") +} + +// Get a pluggable transports version offered by Tor and understood by us, if +// any. The only version we understand is "1". This function reads the +// environment variable TOR_PT_MANAGED_TRANSPORT_VER. +func getManagedTransportVer() string { + const transportVersion = "1" + for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") { + if offered == transportVersion { + return offered + } + } + return "" +} + +// Get the intersection of the method names offered by Tor and those in +// methodNames. This function reads the environment variable +// TOR_PT_CLIENT_TRANSPORTS. +func getClientTransports(methodNames []string) []string { + clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS") + if clientTransports == "*" { + return methodNames + } + result := make([]string, 0) + for _, requested := range strings.Split(clientTransports, ",") { + for _, methodName := range methodNames { + if requested == methodName { + result = append(result, methodName) + break + } + } + } + return result +} + +// This structure is returned by ClientSetup. It consists of a list of method +// names. +type ClientInfo struct { + MethodNames []string +} + +// Check the client pluggable transports environments, emitting an error message +// and exiting the program if any error is encountered. Returns a subset of +// methodNames requested by Tor. +func ClientSetup(methodNames []string) ClientInfo { + var info ClientInfo + + ver := getManagedTransportVer() + if ver == "" { + VersionError("no-version") + } else { + Line("VERSION", ver) + } + + info.MethodNames = getClientTransports(methodNames) + if len(info.MethodNames) == 0 { + CmethodsDone() + os.Exit(1) + } + + return info +} + +// A combination of a method name and an address, as extracted from +// TOR_PT_SERVER_BINDADDR. +type BindAddr struct { + MethodName string + Addr *net.TCPAddr +} + +// Resolve an address string into a net.TCPAddr. +func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) { + addr, err := net.ResolveTCPAddr("tcp", bindAddr) + if err == nil { + return addr, nil + } + // Before the fixing of bug #7011, tor doesn't put brackets around IPv6 + // addresses. Split after the last colon, assuming it is a port + // separator, and try adding the brackets. + parts := strings.Split(bindAddr, ":") + if len(parts) <= 2 { + return nil, err + } + bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1] + return net.ResolveTCPAddr("tcp", bindAddr) +} + +// Return a new slice, the members of which are those members of addrs having a +// MethodName in methodsNames. +func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr { + var result []BindAddr + + for _, ba := range addrs { + for _, methodName := range methodNames { + if ba.MethodName == methodName { + result = append(result, ba) + break + } + } + } + + return result +} + +// Return a map from method names to bind addresses. The map is the contents of +// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and +// further filtered by the methods in methodNames. +func getServerBindAddrs(methodNames []string) []BindAddr { + var result []BindAddr + + // Get the list of all requested bindaddrs. + var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR") + for _, spec := range strings.Split(serverBindAddr, ",") { + var bindAddr BindAddr + + parts := strings.SplitN(spec, "-", 2) + if len(parts) != 2 { + EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain "-"", spec)) + } + bindAddr.MethodName = parts[0] + addr, err := resolveBindAddr(parts[1]) + if err != nil { + EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error())) + } + bindAddr.Addr = addr + result = append(result, bindAddr) + } + + // Filter by TOR_PT_SERVER_TRANSPORTS. + serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS") + if serverTransports != "*" { + result = filterBindAddrs(result, strings.Split(serverTransports, ",")) + } + + // Finally filter by what we understand. + result = filterBindAddrs(result, methodNames) + + return result +} + +// Reads and validates the contents of an auth cookie file. Returns the 32-byte +// cookie. See section 4.2.1.2 of pt-spec.txt. +func readAuthCookieFile(filename string) ([]byte, error) { + authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a") + header := make([]byte, 32) + cookie := make([]byte, 32) + + f, err := os.Open(filename) + if err != nil { + return cookie, err + } + defer f.Close() + + n, err := io.ReadFull(f, header) + if err != nil { + return cookie, err + } + n, err = io.ReadFull(f, cookie) + if err != nil { + return cookie, err + } + // Check that the file ends here. + n, err = f.Read(make([]byte, 1)) + if n != 0 { + return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes")) + } else if err != io.EOF { + return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file")) + } + + if !bytes.Equal(header, authCookieHeader) { + return cookie, errors.New(fmt.Sprintf("missing auth cookie header")) + } + + return cookie, nil +} + +// This structure is returned by ServerSetup. It consists of a list of +// BindAddrs, along with a single address for the ORPort. +type ServerInfo struct { + BindAddrs []BindAddr + OrAddr *net.TCPAddr + ExtendedOrAddr *net.TCPAddr + AuthCookie []byte +} + +// Check the server pluggable transports environments, emitting an error message +// and exiting the program if any error is encountered. Resolves the various +// requested bind addresses and the server ORPort. Returns a ServerInfo struct. +func ServerSetup(methodNames []string) ServerInfo { + var info ServerInfo + var err error + + ver := getManagedTransportVer() + if ver == "" { + VersionError("no-version") + } else { + Line("VERSION", ver) + } + + var orPort = getenvRequired("TOR_PT_ORPORT") + info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort) + if err != nil { + EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error())) + } + + info.BindAddrs = getServerBindAddrs(methodNames) + if len(info.BindAddrs) == 0 { + SmethodsDone() + os.Exit(1) + } + + var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT") + if extendedOrPort != "" { + info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort) + if err != nil { + EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error())) + } + } + + var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE") + if authCookieFilename != "" { + info.AuthCookie, err = readAuthCookieFile(authCookieFilename) + if err != nil { + EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error())) + } + } + + return info +} + +// See 217-ext-orport-auth.txt section 4.2.1.3. +func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { + h := hmac.New(sha256.New, info.AuthCookie) + io.WriteString(h, "ExtORPort authentication server-to-client hash") + h.Write(clientNonce) + h.Write(serverNonce) + return h.Sum([]byte{}) +} + +// See 217-ext-orport-auth.txt section 4.2.1.3. +func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { + h := hmac.New(sha256.New, info.AuthCookie) + io.WriteString(h, "ExtORPort authentication client-to-server hash") + h.Write(clientNonce) + h.Write(serverNonce) + return h.Sum([]byte{}) +} + +func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error { + r := bufio.NewReader(s) + + // Read auth types. 217-ext-orport-auth.txt section 4.1. + var authTypes [256]bool + var count int + for count = 0; count < 256; count++ { + b, err := r.ReadByte() + if err != nil { + return err + } + if b == 0 { + break + } + authTypes[b] = true + } + if count >= 256 { + return errors.New(fmt.Sprintf("read 256 auth types without seeing \x00")) + } + + // We support only type 1, SAFE_COOKIE. + if !authTypes[1] { + return errors.New(fmt.Sprintf("server didn't offer auth type 1")) + } + _, err := s.Write([]byte{1}) + if err != nil { + return err + } + + clientNonce := make([]byte, 32) + clientHash := make([]byte, 32) + serverNonce := make([]byte, 32) + serverHash := make([]byte, 32) + + _, err = io.ReadFull(rand.Reader, clientNonce) + if err != nil { + return err + } + _, err = s.Write(clientNonce) + if err != nil { + return err + } + + _, err = io.ReadFull(r, serverHash) + if err != nil { + return err + } + _, err = io.ReadFull(r, serverNonce) + if err != nil { + return err + } + + expectedServerHash := computeServerHash(info, clientNonce, serverNonce) + if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 { + return errors.New(fmt.Sprintf("mismatch in server hash")) + } + + clientHash = computeClientHash(info, clientNonce, serverNonce) + _, err = s.Write(clientHash) + if err != nil { + return err + } + + status := make([]byte, 1) + _, err = io.ReadFull(r, status) + if err != nil { + return err + } + if status[0] != 1 { + return errors.New(fmt.Sprintf("server rejected authentication")) + } + + if r.Buffered() != 0 { + return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered())) + } + + return nil +} + +// See section 3.1 of 196-transport-control-ports.txt. +const ( + extOrCmdDone = 0x0000 + extOrCmdUserAddr = 0x0001 + extOrCmdTransport = 0x0002 + extOrCmdOkay = 0x1000 + extOrCmdDeny = 0x1001 +) + +func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error { + var buf bytes.Buffer + if len(body) > 65535 { + return errors.New("command exceeds maximum length of 65535") + } + err := binary.Write(&buf, binary.BigEndian, cmd) + if err != nil { + return err + } + err = binary.Write(&buf, binary.BigEndian, uint16(len(body))) + if err != nil { + return err + } + err = binary.Write(&buf, binary.BigEndian, body) + if err != nil { + return err + } + _, err = s.Write(buf.Bytes()) + if err != nil { + return err + } + + return nil +} + +// Send a USERADDR command on s. See section 3.1.2.1 of +// 196-transport-control-ports.txt. +func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error { + return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String())) +} + +// Send a TRANSPORT command on s. See section 3.1.2.2 of +// 196-transport-control-ports.txt. +func extOrPortSendTransport(s *net.TCPConn, methodName string) error { + return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName)) +} + +// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt. +func extOrPortSendDone(s *net.TCPConn) error { + return extOrPortWriteCommand(s, extOrCmdDone, []byte{}) +} + +func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) { + var bodyLen uint16 + data := make([]byte, 4) + + _, err = io.ReadFull(s, data) + if err != nil { + return + } + buf := bytes.NewBuffer(data) + err = binary.Read(buf, binary.BigEndian, &cmd) + if err != nil { + return + } + err = binary.Read(buf, binary.BigEndian, &bodyLen) + if err != nil { + return + } + body = make([]byte, bodyLen) + _, err = io.ReadFull(s, body) + if err != nil { + return + } + + return cmd, body, err +} + +// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an +// OKAY or DENY response command from the server. Returns nil if and only if +// OKAY is received. +func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error { + var err error + + err = extOrPortSendUserAddr(s, conn) + if err != nil { + return err + } + err = extOrPortSendTransport(s, methodName) + if err != nil { + return err + } + err = extOrPortSendDone(s) + if err != nil { + return err + } + cmd, _, err := extOrPortRecvCommand(s) + if err != nil { + return err + } + if cmd == extOrCmdDeny { + return errors.New("server returned DENY after our USERADDR and DONE") + } else if cmd != extOrCmdOkay { + return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd)) + } + + return nil +} + +// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an +// open *net.TCPConn. If connecting to the extended OR port, extended OR port +// authentication à la 217-ext-orport-auth.txt is done before returning; an +// error is returned if authentication fails. +func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) { + if info.ExtendedOrAddr == nil { + return net.DialTCP("tcp", nil, info.OrAddr) + } + + s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr) + if err != nil { + return nil, err + } + s.SetDeadline(time.Now().Add(5 * time.Second)) + err = extOrPortAuthenticate(s, info) + if err != nil { + s.Close() + return nil, err + } + err = extOrPortSetup(s, conn, methodName) + if err != nil { + s.Close() + return nil, err + } + s.SetDeadline(time.Time{}) + + return s, nil +} diff --git a/websocket-transport/src/websocket-client/socks.go b/websocket-transport/src/websocket-client/socks.go new file mode 100644 index 0000000..1fa847f --- /dev/null +++ b/websocket-transport/src/websocket-client/socks.go @@ -0,0 +1,107 @@ +// SOCKS4a server library. + +package main + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" +) + +const ( + socksVersion = 0x04 + socksCmdConnect = 0x01 + socksResponseVersion = 0x00 + socksRequestGranted = 0x5a + socksRequestFailed = 0x5b +) + +// Read a SOCKS4a connect request, and call the given connect callback with the +// requested destination string. If the callback returns an error, sends a SOCKS +// request failed message. Otherwise, sends a SOCKS request granted message for +// the destination address returned by the callback. +func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error { + dest, err := ReadSocks4aConnect(conn) + if err != nil { + SendSocks4aResponseFailed(conn) + return err + } + destAddr, err := connect(dest) + if err != nil { + SendSocks4aResponseFailed(conn) + return err + } + SendSocks4aResponseGranted(conn, destAddr) + return nil +} + +// Read a SOCKS4a connect request. Returns a "host:port" string. +func ReadSocks4aConnect(s io.Reader) (string, error) { + r := bufio.NewReader(s) + + var h [8]byte + n, err := io.ReadFull(r, h[:]) + if err != nil { + return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err)) + } + if h[0] != socksVersion { + return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion)) + } + if h[1] != socksCmdConnect { + return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect)) + } + + _, err = r.ReadBytes('\x00') + if err != nil { + return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err)) + } + + var port int + var host string + + port = int(h[2])<<8 | int(h[3])<<0 + if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 { + hostBytes, err := r.ReadBytes('\x00') + if err != nil { + return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err)) + } + host = string(hostBytes[:len(hostBytes)-1]) + } else { + host = net.IPv4(h[4], h[5], h[6], h[7]).String() + } + + if r.Buffered() != 0 { + return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered())) + } + + return fmt.Sprintf("%s:%d", host, port), nil +} + +// Send a SOCKS4a response with the given code and address. +func SendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error { + var resp [8]byte + resp[0] = socksResponseVersion + resp[1] = code + resp[2] = byte((addr.Port >> 8) & 0xff) + resp[3] = byte((addr.Port >> 0) & 0xff) + resp[4] = addr.IP[0] + resp[5] = addr.IP[1] + resp[6] = addr.IP[2] + resp[7] = addr.IP[3] + _, err := w.Write(resp[:]) + return err +} + +var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0} + +// Send a SOCKS4a response code 0x5a. +func SendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error { + return SendSocks4aResponse(w, socksRequestGranted, addr) +} + +// Send a SOCKS4a response code 0x5b (with an all-zero address). +func SendSocks4aResponseFailed(w io.Writer) error { + return SendSocks4aResponse(w, socksRequestFailed, &emptyAddr) +} diff --git a/websocket-transport/src/websocket-client/websocket-client.go b/websocket-transport/src/websocket-client/websocket-client.go new file mode 100644 index 0000000..1bfc746 --- /dev/null +++ b/websocket-transport/src/websocket-client/websocket-client.go @@ -0,0 +1,253 @@ +// Tor websocket client transport plugin. +// +// Usage: +// ClientTransportPlugin websocket exec ./websocket-client + +package main + +import ( + "code.google.com/p/go.net/websocket" + "flag" + "fmt" + "io" + "net" + "net/url" + "os" + "os/signal" + "sync" + "time" +) + +import "pt" + +const ptMethodName = "websocket" +const socksTimeout = 2 * time.Second +const bufSiz = 1500 + +var logFile = os.Stderr + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +var logMutex sync.Mutex + +func usage() { + fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) + fmt.Printf("WebSocket client pluggable transport for Tor.\n") + fmt.Printf("Works only as a managed proxy.\n") + fmt.Printf("\n") + fmt.Printf(" -h, --help show this help.\n") + fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") + fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n") +} + +func Log(format string, v ...interface{}) { + dateStr := time.Now().Format("2006-01-02 15:04:05") + logMutex.Lock() + defer logMutex.Unlock() + msg := fmt.Sprintf(format, v...) + fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) +} + +func proxy(local *net.TCPConn, ws *websocket.Conn) { + var wg sync.WaitGroup + + wg.Add(2) + + // Local-to-WebSocket read loop. + go func() { + buf := make([]byte, bufSiz) + var err error + for { + n, er := local.Read(buf[:]) + if n > 0 { + ew := websocket.Message.Send(ws, buf[:n]) + if ew != nil { + err = ew + break + } + } + if er != nil { + err = er + break + } + } + if err != nil && err != io.EOF { + Log("%s", err) + } + local.CloseRead() + ws.Close() + + wg.Done() + }() + + // WebSocket-to-local read loop. + go func() { + var buf []byte + var err error + for { + er := websocket.Message.Receive(ws, &buf) + if er != nil { + err = er + break + } + n, ew := local.Write(buf) + if ew != nil { + err = ew + break + } + if n != len(buf) { + err = io.ErrShortWrite + break + } + } + if err != nil && err != io.EOF { + Log("%s", err) + } + local.CloseWrite() + ws.Close() + + wg.Done() + }() + + wg.Wait() +} + +func handleConnection(conn *net.TCPConn) error { + defer conn.Close() + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + var ws *websocket.Conn + + conn.SetDeadline(time.Now().Add(socksTimeout)) + err := AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) { + // Disable deadline. + conn.SetDeadline(time.Time{}) + Log("SOCKS request for %s", dest) + destAddr, err := net.ResolveTCPAddr("tcp", dest) + if err != nil { + return nil, err + } + wsUrl := url.URL{Scheme: "ws", Host: dest} + ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String()) + if err != nil { + return nil, err + } + Log("WebSocket connection to %s", ws.Config().Location.String()) + return destAddr, nil + }) + if err != nil { + return err + } + defer ws.Close() + proxy(conn, ws) + return nil +} + +func socksAcceptLoop(ln *net.TCPListener) error { + for { + socks, err := ln.AcceptTCP() + if err != nil { + return err + } + go func() { + err := handleConnection(socks) + if err != nil { + Log("SOCKS from %s: %s", socks.RemoteAddr(), err) + } + }() + } + return nil +} + +func startListener(addrStr string) (*net.TCPListener, error) { + addr, err := net.ResolveTCPAddr("tcp", addrStr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go func() { + err := socksAcceptLoop(ln) + if err != nil { + Log("accept: %s", err) + } + }() + return ln, nil +} + +func main() { + var logFilename string + var socksAddrStrs = []string{"127.0.0.1:0"} + var socksArg string + + flag.Usage = usage + flag.StringVar(&logFilename, "log", "", "log file to write to") + flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections") + flag.Parse() + + if logFilename != "" { + f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) + os.Exit(1) + } + logFile = f + } + + if socksArg != "" { + socksAddrStrs = []string{socksArg} + } + + Log("starting") + pt.ClientSetup([]string{ptMethodName}) + + listeners := make([]*net.TCPListener, 0) + for _, socksAddrStr := range socksAddrStrs { + ln, err := startListener(socksAddrStr) + if err != nil { + pt.CmethodError(ptMethodName, err.Error()) + } + pt.Cmethod(ptMethodName, "socks4", ln.Addr()) + Log("listening on %s", ln.Addr().String()) + listeners = append(listeners, ln) + } + pt.CmethodsDone() + + var numHandlers int = 0 + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + var sigint bool = false + for !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } + + for _, ln := range listeners { + ln.Close() + } + + sigint = false + for numHandlers != 0 && !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } +} diff --git a/websocket-transport/src/websocket-server/websocket-server.go b/websocket-transport/src/websocket-server/websocket-server.go new file mode 100644 index 0000000..6e6c0b7 --- /dev/null +++ b/websocket-transport/src/websocket-server/websocket-server.go @@ -0,0 +1,275 @@ +// Tor websocket server transport plugin. +// +// Usage: +// ServerTransportPlugin websocket exec ./websocket-server --port 9901 + +package main + +import ( + "encoding/base64" + "errors" + "flag" + "fmt" + "io" + "net" + "net/http" + "os" + "os/signal" + "sync" + "time" +) + +import "pt" + +const ptMethodName = "websocket" +const requestTimeout = 10 * time.Second +// "4/3+1" accounts for possible base64 encoding. +const maxMessageSize = 64*1024*4/3 + 1 + +var logFile = os.Stderr + +var ptInfo pt.ServerInfo + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +var logMutex sync.Mutex + +func usage() { + fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) + fmt.Printf("WebSocket server pluggable transport for Tor.\n") + fmt.Printf("Works only as a managed proxy.\n") + fmt.Printf("\n") + fmt.Printf(" -h, --help show this help.\n") + fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") + fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n") +} + +func Log(format string, v ...interface{}) { + dateStr := time.Now().Format("2006-01-02 15:04:05") + logMutex.Lock() + defer logMutex.Unlock() + msg := fmt.Sprintf(format, v...) + fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) +} + +// An abstraction that makes an underlying WebSocket connection look like an +// io.ReadWriteCloser. It internally takes care of things like base64 encoding and +// decoding. +type websocketConn struct { + Ws *Websocket + Base64 bool + messageBuf []byte +} + +// Implements io.Reader. +func (conn *websocketConn) Read(b []byte) (n int, err error) { + for len(conn.messageBuf) == 0 { + var m WebsocketMessage + m, err = conn.Ws.ReadMessage() + if err != nil { + return + } + if m.Opcode == 8 { + err = io.EOF + return + } + if conn.Base64 { + if m.Opcode != 1 { + err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode)) + return + } + conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload))) + var num int + num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload) + if err != nil { + return + } + conn.messageBuf = conn.messageBuf[:num] + } else { + if m.Opcode != 2 { + err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode)) + return + } + conn.messageBuf = m.Payload + } + } + + n = copy(b, conn.messageBuf) + conn.messageBuf = conn.messageBuf[n:] + + return +} + +// Implements io.Writer. +func (conn *websocketConn) Write(b []byte) (n int, err error) { + if conn.Base64 { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(b))) + base64.StdEncoding.Encode(buf, b) + err = conn.Ws.WriteMessage(1, buf) + if err != nil { + return + } + n = len(b) + } else { + err = conn.Ws.WriteMessage(2, b) + n = len(b) + } + return +} + +// Implements io.Closer. +func (conn *websocketConn) Close() error { + // Ignore any error in trying to write a Close frame. + _ = conn.Ws.WriteFrame(8, nil) + return conn.Ws.Conn.Close() +} + +// Create a new websocketConn. +func NewWebsocketConn(ws *Websocket) websocketConn { + var conn websocketConn + conn.Ws = ws + conn.Base64 = (ws.Subprotocol == "base64") + return conn +} + +// Copy from WebSocket to socket and vice versa. +func proxy(local *net.TCPConn, conn *websocketConn) { + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + _, err := io.Copy(conn, local) + if err != nil { + Log("error copying ORPort to WebSocket") + } + local.CloseRead() + conn.Close() + wg.Done() + }() + + go func() { + _, err := io.Copy(local, conn) + if err != nil { + Log("error copying WebSocket to ORPort") + } + local.CloseWrite() + conn.Close() + wg.Done() + }() + + wg.Wait() +} + +func websocketHandler(ws *Websocket) { + // Undo timeouts on HTTP request handling. + ws.Conn.SetDeadline(time.Time{}) + conn := NewWebsocketConn(ws) + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName) + if err != nil { + Log("Failed to connect to ORPort: " + err.Error()) + return + } + + proxy(s, &conn) +} + +func startListener(addr *net.TCPAddr) (*net.TCPListener, error) { + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go func() { + var config WebsocketConfig + config.Subprotocols = []string{"base64"} + config.MaxMessageSize = maxMessageSize + s := &http.Server{ + Handler: config.Handler(websocketHandler), + ReadTimeout: requestTimeout, + } + err = s.Serve(ln) + if err != nil { + Log("http.Serve: " + err.Error()) + } + }() + return ln, nil +} + +func main() { + var logFilename string + var port int + + flag.Usage = usage + flag.StringVar(&logFilename, "log", "", "log file to write to") + flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor") + flag.Parse() + + if logFilename != "" { + f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) + os.Exit(1) + } + logFile = f + } + + Log("starting") + ptInfo = pt.ServerSetup([]string{ptMethodName}) + + listeners := make([]*net.TCPListener, 0) + for _, bindAddr := range ptInfo.BindAddrs { + // Override tor's requested port (which is 0 if this transport + // has not been run before) with the one requested by the --port + // option. + if port != 0 { + bindAddr.Addr.Port = port + } + + ln, err := startListener(bindAddr.Addr) + if err != nil { + pt.SmethodError(bindAddr.MethodName, err.Error()) + } + pt.Smethod(bindAddr.MethodName, ln.Addr()) + Log("listening on %s", ln.Addr().String()) + listeners = append(listeners, ln) + } + pt.SmethodsDone() + + var numHandlers int = 0 + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + var sigint bool = false + for !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } + + for _, ln := range listeners { + ln.Close() + } + + sigint = false + for numHandlers != 0 && !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } +} diff --git a/websocket-transport/src/websocket-server/websocket.go b/websocket-transport/src/websocket-server/websocket.go new file mode 100644 index 0000000..f195828 --- /dev/null +++ b/websocket-transport/src/websocket-server/websocket.go @@ -0,0 +1,432 @@ +// WebSocket library. Only the RFC 6455 variety of WebSocket is supported. +// +// Reading and writing is strictly per-frame (or per-message). There is no way +// to partially read a frame. WebsocketConfig.MaxMessageSize affords control of +// the maximum buffering of messages. +// +// The reason for using this custom implementation instead of +// code.google.com/p/go.net/websocket is that the latter has problems with long +// messages and does not support server subprotocols. +// "Denial of Service Protection in Go HTTP Servers" +// https://code.google.com/p/go/issues/detail?id=2093 +// "go.websocket: Read/Copy fail with long frames" +// https://code.google.com/p/go/issues/detail?id=2134 +// http://golang.org/pkg/net/textproto/#pkg-bugs +// "To let callers manage exposure to denial of service attacks, Reader should +// allow them to set and reset a limit on the number of bytes read from the +// connection." +// "websocket.Dial doesn't limit response header length as http.Get does" +// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI +// +// Example usage: +// +// func doSomething(ws *Websocket) { +// } +// var config WebsocketConfig +// config.Subprotocols = []string{"base64"} +// config.MaxMessageSize = 2500 +// http.Handle("/", config.Handler(doSomething)) +// err = http.ListenAndServe(":8080", nil) + +package main + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" +) + +// Settings for potential WebSocket connections. Subprotocols is a list of +// supported subprotocols as in RFC 6455 section 1.9. When answering client +// requests, the first of the client's requests subprotocols that is also in +// this list (if any) will be used as the subprotocol for the connection. +// MaxMessageSize is a limit on buffering messages. +type WebsocketConfig struct { + Subprotocols []string + MaxMessageSize int +} + +// Representation of a WebSocket frame. The Payload is always without masking. +type WebsocketFrame struct { + Fin bool + Opcode byte + Payload []byte +} + +// Return true iff the frame's opcode says it is a control frame. +func (frame *WebsocketFrame) IsControl() bool { + return (frame.Opcode & 0x08) != 0 +} + +// Representation of a WebSocket message. The Payload is always without masking. +type WebsocketMessage struct { + Opcode byte + Payload []byte +} + +// A WebSocket connection after hijacking from HTTP. +type Websocket struct { + // Conn and ReadWriter from http.ResponseWriter.Hijack. + Conn net.Conn + Bufrw *bufio.ReadWriter + // Whether we are a client or a server has implications for masking. + IsClient bool + // Set from a parent WebsocketConfig. + MaxMessageSize int + // The single selected subprotocol after negotiation, or "". + Subprotocol string + // Buffer for message payloads, which may be interrupted by control + // messages. + messageBuf bytes.Buffer +} + +func applyMask(payload []byte, maskKey [4]byte) { + for i := 0; i < len(payload); i++ { + payload[i] = payload[i] ^ maskKey[i%4] + } +} + +func (ws *Websocket) maxMessageSize() int { + if ws.MaxMessageSize == 0 { + return 64000 + } + return ws.MaxMessageSize +} + +// Read a single frame from the WebSocket. +func (ws *Websocket) ReadFrame() (frame WebsocketFrame, err error) { + var b byte + err = binary.Read(ws.Bufrw, binary.BigEndian, &b) + if err != nil { + return + } + frame.Fin = (b & 0x80) != 0 + frame.Opcode = b & 0x0f + err = binary.Read(ws.Bufrw, binary.BigEndian, &b) + if err != nil { + return + } + masked := (b & 0x80) != 0 + + payloadLen := uint64(b & 0x7f) + if payloadLen == 126 { + var short uint16 + err = binary.Read(ws.Bufrw, binary.BigEndian, &short) + if err != nil { + return + } + payloadLen = uint64(short) + } else if payloadLen == 127 { + var long uint64 + err = binary.Read(ws.Bufrw, binary.BigEndian, &long) + if err != nil { + return + } + payloadLen = long + } + if payloadLen > uint64(ws.maxMessageSize()) { + err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize)) + return + } + + maskKey := [4]byte{} + if masked { + if ws.IsClient { + err = errors.New("client got masked frame") + return + } + err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey) + if err != nil { + return + } + } else { + if !ws.IsClient { + err = errors.New("server got unmasked frame") + return + } + } + + frame.Payload = make([]byte, payloadLen) + _, err = io.ReadFull(ws.Bufrw, frame.Payload) + if err != nil { + return + } + if masked { + applyMask(frame.Payload, maskKey) + } + + return frame, nil +} + +// Read a single message from the WebSocket. Multiple fragmented frames are +// combined into a single message before being returned. Non-control messages +// may be interrupted by control frames. The control frames are returned as +// individual messages before the message that they interrupt. +func (ws *Websocket) ReadMessage() (message WebsocketMessage, err error) { + var opcode byte = 0 + for { + var frame WebsocketFrame + frame, err = ws.ReadFrame() + if err != nil { + return + } + if frame.IsControl() { + if !frame.Fin { + err = errors.New("control frame has fin bit unset") + return + } + message.Opcode = frame.Opcode + message.Payload = frame.Payload + return message, nil + } + + if opcode == 0 { + if frame.Opcode == 0 { + err = errors.New("first frame has opcode 0") + return + } + opcode = frame.Opcode + } else { + if frame.Opcode != 0 { + err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode)) + return + } + } + if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize { + err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d", + ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize)) + return + } + ws.messageBuf.Write(frame.Payload) + if frame.Fin { + break + } + } + message.Opcode = opcode + message.Payload = ws.messageBuf.Bytes() + ws.messageBuf.Reset() + + return message, nil +} + +// Write a single frame to the WebSocket stream. Destructively masks payload in +// place if ws.IsClient. Frames are always unfragmented. +func (ws *Websocket) WriteFrame(opcode byte, payload []byte) (err error) { + if opcode >= 16 { + err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode)) + return + } + ws.Bufrw.WriteByte(0x80 | opcode) + + var maskBit byte + var maskKey [4]byte + if ws.IsClient { + _, err = io.ReadFull(rand.Reader, maskKey[:]) + if err != nil { + return + } + applyMask(payload, maskKey) + maskBit = 0x80 + } else { + maskBit = 0x00 + } + + if len(payload) < 126 { + ws.Bufrw.WriteByte(maskBit | byte(len(payload))) + } else if len(payload) <= 0xffff { + ws.Bufrw.WriteByte(maskBit | 126) + binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload))) + } else { + ws.Bufrw.WriteByte(maskBit | 127) + binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload))) + } + + if ws.IsClient { + _, err = ws.Bufrw.Write(maskKey[:]) + if err != nil { + return + } + } + _, err = ws.Bufrw.Write(payload) + if err != nil { + return + } + + ws.Bufrw.Flush() + + return +} + +// Write a single message to the WebSocket stream. Destructively masks payload +// in place if ws.IsClient. Messages are always sent as a single unfragmented +// frame. +func (ws *Websocket) WriteMessage(opcode byte, payload []byte) (err error) { + return ws.WriteFrame(opcode, payload) +} + +// Split a string on commas and trim whitespace. +func commaSplit(s string) []string { + var result []string + if strings.TrimSpace(s) == "" { + return result + } + for _, e := range strings.Split(s, ",") { + result = append(result, strings.TrimSpace(e)) + } + return result +} + +// Returns true iff one of the strings in haystack is needle. +func containsCase(haystack []string, needle string) bool { + for _, e := range haystack { + if strings.ToLower(e) == strings.ToLower(needle) { + return true + } + } + return false +} + +// One-step SHA-1 hash of a string. +func sha1Hash(data string) []byte { + h := sha1.New() + h.Write([]byte(data)) + return h.Sum(nil) +} + +func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) { + w.Header().Set("Connection", "close") + bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code))) + w.Header().Write(bufrw) + bufrw.WriteString("\r\n") + bufrw.Flush() +} + +// An implementation of http.Handler with a WebsocketConfig. The ServeHTTP +// function calls websocketCallback assuming WebSocket HTTP negotiation is +// successful. +type WebSocketHTTPHandler struct { + Config *WebsocketConfig + WebsocketCallback func(*Websocket) +} + +// Implements the http.Handler interface. +func (handler *WebSocketHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + // See RFC 6455 section 4.2.1 for this sequence of checks. + + // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"... + if req.Method != "GET" { + httpError(w, bufrw, http.StatusMethodNotAllowed) + return + } + if req.URL.Path != "/" { + httpError(w, bufrw, http.StatusNotFound) + return + } + // 2. A |Host| header field containing the server's authority. + // We deliberately skip this test. + // 3. An |Upgrade| header field containing the value "websocket", + // treated as an ASCII case-insensitive value. + if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 4. A |Connection| header field that includes the token "Upgrade", + // treated as an ASCII case-insensitive value. + if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value + // that, when decoded, is 16 bytes in length. + websocketKey := req.Header.Get("Sec-WebSocket-Key") + key, err := base64.StdEncoding.DecodeString(websocketKey) + if err != nil || len(key) != 16 { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 6. A |Sec-WebSocket-Version| header field, with a value of 13. + // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10. + var knownVersions = []string{"8", "13"} + websocketVersion := req.Header.Get("Sec-WebSocket-Version") + if !containsCase(knownVersions, websocketVersion) { + // "If this version does not match a version understood by the + // server, the server MUST abort the WebSocket handshake + // described in this section and instead send an appropriate + // HTTP error code (such as 426 Upgrade Required) and a + // |Sec-WebSocket-Version| header field indicating the + // version(s) the server is capable of understanding." + w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", ")) + httpError(w, bufrw, 426) + return + } + // 7. Optionally, an |Origin| header field. + // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of + // values indicating which protocols the client would like to speak, ordered + // by preference. + clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol")) + // 9. Optionally, a |Sec-WebSocket-Extensions| header field... + // 10. Optionally, other header fields... + + var ws Websocket + ws.Conn = conn + ws.Bufrw = bufrw + ws.IsClient = false + ws.MaxMessageSize = handler.Config.MaxMessageSize + + // See RFC 6455 section 4.2.2, item 5 for these steps. + + // 1. A Status-Line with a 101 response code as per RFC 2616. + bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))) + // 2. An |Upgrade| header field with value "websocket" as per RFC 2616. + w.Header().Set("Upgrade", "websocket") + // 3. A |Connection| header field with value "Upgrade". + w.Header().Set("Connection", "Upgrade") + // 4. A |Sec-WebSocket-Accept| header field. The value of this header + // field is constructed by concatenating /key/, defined above in step 4 + // in Section 4.2.2, with the string + // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + // concatenated value to obtain a 20-byte value and base64-encoding (see + // Section 4 of [RFC4648]) this 20-byte hash. + const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID)) + w.Header().Set("Sec-WebSocket-Accept", acceptKey) + // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value + // /subprotocol/ as defined in step 4 in Section 4.2.2. + for _, clientProto := range clientProtocols { + for _, serverProto := range handler.Config.Subprotocols { + if clientProto == serverProto { + ws.Subprotocol = clientProto + w.Header().Set("Sec-WebSocket-Protocol", clientProto) + break + } + } + } + // 6. Optionally, a |Sec-WebSocket-Extensions| header field... + w.Header().Write(bufrw) + bufrw.WriteString("\r\n") + bufrw.Flush() + + // Call the WebSocket-specific handler. + handler.WebsocketCallback(&ws) +} + +// Return an http.Handler with the given callback function. +func (config *WebsocketConfig) Handler(callback func(*Websocket)) http.Handler { + return &WebSocketHTTPHandler{config, callback} +} diff --git a/websocket-transport/websocket-client.go b/websocket-transport/websocket-client.go deleted file mode 100644 index 20e0c7b..0000000 --- a/websocket-transport/websocket-client.go +++ /dev/null @@ -1,253 +0,0 @@ -// Tor websocket client transport plugin. -// -// Usage: -// ClientTransportPlugin websocket exec ./websocket-client - -package main - -import ( - "code.google.com/p/go.net/websocket" - "flag" - "fmt" - "io" - "net" - "net/url" - "os" - "os/signal" - "sync" - "time" -) - -import "./pt" - -const ptMethodName = "websocket" -const socksTimeout = 2 * time.Second -const bufSiz = 1500 - -var logFile = os.Stderr - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -var logMutex sync.Mutex - -func usage() { - fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) - fmt.Printf("WebSocket client pluggable transport for Tor.\n") - fmt.Printf("Works only as a managed proxy.\n") - fmt.Printf("\n") - fmt.Printf(" -h, --help show this help.\n") - fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") - fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n") -} - -func Log(format string, v ...interface{}) { - dateStr := time.Now().Format("2006-01-02 15:04:05") - logMutex.Lock() - defer logMutex.Unlock() - msg := fmt.Sprintf(format, v...) - fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) -} - -func proxy(local *net.TCPConn, ws *websocket.Conn) { - var wg sync.WaitGroup - - wg.Add(2) - - // Local-to-WebSocket read loop. - go func() { - buf := make([]byte, bufSiz) - var err error - for { - n, er := local.Read(buf[:]) - if n > 0 { - ew := websocket.Message.Send(ws, buf[:n]) - if ew != nil { - err = ew - break - } - } - if er != nil { - err = er - break - } - } - if err != nil && err != io.EOF { - Log("%s", err) - } - local.CloseRead() - ws.Close() - - wg.Done() - }() - - // WebSocket-to-local read loop. - go func() { - var buf []byte - var err error - for { - er := websocket.Message.Receive(ws, &buf) - if er != nil { - err = er - break - } - n, ew := local.Write(buf) - if ew != nil { - err = ew - break - } - if n != len(buf) { - err = io.ErrShortWrite - break - } - } - if err != nil && err != io.EOF { - Log("%s", err) - } - local.CloseWrite() - ws.Close() - - wg.Done() - }() - - wg.Wait() -} - -func handleConnection(conn *net.TCPConn) error { - defer conn.Close() - - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - var ws *websocket.Conn - - conn.SetDeadline(time.Now().Add(socksTimeout)) - err := AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) { - // Disable deadline. - conn.SetDeadline(time.Time{}) - Log("SOCKS request for %s", dest) - destAddr, err := net.ResolveTCPAddr("tcp", dest) - if err != nil { - return nil, err - } - wsUrl := url.URL{Scheme: "ws", Host: dest} - ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String()) - if err != nil { - return nil, err - } - Log("WebSocket connection to %s", ws.Config().Location.String()) - return destAddr, nil - }) - if err != nil { - return err - } - defer ws.Close() - proxy(conn, ws) - return nil -} - -func socksAcceptLoop(ln *net.TCPListener) error { - for { - socks, err := ln.AcceptTCP() - if err != nil { - return err - } - go func() { - err := handleConnection(socks) - if err != nil { - Log("SOCKS from %s: %s", socks.RemoteAddr(), err) - } - }() - } - return nil -} - -func startListener(addrStr string) (*net.TCPListener, error) { - addr, err := net.ResolveTCPAddr("tcp", addrStr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } - go func() { - err := socksAcceptLoop(ln) - if err != nil { - Log("accept: %s", err) - } - }() - return ln, nil -} - -func main() { - var logFilename string - var socksAddrStrs = []string{"127.0.0.1:0"} - var socksArg string - - flag.Usage = usage - flag.StringVar(&logFilename, "log", "", "log file to write to") - flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections") - flag.Parse() - - if logFilename != "" { - f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) - os.Exit(1) - } - logFile = f - } - - if socksArg != "" { - socksAddrStrs = []string{socksArg} - } - - Log("starting") - pt.ClientSetup([]string{ptMethodName}) - - listeners := make([]*net.TCPListener, 0) - for _, socksAddrStr := range socksAddrStrs { - ln, err := startListener(socksAddrStr) - if err != nil { - pt.CmethodError(ptMethodName, err.Error()) - } - pt.Cmethod(ptMethodName, "socks4", ln.Addr()) - Log("listening on %s", ln.Addr().String()) - listeners = append(listeners, ln) - } - pt.CmethodsDone() - - var numHandlers int = 0 - - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt) - var sigint bool = false - for !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } - - for _, ln := range listeners { - ln.Close() - } - - sigint = false - for numHandlers != 0 && !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } -} diff --git a/websocket-transport/websocket-server.go b/websocket-transport/websocket-server.go deleted file mode 100644 index 4362fd6..0000000 --- a/websocket-transport/websocket-server.go +++ /dev/null @@ -1,275 +0,0 @@ -// Tor websocket server transport plugin. -// -// Usage: -// ServerTransportPlugin websocket exec ./websocket-server --port 9901 - -package main - -import ( - "encoding/base64" - "errors" - "flag" - "fmt" - "io" - "net" - "net/http" - "os" - "os/signal" - "sync" - "time" -) - -import "./pt" - -const ptMethodName = "websocket" -const requestTimeout = 10 * time.Second -// "4/3+1" accounts for possible base64 encoding. -const maxMessageSize = 64*1024*4/3 + 1 - -var logFile = os.Stderr - -var ptInfo pt.ServerInfo - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -var logMutex sync.Mutex - -func usage() { - fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) - fmt.Printf("WebSocket server pluggable transport for Tor.\n") - fmt.Printf("Works only as a managed proxy.\n") - fmt.Printf("\n") - fmt.Printf(" -h, --help show this help.\n") - fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") - fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n") -} - -func Log(format string, v ...interface{}) { - dateStr := time.Now().Format("2006-01-02 15:04:05") - logMutex.Lock() - defer logMutex.Unlock() - msg := fmt.Sprintf(format, v...) - fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) -} - -// An abstraction that makes an underlying WebSocket connection look like an -// io.ReadWriteCloser. It internally takes care of things like base64 encoding and -// decoding. -type websocketConn struct { - Ws *Websocket - Base64 bool - messageBuf []byte -} - -// Implements io.Reader. -func (conn *websocketConn) Read(b []byte) (n int, err error) { - for len(conn.messageBuf) == 0 { - var m WebsocketMessage - m, err = conn.Ws.ReadMessage() - if err != nil { - return - } - if m.Opcode == 8 { - err = io.EOF - return - } - if conn.Base64 { - if m.Opcode != 1 { - err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode)) - return - } - conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload))) - var num int - num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload) - if err != nil { - return - } - conn.messageBuf = conn.messageBuf[:num] - } else { - if m.Opcode != 2 { - err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode)) - return - } - conn.messageBuf = m.Payload - } - } - - n = copy(b, conn.messageBuf) - conn.messageBuf = conn.messageBuf[n:] - - return -} - -// Implements io.Writer. -func (conn *websocketConn) Write(b []byte) (n int, err error) { - if conn.Base64 { - buf := make([]byte, base64.StdEncoding.EncodedLen(len(b))) - base64.StdEncoding.Encode(buf, b) - err = conn.Ws.WriteMessage(1, buf) - if err != nil { - return - } - n = len(b) - } else { - err = conn.Ws.WriteMessage(2, b) - n = len(b) - } - return -} - -// Implements io.Closer. -func (conn *websocketConn) Close() error { - // Ignore any error in trying to write a Close frame. - _ = conn.Ws.WriteFrame(8, nil) - return conn.Ws.Conn.Close() -} - -// Create a new websocketConn. -func NewWebsocketConn(ws *Websocket) websocketConn { - var conn websocketConn - conn.Ws = ws - conn.Base64 = (ws.Subprotocol == "base64") - return conn -} - -// Copy from WebSocket to socket and vice versa. -func proxy(local *net.TCPConn, conn *websocketConn) { - var wg sync.WaitGroup - - wg.Add(2) - - go func() { - _, err := io.Copy(conn, local) - if err != nil { - Log("error copying ORPort to WebSocket") - } - local.CloseRead() - conn.Close() - wg.Done() - }() - - go func() { - _, err := io.Copy(local, conn) - if err != nil { - Log("error copying WebSocket to ORPort") - } - local.CloseWrite() - conn.Close() - wg.Done() - }() - - wg.Wait() -} - -func websocketHandler(ws *Websocket) { - // Undo timeouts on HTTP request handling. - ws.Conn.SetDeadline(time.Time{}) - conn := NewWebsocketConn(ws) - - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName) - if err != nil { - Log("Failed to connect to ORPort: " + err.Error()) - return - } - - proxy(s, &conn) -} - -func startListener(addr *net.TCPAddr) (*net.TCPListener, error) { - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } - go func() { - var config WebsocketConfig - config.Subprotocols = []string{"base64"} - config.MaxMessageSize = maxMessageSize - s := &http.Server{ - Handler: config.Handler(websocketHandler), - ReadTimeout: requestTimeout, - } - err = s.Serve(ln) - if err != nil { - Log("http.Serve: " + err.Error()) - } - }() - return ln, nil -} - -func main() { - var logFilename string - var port int - - flag.Usage = usage - flag.StringVar(&logFilename, "log", "", "log file to write to") - flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor") - flag.Parse() - - if logFilename != "" { - f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) - os.Exit(1) - } - logFile = f - } - - Log("starting") - ptInfo = pt.ServerSetup([]string{ptMethodName}) - - listeners := make([]*net.TCPListener, 0) - for _, bindAddr := range ptInfo.BindAddrs { - // Override tor's requested port (which is 0 if this transport - // has not been run before) with the one requested by the --port - // option. - if port != 0 { - bindAddr.Addr.Port = port - } - - ln, err := startListener(bindAddr.Addr) - if err != nil { - pt.SmethodError(bindAddr.MethodName, err.Error()) - } - pt.Smethod(bindAddr.MethodName, ln.Addr()) - Log("listening on %s", ln.Addr().String()) - listeners = append(listeners, ln) - } - pt.SmethodsDone() - - var numHandlers int = 0 - - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt) - var sigint bool = false - for !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } - - for _, ln := range listeners { - ln.Close() - } - - sigint = false - for numHandlers != 0 && !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } -} diff --git a/websocket-transport/websocket.go b/websocket-transport/websocket.go deleted file mode 100644 index f195828..0000000 --- a/websocket-transport/websocket.go +++ /dev/null @@ -1,432 +0,0 @@ -// WebSocket library. Only the RFC 6455 variety of WebSocket is supported. -// -// Reading and writing is strictly per-frame (or per-message). There is no way -// to partially read a frame. WebsocketConfig.MaxMessageSize affords control of -// the maximum buffering of messages. -// -// The reason for using this custom implementation instead of -// code.google.com/p/go.net/websocket is that the latter has problems with long -// messages and does not support server subprotocols. -// "Denial of Service Protection in Go HTTP Servers" -// https://code.google.com/p/go/issues/detail?id=2093 -// "go.websocket: Read/Copy fail with long frames" -// https://code.google.com/p/go/issues/detail?id=2134 -// http://golang.org/pkg/net/textproto/#pkg-bugs -// "To let callers manage exposure to denial of service attacks, Reader should -// allow them to set and reset a limit on the number of bytes read from the -// connection." -// "websocket.Dial doesn't limit response header length as http.Get does" -// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI -// -// Example usage: -// -// func doSomething(ws *Websocket) { -// } -// var config WebsocketConfig -// config.Subprotocols = []string{"base64"} -// config.MaxMessageSize = 2500 -// http.Handle("/", config.Handler(doSomething)) -// err = http.ListenAndServe(":8080", nil) - -package main - -import ( - "bufio" - "bytes" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "net/http" - "strings" -) - -// Settings for potential WebSocket connections. Subprotocols is a list of -// supported subprotocols as in RFC 6455 section 1.9. When answering client -// requests, the first of the client's requests subprotocols that is also in -// this list (if any) will be used as the subprotocol for the connection. -// MaxMessageSize is a limit on buffering messages. -type WebsocketConfig struct { - Subprotocols []string - MaxMessageSize int -} - -// Representation of a WebSocket frame. The Payload is always without masking. -type WebsocketFrame struct { - Fin bool - Opcode byte - Payload []byte -} - -// Return true iff the frame's opcode says it is a control frame. -func (frame *WebsocketFrame) IsControl() bool { - return (frame.Opcode & 0x08) != 0 -} - -// Representation of a WebSocket message. The Payload is always without masking. -type WebsocketMessage struct { - Opcode byte - Payload []byte -} - -// A WebSocket connection after hijacking from HTTP. -type Websocket struct { - // Conn and ReadWriter from http.ResponseWriter.Hijack. - Conn net.Conn - Bufrw *bufio.ReadWriter - // Whether we are a client or a server has implications for masking. - IsClient bool - // Set from a parent WebsocketConfig. - MaxMessageSize int - // The single selected subprotocol after negotiation, or "". - Subprotocol string - // Buffer for message payloads, which may be interrupted by control - // messages. - messageBuf bytes.Buffer -} - -func applyMask(payload []byte, maskKey [4]byte) { - for i := 0; i < len(payload); i++ { - payload[i] = payload[i] ^ maskKey[i%4] - } -} - -func (ws *Websocket) maxMessageSize() int { - if ws.MaxMessageSize == 0 { - return 64000 - } - return ws.MaxMessageSize -} - -// Read a single frame from the WebSocket. -func (ws *Websocket) ReadFrame() (frame WebsocketFrame, err error) { - var b byte - err = binary.Read(ws.Bufrw, binary.BigEndian, &b) - if err != nil { - return - } - frame.Fin = (b & 0x80) != 0 - frame.Opcode = b & 0x0f - err = binary.Read(ws.Bufrw, binary.BigEndian, &b) - if err != nil { - return - } - masked := (b & 0x80) != 0 - - payloadLen := uint64(b & 0x7f) - if payloadLen == 126 { - var short uint16 - err = binary.Read(ws.Bufrw, binary.BigEndian, &short) - if err != nil { - return - } - payloadLen = uint64(short) - } else if payloadLen == 127 { - var long uint64 - err = binary.Read(ws.Bufrw, binary.BigEndian, &long) - if err != nil { - return - } - payloadLen = long - } - if payloadLen > uint64(ws.maxMessageSize()) { - err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize)) - return - } - - maskKey := [4]byte{} - if masked { - if ws.IsClient { - err = errors.New("client got masked frame") - return - } - err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey) - if err != nil { - return - } - } else { - if !ws.IsClient { - err = errors.New("server got unmasked frame") - return - } - } - - frame.Payload = make([]byte, payloadLen) - _, err = io.ReadFull(ws.Bufrw, frame.Payload) - if err != nil { - return - } - if masked { - applyMask(frame.Payload, maskKey) - } - - return frame, nil -} - -// Read a single message from the WebSocket. Multiple fragmented frames are -// combined into a single message before being returned. Non-control messages -// may be interrupted by control frames. The control frames are returned as -// individual messages before the message that they interrupt. -func (ws *Websocket) ReadMessage() (message WebsocketMessage, err error) { - var opcode byte = 0 - for { - var frame WebsocketFrame - frame, err = ws.ReadFrame() - if err != nil { - return - } - if frame.IsControl() { - if !frame.Fin { - err = errors.New("control frame has fin bit unset") - return - } - message.Opcode = frame.Opcode - message.Payload = frame.Payload - return message, nil - } - - if opcode == 0 { - if frame.Opcode == 0 { - err = errors.New("first frame has opcode 0") - return - } - opcode = frame.Opcode - } else { - if frame.Opcode != 0 { - err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode)) - return - } - } - if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize { - err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d", - ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize)) - return - } - ws.messageBuf.Write(frame.Payload) - if frame.Fin { - break - } - } - message.Opcode = opcode - message.Payload = ws.messageBuf.Bytes() - ws.messageBuf.Reset() - - return message, nil -} - -// Write a single frame to the WebSocket stream. Destructively masks payload in -// place if ws.IsClient. Frames are always unfragmented. -func (ws *Websocket) WriteFrame(opcode byte, payload []byte) (err error) { - if opcode >= 16 { - err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode)) - return - } - ws.Bufrw.WriteByte(0x80 | opcode) - - var maskBit byte - var maskKey [4]byte - if ws.IsClient { - _, err = io.ReadFull(rand.Reader, maskKey[:]) - if err != nil { - return - } - applyMask(payload, maskKey) - maskBit = 0x80 - } else { - maskBit = 0x00 - } - - if len(payload) < 126 { - ws.Bufrw.WriteByte(maskBit | byte(len(payload))) - } else if len(payload) <= 0xffff { - ws.Bufrw.WriteByte(maskBit | 126) - binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload))) - } else { - ws.Bufrw.WriteByte(maskBit | 127) - binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload))) - } - - if ws.IsClient { - _, err = ws.Bufrw.Write(maskKey[:]) - if err != nil { - return - } - } - _, err = ws.Bufrw.Write(payload) - if err != nil { - return - } - - ws.Bufrw.Flush() - - return -} - -// Write a single message to the WebSocket stream. Destructively masks payload -// in place if ws.IsClient. Messages are always sent as a single unfragmented -// frame. -func (ws *Websocket) WriteMessage(opcode byte, payload []byte) (err error) { - return ws.WriteFrame(opcode, payload) -} - -// Split a string on commas and trim whitespace. -func commaSplit(s string) []string { - var result []string - if strings.TrimSpace(s) == "" { - return result - } - for _, e := range strings.Split(s, ",") { - result = append(result, strings.TrimSpace(e)) - } - return result -} - -// Returns true iff one of the strings in haystack is needle. -func containsCase(haystack []string, needle string) bool { - for _, e := range haystack { - if strings.ToLower(e) == strings.ToLower(needle) { - return true - } - } - return false -} - -// One-step SHA-1 hash of a string. -func sha1Hash(data string) []byte { - h := sha1.New() - h.Write([]byte(data)) - return h.Sum(nil) -} - -func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) { - w.Header().Set("Connection", "close") - bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code))) - w.Header().Write(bufrw) - bufrw.WriteString("\r\n") - bufrw.Flush() -} - -// An implementation of http.Handler with a WebsocketConfig. The ServeHTTP -// function calls websocketCallback assuming WebSocket HTTP negotiation is -// successful. -type WebSocketHTTPHandler struct { - Config *WebsocketConfig - WebsocketCallback func(*Websocket) -} - -// Implements the http.Handler interface. -func (handler *WebSocketHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - conn, bufrw, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer conn.Close() - - // See RFC 6455 section 4.2.1 for this sequence of checks. - - // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"... - if req.Method != "GET" { - httpError(w, bufrw, http.StatusMethodNotAllowed) - return - } - if req.URL.Path != "/" { - httpError(w, bufrw, http.StatusNotFound) - return - } - // 2. A |Host| header field containing the server's authority. - // We deliberately skip this test. - // 3. An |Upgrade| header field containing the value "websocket", - // treated as an ASCII case-insensitive value. - if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 4. A |Connection| header field that includes the token "Upgrade", - // treated as an ASCII case-insensitive value. - if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value - // that, when decoded, is 16 bytes in length. - websocketKey := req.Header.Get("Sec-WebSocket-Key") - key, err := base64.StdEncoding.DecodeString(websocketKey) - if err != nil || len(key) != 16 { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 6. A |Sec-WebSocket-Version| header field, with a value of 13. - // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10. - var knownVersions = []string{"8", "13"} - websocketVersion := req.Header.Get("Sec-WebSocket-Version") - if !containsCase(knownVersions, websocketVersion) { - // "If this version does not match a version understood by the - // server, the server MUST abort the WebSocket handshake - // described in this section and instead send an appropriate - // HTTP error code (such as 426 Upgrade Required) and a - // |Sec-WebSocket-Version| header field indicating the - // version(s) the server is capable of understanding." - w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", ")) - httpError(w, bufrw, 426) - return - } - // 7. Optionally, an |Origin| header field. - // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of - // values indicating which protocols the client would like to speak, ordered - // by preference. - clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol")) - // 9. Optionally, a |Sec-WebSocket-Extensions| header field... - // 10. Optionally, other header fields... - - var ws Websocket - ws.Conn = conn - ws.Bufrw = bufrw - ws.IsClient = false - ws.MaxMessageSize = handler.Config.MaxMessageSize - - // See RFC 6455 section 4.2.2, item 5 for these steps. - - // 1. A Status-Line with a 101 response code as per RFC 2616. - bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))) - // 2. An |Upgrade| header field with value "websocket" as per RFC 2616. - w.Header().Set("Upgrade", "websocket") - // 3. A |Connection| header field with value "Upgrade". - w.Header().Set("Connection", "Upgrade") - // 4. A |Sec-WebSocket-Accept| header field. The value of this header - // field is constructed by concatenating /key/, defined above in step 4 - // in Section 4.2.2, with the string - // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this - // concatenated value to obtain a 20-byte value and base64-encoding (see - // Section 4 of [RFC4648]) this 20-byte hash. - const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID)) - w.Header().Set("Sec-WebSocket-Accept", acceptKey) - // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value - // /subprotocol/ as defined in step 4 in Section 4.2.2. - for _, clientProto := range clientProtocols { - for _, serverProto := range handler.Config.Subprotocols { - if clientProto == serverProto { - ws.Subprotocol = clientProto - w.Header().Set("Sec-WebSocket-Protocol", clientProto) - break - } - } - } - // 6. Optionally, a |Sec-WebSocket-Extensions| header field... - w.Header().Write(bufrw) - bufrw.WriteString("\r\n") - bufrw.Flush() - - // Call the WebSocket-specific handler. - handler.WebsocketCallback(&ws) -} - -// Return an http.Handler with the given callback function. -func (config *WebsocketConfig) Handler(callback func(*Websocket)) http.Handler { - return &WebSocketHTTPHandler{config, callback} -}