commit fa2f68b5d91e07670622ab7054a2d96d85183c0e Author: David Fifield david@bamsoftware.com Date: Wed Dec 11 01:05:26 2013 -0800
Stop setting GOPATH.
This is something the users is supposed to set for themselves and we shouldn't mess with it. I went back to relative imports as that seems to be supported by gccgo now.
http://golang.org/doc/code.html really wants you to take all the Go code you download and put it in $GOPATH, with imports relative to $GOPATH. Downloading a repo elsewhere and then doing "go build", like we're used to with other source packages, doesn't work as well. --- Makefile | 23 +- pt/.gitignore | 2 + pt/examples/dummy-client/dummy-client.go | 137 ++++++ pt/examples/dummy-server/dummy-server.go | 121 +++++ pt/pt.go | 611 ++++++++++++++++++++++++++ pt/pt_test.go | 61 +++ pt/socks/socks.go | 107 +++++ src/pt/.gitignore | 2 - src/pt/examples/dummy-client/dummy-client.go | 137 ------ src/pt/examples/dummy-server/dummy-server.go | 121 ----- src/pt/pt.go | 611 -------------------------- src/pt/pt_test.go | 61 --- src/pt/socks/socks.go | 107 ----- src/websocket-client/websocket-client.go | 254 ----------- src/websocket-server/websocket-server.go | 285 ------------ src/websocket/websocket.go | 431 ------------------ websocket-client/websocket-client.go | 254 +++++++++++ websocket-server/websocket-server.go | 285 ++++++++++++ websocket/websocket.go | 431 ++++++++++++++++++ 19 files changed, 2018 insertions(+), 2023 deletions(-)
diff --git a/Makefile b/Makefile index 2ce4ffa..95aa243 100644 --- a/Makefile +++ b/Makefile @@ -2,33 +2,28 @@ DESTDIR = PREFIX = /usr/local BINDIR = $(PREFIX)/bin
-PROGRAMS = websocket-client websocket-server - -export GOPATH = $(CURDIR) GOBUILDFLAGS = # Alternate flags to use gccgo, allowing cross-compiling for x86 from # x86_64, and presumably better optimization. Install this package: # apt-get install gccgo-multilib # GOBUILDFLAGS = -compiler gccgo -gccgoflags "-O3 -m32 -static-libgo"
-all: websocket-server +all: websocket-server/websocket-server
-%: $(GOPATH)/src/%/*.go - go build $(GOBUILDFLAGS) "$*" +websocket-server/websocket-server: websocket-server/*.go websocket/*.go pt/*.go + cd websocket-server && go build $(GOBUILDFLAGS)
-# websocket-client has a special rule because "go get" is necessary. -websocket-client: $(GOPATH)/src/websocket-client/*.go - go get -d $(GOBUILDFLAGS) websocket-client - go build $(GOBUILDFLAGS) websocket-client +websocket-client/websocket-client: websocket-client/*.go websocket/*.go pt/*.go + cd websocket-client && go build $(GOBUILDFLAGS)
-install: +install: websocket-server/websocket-server mkdir -p "$(DESTDIR)$(BINDIR)" - cp -f websocket-server "$(DESTDIR)$(BINDIR)" + cp -f websocket-server/websocket-server "$(DESTDIR)$(BINDIR)"
clean: - rm -f $(PROGRAMS) + rm -f websocket-server/websocket-server websocket-client/websocket-client
fmt: - go fmt $(PROGRAMS) + go fmt ./websocket-server ./websocket-client ./websocket ./pt
.PHONY: all install clean fmt diff --git a/pt/.gitignore b/pt/.gitignore new file mode 100644 index 0000000..d4d5132 --- /dev/null +++ b/pt/.gitignore @@ -0,0 +1,2 @@ +/examples/dummy-client/dummy-client +/examples/dummy-server/dummy-server diff --git a/pt/examples/dummy-client/dummy-client.go b/pt/examples/dummy-client/dummy-client.go new file mode 100644 index 0000000..3cf7b45 --- /dev/null +++ b/pt/examples/dummy-client/dummy-client.go @@ -0,0 +1,137 @@ +// Usage (in torrc): +// UseBridges 1 +// Bridge dummy X.X.X.X:YYYY +// ClientTransportPlugin dummy exec dummy-client +// Because this transport doesn't do anything to the traffic, you can use any +// ordinary relay's ORPort in the Bridge line. + +package main + +import ( + "io" + "net" + "os" + "os/signal" + "sync" + "syscall" +) + +import "git.torproject.org/pluggable-transports/websocket.git/src/pt" +import "git.torproject.org/pluggable-transports/websocket.git/src/pt/socks" + +var ptInfo pt.ClientInfo + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +func copyLoop(a, b net.Conn) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + io.Copy(b, a) + wg.Done() + }() + go func() { + io.Copy(a, b) + wg.Done() + }() + + wg.Wait() +} + +func handleConnection(local net.Conn) error { + defer local.Close() + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + var remote net.Conn + err := socks.AwaitSocks4aConnect(local.(*net.TCPConn), func(dest string) (*net.TCPAddr, error) { + var err error + // set remote in outer function environment + remote, err = net.Dial("tcp", dest) + if err != nil { + return nil, err + } + return remote.RemoteAddr().(*net.TCPAddr), nil + }) + if err != nil { + return err + } + defer remote.Close() + copyLoop(local, remote) + + return nil +} + +func acceptLoop(ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + return err + } + go handleConnection(conn) + } + return nil +} + +func startListener(addr string) (net.Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + go acceptLoop(ln) + return ln, nil +} + +func main() { + ptInfo = pt.ClientSetup([]string{"dummy"}) + + listeners := make([]net.Listener, 0) + for _, methodName := range ptInfo.MethodNames { + ln, err := startListener("127.0.0.1:0") + if err != nil { + pt.CmethodError(methodName, err.Error()) + continue + } + pt.Cmethod(methodName, "socks4", ln.Addr()) + listeners = append(listeners, ln) + } + pt.CmethodsDone() + + var numHandlers int = 0 + var sig os.Signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // wait for first signal + sig = nil + for sig == nil { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } + for _, ln := range listeners { + ln.Close() + } + + if sig == syscall.SIGTERM { + return + } + + // wait for second signal or no more handlers + sig = nil + for sig == nil && numHandlers != 0 { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } +} diff --git a/pt/examples/dummy-server/dummy-server.go b/pt/examples/dummy-server/dummy-server.go new file mode 100644 index 0000000..26314d0 --- /dev/null +++ b/pt/examples/dummy-server/dummy-server.go @@ -0,0 +1,121 @@ +// Usage (in torrc): +// BridgeRelay 1 +// ORPort 9001 +// ExtORPort 6669 +// ServerTransportPlugin dummy exec dummy-server + +package main + +import ( + "io" + "net" + "os" + "os/signal" + "sync" + "syscall" +) + +import "git.torproject.org/pluggable-transports/websocket.git/src/pt" + +var ptInfo pt.ServerInfo + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +func copyLoop(a, b net.Conn) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + io.Copy(b, a) + wg.Done() + }() + go func() { + io.Copy(a, b) + wg.Done() + }() + + wg.Wait() +} + +func handleConnection(conn net.Conn) { + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + or, err := pt.ConnectOr(&ptInfo, conn, "dummy") + if err != nil { + return + } + copyLoop(conn, or) +} + +func acceptLoop(ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + return err + } + go handleConnection(conn) + } + return nil +} + +func startListener(addr *net.TCPAddr) (net.Listener, error) { + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go acceptLoop(ln) + return ln, nil +} + +func main() { + ptInfo = pt.ServerSetup([]string{"dummy"}) + + listeners := make([]net.Listener, 0) + for _, bindAddr := range ptInfo.BindAddrs { + ln, err := startListener(bindAddr.Addr) + if err != nil { + pt.SmethodError(bindAddr.MethodName, err.Error()) + continue + } + pt.Smethod(bindAddr.MethodName, ln.Addr()) + listeners = append(listeners, ln) + } + pt.SmethodsDone() + + var numHandlers int = 0 + var sig os.Signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // wait for first signal + sig = nil + for sig == nil { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } + for _, ln := range listeners { + ln.Close() + } + + if sig == syscall.SIGTERM { + return + } + + // wait for second signal or no more handlers + sig = nil + for sig == nil && numHandlers != 0 { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } +} diff --git a/pt/pt.go b/pt/pt.go new file mode 100644 index 0000000..526f3b7 --- /dev/null +++ b/pt/pt.go @@ -0,0 +1,611 @@ +// Tor pluggable transports library. +// +// Sample client usage: +// +// import "git.torproject.org/pluggable-transports/websocket.git/src/pt" +// var ptInfo pt.ClientInfo +// ptInfo = pt.ClientSetup([]string{"foo"}) +// for _, methodName := range ptInfo.MethodNames { +// ln, err := startSocksListener() +// if err != nil { +// pt.CmethodError(methodName, err.Error()) +// continue +// } +// pt.Cmethod(methodName, "socks4", ln.Addr()) +// } +// pt.CmethodsDone() +// +// Sample server usage: +// +// import "git.torproject.org/pluggable-transports/websocket.git/src/pt" +// var ptInfo pt.ServerInfo +// ptInfo = pt.ServerSetup([]string{"foo", "bar"}) +// for _, bindAddr := range ptInfo.BindAddrs { +// ln, err := startListener(bindAddr.Addr, bindAddr.MethodName) +// if err != nil { +// pt.SmethodError(bindAddr.MethodName, err.Error()) +// continue +// } +// pt.Smethod(bindAddr.MethodName, ln.Addr()) +// } +// pt.SmethodsDone() +// func handler(conn net.Conn, methodName string) { +// or, err := pt.ConnectOr(&ptInfo, conn, methodName) +// if err != nil { +// return +// } +// // Do something with or and conn. +// } + +package pt + +import ( + "bufio" + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "time" +) + +func getenv(key string) string { + return os.Getenv(key) +} + +// Abort with an ENV-ERROR if the environment variable isn't set. +func getenvRequired(key string) string { + value := os.Getenv(key) + if value == "" { + EnvError(fmt.Sprintf("no %s environment variable", key)) + } + return value +} + +// Escape a string so it contains no byte values over 127 and doesn't contain +// any of the characters '\x00' or '\n'. +func escape(s string) string { + var buf bytes.Buffer + for _, b := range []byte(s) { + if b == '\n' { + buf.WriteString("\n") + } else if b == '\' { + buf.WriteString("\\") + } else if 0 < b && b < 128 { + buf.WriteByte(b) + } else { + fmt.Fprintf(&buf, "\x%02x", b) + } + } + return buf.String() +} + +// Print a pluggable transports protocol line to stdout. The line consists of an +// unescaped keyword, followed by any number of escaped strings. +func Line(keyword string, v ...string) { + var buf bytes.Buffer + buf.WriteString(keyword) + for _, x := range v { + buf.WriteString(" " + escape(x)) + } + fmt.Println(buf.String()) + os.Stdout.Sync() +} + +// All of the *Error functions call os.Exit(1). + +// Emit an ENV-ERROR with explanation text. +func EnvError(msg string) { + Line("ENV-ERROR", msg) + os.Exit(1) +} + +// Emit a VERSION-ERROR with explanation text. +func VersionError(msg string) { + Line("VERSION-ERROR", msg) + os.Exit(1) +} + +// Emit a CMETHOD-ERROR with explanation text. +func CmethodError(methodName, msg string) { + Line("CMETHOD-ERROR", methodName, msg) + os.Exit(1) +} + +// Emit an SMETHOD-ERROR with explanation text. +func SmethodError(methodName, msg string) { + Line("SMETHOD-ERROR", methodName, msg) + os.Exit(1) +} + +// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for +// each listening client SOCKS port. +func Cmethod(name string, socks string, addr net.Addr) { + Line("CMETHOD", name, socks, addr.String()) +} + +// Emit a CMETHODS DONE line. Call this after opening all client listeners. +func CmethodsDone() { + Line("CMETHODS", "DONE") +} + +// Emit an SMETHOD line. Call this once for each listening server port. +func Smethod(name string, addr net.Addr) { + Line("SMETHOD", name, addr.String()) +} + +// Emit an SMETHODS DONE line. Call this after opening all server listeners. +func SmethodsDone() { + Line("SMETHODS", "DONE") +} + +// Get a pluggable transports version offered by Tor and understood by us, if +// any. The only version we understand is "1". This function reads the +// environment variable TOR_PT_MANAGED_TRANSPORT_VER. +func getManagedTransportVer() string { + const transportVersion = "1" + for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") { + if offered == transportVersion { + return offered + } + } + return "" +} + +// Get the intersection of the method names offered by Tor and those in +// methodNames. This function reads the environment variable +// TOR_PT_CLIENT_TRANSPORTS. +func getClientTransports(methodNames []string) []string { + clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS") + if clientTransports == "*" { + return methodNames + } + result := make([]string, 0) + for _, requested := range strings.Split(clientTransports, ",") { + for _, methodName := range methodNames { + if requested == methodName { + result = append(result, methodName) + break + } + } + } + return result +} + +// This structure is returned by ClientSetup. It consists of a list of method +// names. +type ClientInfo struct { + MethodNames []string +} + +// Check the client pluggable transports environments, emitting an error message +// and exiting the program if any error is encountered. Returns a subset of +// methodNames requested by Tor. +func ClientSetup(methodNames []string) ClientInfo { + var info ClientInfo + + ver := getManagedTransportVer() + if ver == "" { + VersionError("no-version") + } else { + Line("VERSION", ver) + } + + info.MethodNames = getClientTransports(methodNames) + if len(info.MethodNames) == 0 { + CmethodsDone() + os.Exit(1) + } + + return info +} + +// A combination of a method name and an address, as extracted from +// TOR_PT_SERVER_BINDADDR. +type BindAddr struct { + MethodName string + Addr *net.TCPAddr +} + +// Resolve an address string into a net.TCPAddr. +func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) { + addr, err := net.ResolveTCPAddr("tcp", bindAddr) + if err == nil { + return addr, nil + } + // Before the fixing of bug #7011, tor doesn't put brackets around IPv6 + // addresses. Split after the last colon, assuming it is a port + // separator, and try adding the brackets. + parts := strings.Split(bindAddr, ":") + if len(parts) <= 2 { + return nil, err + } + bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1] + return net.ResolveTCPAddr("tcp", bindAddr) +} + +// Return a new slice, the members of which are those members of addrs having a +// MethodName in methodNames. +func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr { + var result []BindAddr + + for _, ba := range addrs { + for _, methodName := range methodNames { + if ba.MethodName == methodName { + result = append(result, ba) + break + } + } + } + + return result +} + +// Return a map from method names to bind addresses. The map is the contents of +// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and +// further filtered by the methods in methodNames. +func getServerBindAddrs(methodNames []string) []BindAddr { + var result []BindAddr + + // Get the list of all requested bindaddrs. + var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR") + for _, spec := range strings.Split(serverBindAddr, ",") { + var bindAddr BindAddr + + parts := strings.SplitN(spec, "-", 2) + if len(parts) != 2 { + EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain "-"", spec)) + } + bindAddr.MethodName = parts[0] + addr, err := resolveBindAddr(parts[1]) + if err != nil { + EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error())) + } + bindAddr.Addr = addr + result = append(result, bindAddr) + } + + // Filter by TOR_PT_SERVER_TRANSPORTS. + serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS") + if serverTransports != "*" { + result = filterBindAddrs(result, strings.Split(serverTransports, ",")) + } + + // Finally filter by what we understand. + result = filterBindAddrs(result, methodNames) + + return result +} + +// Read and validate the contents of an auth cookie file. Returns the 32-byte +// cookie. See section 4.2.1.2 of pt-spec.txt. +func readAuthCookieFile(filename string) ([]byte, error) { + authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a") + header := make([]byte, 32) + cookie := make([]byte, 32) + + f, err := os.Open(filename) + if err != nil { + return cookie, err + } + defer f.Close() + + n, err := io.ReadFull(f, header) + if err != nil { + return cookie, err + } + n, err = io.ReadFull(f, cookie) + if err != nil { + return cookie, err + } + // Check that the file ends here. + n, err = f.Read(make([]byte, 1)) + if n != 0 { + return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes")) + } else if err != io.EOF { + return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file")) + } + + if !bytes.Equal(header, authCookieHeader) { + return cookie, errors.New(fmt.Sprintf("missing auth cookie header")) + } + + return cookie, nil +} + +// This structure is returned by ServerSetup. It consists of a list of +// BindAddrs, an address for the ORPort, an address for the extended ORPort (if +// any), and an authentication cookie (if any). +type ServerInfo struct { + BindAddrs []BindAddr + OrAddr *net.TCPAddr + ExtendedOrAddr *net.TCPAddr + AuthCookie []byte +} + +// Check the server pluggable transports environments, emitting an error message +// and exiting the program if any error is encountered. Resolves the various +// requested bind addresses, the server ORPort and extended ORPort, and reads +// the auth cookie file. Returns a ServerInfo struct. +func ServerSetup(methodNames []string) ServerInfo { + var info ServerInfo + var err error + + ver := getManagedTransportVer() + if ver == "" { + VersionError("no-version") + } else { + Line("VERSION", ver) + } + + var orPort = getenvRequired("TOR_PT_ORPORT") + info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort) + if err != nil { + EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error())) + } + + info.BindAddrs = getServerBindAddrs(methodNames) + if len(info.BindAddrs) == 0 { + SmethodsDone() + os.Exit(1) + } + + var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT") + if extendedOrPort != "" { + info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort) + if err != nil { + EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error())) + } + } + + var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE") + if authCookieFilename != "" { + info.AuthCookie, err = readAuthCookieFile(authCookieFilename) + if err != nil { + EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error())) + } + } + + return info +} + +// See 217-ext-orport-auth.txt section 4.2.1.3. +func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { + h := hmac.New(sha256.New, info.AuthCookie) + io.WriteString(h, "ExtORPort authentication server-to-client hash") + h.Write(clientNonce) + h.Write(serverNonce) + return h.Sum([]byte{}) +} + +// See 217-ext-orport-auth.txt section 4.2.1.3. +func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { + h := hmac.New(sha256.New, info.AuthCookie) + io.WriteString(h, "ExtORPort authentication client-to-server hash") + h.Write(clientNonce) + h.Write(serverNonce) + return h.Sum([]byte{}) +} + +func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error { + r := bufio.NewReader(s) + + // Read auth types. 217-ext-orport-auth.txt section 4.1. + var authTypes [256]bool + var count int + for count = 0; count < 256; count++ { + b, err := r.ReadByte() + if err != nil { + return err + } + if b == 0 { + break + } + authTypes[b] = true + } + if count >= 256 { + return errors.New(fmt.Sprintf("read 256 auth types without seeing \x00")) + } + + // We support only type 1, SAFE_COOKIE. + if !authTypes[1] { + return errors.New(fmt.Sprintf("server didn't offer auth type 1")) + } + _, err := s.Write([]byte{1}) + if err != nil { + return err + } + + clientNonce := make([]byte, 32) + clientHash := make([]byte, 32) + serverNonce := make([]byte, 32) + serverHash := make([]byte, 32) + + _, err = io.ReadFull(rand.Reader, clientNonce) + if err != nil { + return err + } + _, err = s.Write(clientNonce) + if err != nil { + return err + } + + _, err = io.ReadFull(r, serverHash) + if err != nil { + return err + } + _, err = io.ReadFull(r, serverNonce) + if err != nil { + return err + } + + expectedServerHash := computeServerHash(info, clientNonce, serverNonce) + if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 { + return errors.New(fmt.Sprintf("mismatch in server hash")) + } + + clientHash = computeClientHash(info, clientNonce, serverNonce) + _, err = s.Write(clientHash) + if err != nil { + return err + } + + status := make([]byte, 1) + _, err = io.ReadFull(r, status) + if err != nil { + return err + } + if status[0] != 1 { + return errors.New(fmt.Sprintf("server rejected authentication")) + } + + if r.Buffered() != 0 { + return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered())) + } + + return nil +} + +// See section 3.1 of 196-transport-control-ports.txt. +const ( + extOrCmdDone = 0x0000 + extOrCmdUserAddr = 0x0001 + extOrCmdTransport = 0x0002 + extOrCmdOkay = 0x1000 + extOrCmdDeny = 0x1001 +) + +func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error { + var buf bytes.Buffer + if len(body) > 65535 { + return errors.New("command exceeds maximum length of 65535") + } + err := binary.Write(&buf, binary.BigEndian, cmd) + if err != nil { + return err + } + err = binary.Write(&buf, binary.BigEndian, uint16(len(body))) + if err != nil { + return err + } + err = binary.Write(&buf, binary.BigEndian, body) + if err != nil { + return err + } + _, err = s.Write(buf.Bytes()) + if err != nil { + return err + } + + return nil +} + +// Send a USERADDR command on s. See section 3.1.2.1 of +// 196-transport-control-ports.txt. +func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error { + return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String())) +} + +// Send a TRANSPORT command on s. See section 3.1.2.2 of +// 196-transport-control-ports.txt. +func extOrPortSendTransport(s *net.TCPConn, methodName string) error { + return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName)) +} + +// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt. +func extOrPortSendDone(s *net.TCPConn) error { + return extOrPortWriteCommand(s, extOrCmdDone, []byte{}) +} + +func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) { + var bodyLen uint16 + data := make([]byte, 4) + + _, err = io.ReadFull(s, data) + if err != nil { + return + } + buf := bytes.NewBuffer(data) + err = binary.Read(buf, binary.BigEndian, &cmd) + if err != nil { + return + } + err = binary.Read(buf, binary.BigEndian, &bodyLen) + if err != nil { + return + } + body = make([]byte, bodyLen) + _, err = io.ReadFull(s, body) + if err != nil { + return + } + + return cmd, body, err +} + +// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an +// OKAY or DENY response command from the server. Returns nil if and only if +// OKAY is received. +func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error { + var err error + + err = extOrPortSendUserAddr(s, conn) + if err != nil { + return err + } + err = extOrPortSendTransport(s, methodName) + if err != nil { + return err + } + err = extOrPortSendDone(s) + if err != nil { + return err + } + cmd, _, err := extOrPortRecvCommand(s) + if err != nil { + return err + } + if cmd == extOrCmdDeny { + return errors.New("server returned DENY after our USERADDR and DONE") + } else if cmd != extOrCmdOkay { + return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd)) + } + + return nil +} + +// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an +// open *net.TCPConn. If connecting to the extended OR port, extended OR port +// authentication à la 217-ext-orport-auth.txt is done before returning; an +// error is returned if authentication fails. +func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) { + if info.ExtendedOrAddr == nil { + return net.DialTCP("tcp", nil, info.OrAddr) + } + + s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr) + if err != nil { + return nil, err + } + s.SetDeadline(time.Now().Add(5 * time.Second)) + err = extOrPortAuthenticate(s, info) + if err != nil { + s.Close() + return nil, err + } + err = extOrPortSetup(s, conn, methodName) + if err != nil { + s.Close() + return nil, err + } + s.SetDeadline(time.Time{}) + + return s, nil +} diff --git a/pt/pt_test.go b/pt/pt_test.go new file mode 100644 index 0000000..cc7924a --- /dev/null +++ b/pt/pt_test.go @@ -0,0 +1,61 @@ +package pt + +import "os" +import "testing" + +func stringIsSafe(s string) bool { + for _, c := range []byte(s) { + if c == '\x00' || c == '\n' || c > 127 { + return false + } + } + return true +} + +func TestEscape(t *testing.T) { + tests := [...]string{ + "", + "abc", + "a\nb", + "a\b", + "ab\", + "ab\\n", + "ab\n\", + } + + check := func(input string) { + output := escape(input) + if !stringIsSafe(output) { + t.Errorf("escape(%q) → %q", input, output) + } + } + for _, input := range tests { + check(input) + } + for b := 0; b < 256; b++ { + // check one-byte string with each byte value 0–255 + check(string([]byte{byte(b)})) + // check UTF-8 encoding of each character 0–255 + check(string(b)) + } +} + +func TestGetManagedTransportVer(t *testing.T) { + tests := [...]struct { + input, expected string + }{ + {"1", "1"}, + {"1,1", "1"}, + {"1,2", "1"}, + {"2,1", "1"}, + {"2", ""}, + } + + for _, test := range tests { + os.Setenv("TOR_PT_MANAGED_TRANSPORT_VER", test.input) + output := getManagedTransportVer() + if output != test.expected { + t.Errorf("%q → %q (expected %q)", test.input, output, test.expected) + } + } +} diff --git a/pt/socks/socks.go b/pt/socks/socks.go new file mode 100644 index 0000000..788d53c --- /dev/null +++ b/pt/socks/socks.go @@ -0,0 +1,107 @@ +// SOCKS4a server library. + +package socks + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" +) + +const ( + socksVersion = 0x04 + socksCmdConnect = 0x01 + socksResponseVersion = 0x00 + socksRequestGranted = 0x5a + socksRequestFailed = 0x5b +) + +// Read a SOCKS4a connect request, and call the given connect callback with the +// requested destination string. If the callback returns an error, sends a SOCKS +// request failed message. Otherwise, sends a SOCKS request granted message for +// the destination address returned by the callback. +func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error { + dest, err := readSocks4aConnect(conn) + if err != nil { + sendSocks4aResponseFailed(conn) + return err + } + destAddr, err := connect(dest) + if err != nil { + sendSocks4aResponseFailed(conn) + return err + } + sendSocks4aResponseGranted(conn, destAddr) + return nil +} + +// Read a SOCKS4a connect request. Returns a "host:port" string. +func readSocks4aConnect(s io.Reader) (string, error) { + r := bufio.NewReader(s) + + var h [8]byte + n, err := io.ReadFull(r, h[:]) + if err != nil { + return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err)) + } + if h[0] != socksVersion { + return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion)) + } + if h[1] != socksCmdConnect { + return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect)) + } + + _, err = r.ReadBytes('\x00') + if err != nil { + return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err)) + } + + var port int + var host string + + port = int(h[2])<<8 | int(h[3])<<0 + if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 { + hostBytes, err := r.ReadBytes('\x00') + if err != nil { + return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err)) + } + host = string(hostBytes[:len(hostBytes)-1]) + } else { + host = net.IPv4(h[4], h[5], h[6], h[7]).String() + } + + if r.Buffered() != 0 { + return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered())) + } + + return fmt.Sprintf("%s:%d", host, port), nil +} + +// Send a SOCKS4a response with the given code and address. +func sendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error { + var resp [8]byte + resp[0] = socksResponseVersion + resp[1] = code + resp[2] = byte((addr.Port >> 8) & 0xff) + resp[3] = byte((addr.Port >> 0) & 0xff) + resp[4] = addr.IP[0] + resp[5] = addr.IP[1] + resp[6] = addr.IP[2] + resp[7] = addr.IP[3] + _, err := w.Write(resp[:]) + return err +} + +var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0} + +// Send a SOCKS4a response code 0x5a. +func sendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error { + return sendSocks4aResponse(w, socksRequestGranted, addr) +} + +// Send a SOCKS4a response code 0x5b (with an all-zero address). +func sendSocks4aResponseFailed(w io.Writer) error { + return sendSocks4aResponse(w, socksRequestFailed, &emptyAddr) +} diff --git a/src/pt/.gitignore b/src/pt/.gitignore deleted file mode 100644 index d4d5132..0000000 --- a/src/pt/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/examples/dummy-client/dummy-client -/examples/dummy-server/dummy-server diff --git a/src/pt/examples/dummy-client/dummy-client.go b/src/pt/examples/dummy-client/dummy-client.go deleted file mode 100644 index 3cf7b45..0000000 --- a/src/pt/examples/dummy-client/dummy-client.go +++ /dev/null @@ -1,137 +0,0 @@ -// Usage (in torrc): -// UseBridges 1 -// Bridge dummy X.X.X.X:YYYY -// ClientTransportPlugin dummy exec dummy-client -// Because this transport doesn't do anything to the traffic, you can use any -// ordinary relay's ORPort in the Bridge line. - -package main - -import ( - "io" - "net" - "os" - "os/signal" - "sync" - "syscall" -) - -import "git.torproject.org/pluggable-transports/websocket.git/src/pt" -import "git.torproject.org/pluggable-transports/websocket.git/src/pt/socks" - -var ptInfo pt.ClientInfo - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -func copyLoop(a, b net.Conn) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - io.Copy(b, a) - wg.Done() - }() - go func() { - io.Copy(a, b) - wg.Done() - }() - - wg.Wait() -} - -func handleConnection(local net.Conn) error { - defer local.Close() - - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - var remote net.Conn - err := socks.AwaitSocks4aConnect(local.(*net.TCPConn), func(dest string) (*net.TCPAddr, error) { - var err error - // set remote in outer function environment - remote, err = net.Dial("tcp", dest) - if err != nil { - return nil, err - } - return remote.RemoteAddr().(*net.TCPAddr), nil - }) - if err != nil { - return err - } - defer remote.Close() - copyLoop(local, remote) - - return nil -} - -func acceptLoop(ln net.Listener) error { - for { - conn, err := ln.Accept() - if err != nil { - return err - } - go handleConnection(conn) - } - return nil -} - -func startListener(addr string) (net.Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - go acceptLoop(ln) - return ln, nil -} - -func main() { - ptInfo = pt.ClientSetup([]string{"dummy"}) - - listeners := make([]net.Listener, 0) - for _, methodName := range ptInfo.MethodNames { - ln, err := startListener("127.0.0.1:0") - if err != nil { - pt.CmethodError(methodName, err.Error()) - continue - } - pt.Cmethod(methodName, "socks4", ln.Addr()) - listeners = append(listeners, ln) - } - pt.CmethodsDone() - - var numHandlers int = 0 - var sig os.Signal - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // wait for first signal - sig = nil - for sig == nil { - select { - case n := <-handlerChan: - numHandlers += n - case sig = <-sigChan: - } - } - for _, ln := range listeners { - ln.Close() - } - - if sig == syscall.SIGTERM { - return - } - - // wait for second signal or no more handlers - sig = nil - for sig == nil && numHandlers != 0 { - select { - case n := <-handlerChan: - numHandlers += n - case sig = <-sigChan: - } - } -} diff --git a/src/pt/examples/dummy-server/dummy-server.go b/src/pt/examples/dummy-server/dummy-server.go deleted file mode 100644 index 26314d0..0000000 --- a/src/pt/examples/dummy-server/dummy-server.go +++ /dev/null @@ -1,121 +0,0 @@ -// Usage (in torrc): -// BridgeRelay 1 -// ORPort 9001 -// ExtORPort 6669 -// ServerTransportPlugin dummy exec dummy-server - -package main - -import ( - "io" - "net" - "os" - "os/signal" - "sync" - "syscall" -) - -import "git.torproject.org/pluggable-transports/websocket.git/src/pt" - -var ptInfo pt.ServerInfo - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -func copyLoop(a, b net.Conn) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - io.Copy(b, a) - wg.Done() - }() - go func() { - io.Copy(a, b) - wg.Done() - }() - - wg.Wait() -} - -func handleConnection(conn net.Conn) { - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - or, err := pt.ConnectOr(&ptInfo, conn, "dummy") - if err != nil { - return - } - copyLoop(conn, or) -} - -func acceptLoop(ln net.Listener) error { - for { - conn, err := ln.Accept() - if err != nil { - return err - } - go handleConnection(conn) - } - return nil -} - -func startListener(addr *net.TCPAddr) (net.Listener, error) { - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } - go acceptLoop(ln) - return ln, nil -} - -func main() { - ptInfo = pt.ServerSetup([]string{"dummy"}) - - listeners := make([]net.Listener, 0) - for _, bindAddr := range ptInfo.BindAddrs { - ln, err := startListener(bindAddr.Addr) - if err != nil { - pt.SmethodError(bindAddr.MethodName, err.Error()) - continue - } - pt.Smethod(bindAddr.MethodName, ln.Addr()) - listeners = append(listeners, ln) - } - pt.SmethodsDone() - - var numHandlers int = 0 - var sig os.Signal - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // wait for first signal - sig = nil - for sig == nil { - select { - case n := <-handlerChan: - numHandlers += n - case sig = <-sigChan: - } - } - for _, ln := range listeners { - ln.Close() - } - - if sig == syscall.SIGTERM { - return - } - - // wait for second signal or no more handlers - sig = nil - for sig == nil && numHandlers != 0 { - select { - case n := <-handlerChan: - numHandlers += n - case sig = <-sigChan: - } - } -} diff --git a/src/pt/pt.go b/src/pt/pt.go deleted file mode 100644 index 526f3b7..0000000 --- a/src/pt/pt.go +++ /dev/null @@ -1,611 +0,0 @@ -// Tor pluggable transports library. -// -// Sample client usage: -// -// import "git.torproject.org/pluggable-transports/websocket.git/src/pt" -// var ptInfo pt.ClientInfo -// ptInfo = pt.ClientSetup([]string{"foo"}) -// for _, methodName := range ptInfo.MethodNames { -// ln, err := startSocksListener() -// if err != nil { -// pt.CmethodError(methodName, err.Error()) -// continue -// } -// pt.Cmethod(methodName, "socks4", ln.Addr()) -// } -// pt.CmethodsDone() -// -// Sample server usage: -// -// import "git.torproject.org/pluggable-transports/websocket.git/src/pt" -// var ptInfo pt.ServerInfo -// ptInfo = pt.ServerSetup([]string{"foo", "bar"}) -// for _, bindAddr := range ptInfo.BindAddrs { -// ln, err := startListener(bindAddr.Addr, bindAddr.MethodName) -// if err != nil { -// pt.SmethodError(bindAddr.MethodName, err.Error()) -// continue -// } -// pt.Smethod(bindAddr.MethodName, ln.Addr()) -// } -// pt.SmethodsDone() -// func handler(conn net.Conn, methodName string) { -// or, err := pt.ConnectOr(&ptInfo, conn, methodName) -// if err != nil { -// return -// } -// // Do something with or and conn. -// } - -package pt - -import ( - "bufio" - "bytes" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "os" - "strings" - "time" -) - -func getenv(key string) string { - return os.Getenv(key) -} - -// Abort with an ENV-ERROR if the environment variable isn't set. -func getenvRequired(key string) string { - value := os.Getenv(key) - if value == "" { - EnvError(fmt.Sprintf("no %s environment variable", key)) - } - return value -} - -// Escape a string so it contains no byte values over 127 and doesn't contain -// any of the characters '\x00' or '\n'. -func escape(s string) string { - var buf bytes.Buffer - for _, b := range []byte(s) { - if b == '\n' { - buf.WriteString("\n") - } else if b == '\' { - buf.WriteString("\\") - } else if 0 < b && b < 128 { - buf.WriteByte(b) - } else { - fmt.Fprintf(&buf, "\x%02x", b) - } - } - return buf.String() -} - -// Print a pluggable transports protocol line to stdout. The line consists of an -// unescaped keyword, followed by any number of escaped strings. -func Line(keyword string, v ...string) { - var buf bytes.Buffer - buf.WriteString(keyword) - for _, x := range v { - buf.WriteString(" " + escape(x)) - } - fmt.Println(buf.String()) - os.Stdout.Sync() -} - -// All of the *Error functions call os.Exit(1). - -// Emit an ENV-ERROR with explanation text. -func EnvError(msg string) { - Line("ENV-ERROR", msg) - os.Exit(1) -} - -// Emit a VERSION-ERROR with explanation text. -func VersionError(msg string) { - Line("VERSION-ERROR", msg) - os.Exit(1) -} - -// Emit a CMETHOD-ERROR with explanation text. -func CmethodError(methodName, msg string) { - Line("CMETHOD-ERROR", methodName, msg) - os.Exit(1) -} - -// Emit an SMETHOD-ERROR with explanation text. -func SmethodError(methodName, msg string) { - Line("SMETHOD-ERROR", methodName, msg) - os.Exit(1) -} - -// Emit a CMETHOD line. socks must be "socks4" or "socks5". Call this once for -// each listening client SOCKS port. -func Cmethod(name string, socks string, addr net.Addr) { - Line("CMETHOD", name, socks, addr.String()) -} - -// Emit a CMETHODS DONE line. Call this after opening all client listeners. -func CmethodsDone() { - Line("CMETHODS", "DONE") -} - -// Emit an SMETHOD line. Call this once for each listening server port. -func Smethod(name string, addr net.Addr) { - Line("SMETHOD", name, addr.String()) -} - -// Emit an SMETHODS DONE line. Call this after opening all server listeners. -func SmethodsDone() { - Line("SMETHODS", "DONE") -} - -// Get a pluggable transports version offered by Tor and understood by us, if -// any. The only version we understand is "1". This function reads the -// environment variable TOR_PT_MANAGED_TRANSPORT_VER. -func getManagedTransportVer() string { - const transportVersion = "1" - for _, offered := range strings.Split(getenvRequired("TOR_PT_MANAGED_TRANSPORT_VER"), ",") { - if offered == transportVersion { - return offered - } - } - return "" -} - -// Get the intersection of the method names offered by Tor and those in -// methodNames. This function reads the environment variable -// TOR_PT_CLIENT_TRANSPORTS. -func getClientTransports(methodNames []string) []string { - clientTransports := getenvRequired("TOR_PT_CLIENT_TRANSPORTS") - if clientTransports == "*" { - return methodNames - } - result := make([]string, 0) - for _, requested := range strings.Split(clientTransports, ",") { - for _, methodName := range methodNames { - if requested == methodName { - result = append(result, methodName) - break - } - } - } - return result -} - -// This structure is returned by ClientSetup. It consists of a list of method -// names. -type ClientInfo struct { - MethodNames []string -} - -// Check the client pluggable transports environments, emitting an error message -// and exiting the program if any error is encountered. Returns a subset of -// methodNames requested by Tor. -func ClientSetup(methodNames []string) ClientInfo { - var info ClientInfo - - ver := getManagedTransportVer() - if ver == "" { - VersionError("no-version") - } else { - Line("VERSION", ver) - } - - info.MethodNames = getClientTransports(methodNames) - if len(info.MethodNames) == 0 { - CmethodsDone() - os.Exit(1) - } - - return info -} - -// A combination of a method name and an address, as extracted from -// TOR_PT_SERVER_BINDADDR. -type BindAddr struct { - MethodName string - Addr *net.TCPAddr -} - -// Resolve an address string into a net.TCPAddr. -func resolveBindAddr(bindAddr string) (*net.TCPAddr, error) { - addr, err := net.ResolveTCPAddr("tcp", bindAddr) - if err == nil { - return addr, nil - } - // Before the fixing of bug #7011, tor doesn't put brackets around IPv6 - // addresses. Split after the last colon, assuming it is a port - // separator, and try adding the brackets. - parts := strings.Split(bindAddr, ":") - if len(parts) <= 2 { - return nil, err - } - bindAddr = "[" + strings.Join(parts[:len(parts)-1], ":") + "]:" + parts[len(parts)-1] - return net.ResolveTCPAddr("tcp", bindAddr) -} - -// Return a new slice, the members of which are those members of addrs having a -// MethodName in methodNames. -func filterBindAddrs(addrs []BindAddr, methodNames []string) []BindAddr { - var result []BindAddr - - for _, ba := range addrs { - for _, methodName := range methodNames { - if ba.MethodName == methodName { - result = append(result, ba) - break - } - } - } - - return result -} - -// Return a map from method names to bind addresses. The map is the contents of -// TOR_PT_SERVER_BINDADDR, with keys filtered by TOR_PT_SERVER_TRANSPORTS, and -// further filtered by the methods in methodNames. -func getServerBindAddrs(methodNames []string) []BindAddr { - var result []BindAddr - - // Get the list of all requested bindaddrs. - var serverBindAddr = getenvRequired("TOR_PT_SERVER_BINDADDR") - for _, spec := range strings.Split(serverBindAddr, ",") { - var bindAddr BindAddr - - parts := strings.SplitN(spec, "-", 2) - if len(parts) != 2 { - EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: doesn't contain "-"", spec)) - } - bindAddr.MethodName = parts[0] - addr, err := resolveBindAddr(parts[1]) - if err != nil { - EnvError(fmt.Sprintf("TOR_PT_SERVER_BINDADDR: %q: %s", spec, err.Error())) - } - bindAddr.Addr = addr - result = append(result, bindAddr) - } - - // Filter by TOR_PT_SERVER_TRANSPORTS. - serverTransports := getenvRequired("TOR_PT_SERVER_TRANSPORTS") - if serverTransports != "*" { - result = filterBindAddrs(result, strings.Split(serverTransports, ",")) - } - - // Finally filter by what we understand. - result = filterBindAddrs(result, methodNames) - - return result -} - -// Read and validate the contents of an auth cookie file. Returns the 32-byte -// cookie. See section 4.2.1.2 of pt-spec.txt. -func readAuthCookieFile(filename string) ([]byte, error) { - authCookieHeader := []byte("! Extended ORPort Auth Cookie !\x0a") - header := make([]byte, 32) - cookie := make([]byte, 32) - - f, err := os.Open(filename) - if err != nil { - return cookie, err - } - defer f.Close() - - n, err := io.ReadFull(f, header) - if err != nil { - return cookie, err - } - n, err = io.ReadFull(f, cookie) - if err != nil { - return cookie, err - } - // Check that the file ends here. - n, err = f.Read(make([]byte, 1)) - if n != 0 { - return cookie, errors.New(fmt.Sprintf("file is longer than 64 bytes")) - } else if err != io.EOF { - return cookie, errors.New(fmt.Sprintf("did not find EOF at end of file")) - } - - if !bytes.Equal(header, authCookieHeader) { - return cookie, errors.New(fmt.Sprintf("missing auth cookie header")) - } - - return cookie, nil -} - -// This structure is returned by ServerSetup. It consists of a list of -// BindAddrs, an address for the ORPort, an address for the extended ORPort (if -// any), and an authentication cookie (if any). -type ServerInfo struct { - BindAddrs []BindAddr - OrAddr *net.TCPAddr - ExtendedOrAddr *net.TCPAddr - AuthCookie []byte -} - -// Check the server pluggable transports environments, emitting an error message -// and exiting the program if any error is encountered. Resolves the various -// requested bind addresses, the server ORPort and extended ORPort, and reads -// the auth cookie file. Returns a ServerInfo struct. -func ServerSetup(methodNames []string) ServerInfo { - var info ServerInfo - var err error - - ver := getManagedTransportVer() - if ver == "" { - VersionError("no-version") - } else { - Line("VERSION", ver) - } - - var orPort = getenvRequired("TOR_PT_ORPORT") - info.OrAddr, err = net.ResolveTCPAddr("tcp", orPort) - if err != nil { - EnvError(fmt.Sprintf("cannot resolve TOR_PT_ORPORT %q: %s", orPort, err.Error())) - } - - info.BindAddrs = getServerBindAddrs(methodNames) - if len(info.BindAddrs) == 0 { - SmethodsDone() - os.Exit(1) - } - - var extendedOrPort = getenv("TOR_PT_EXTENDED_SERVER_PORT") - if extendedOrPort != "" { - info.ExtendedOrAddr, err = net.ResolveTCPAddr("tcp", extendedOrPort) - if err != nil { - EnvError(fmt.Sprintf("cannot resolve TOR_PT_EXTENDED_SERVER_PORT %q: %s", extendedOrPort, err.Error())) - } - } - - var authCookieFilename = getenv("TOR_PT_AUTH_COOKIE_FILE") - if authCookieFilename != "" { - info.AuthCookie, err = readAuthCookieFile(authCookieFilename) - if err != nil { - EnvError(fmt.Sprintf("error reading TOR_PT_AUTH_COOKIE_FILE %q: %s", authCookieFilename, err.Error())) - } - } - - return info -} - -// See 217-ext-orport-auth.txt section 4.2.1.3. -func computeServerHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { - h := hmac.New(sha256.New, info.AuthCookie) - io.WriteString(h, "ExtORPort authentication server-to-client hash") - h.Write(clientNonce) - h.Write(serverNonce) - return h.Sum([]byte{}) -} - -// See 217-ext-orport-auth.txt section 4.2.1.3. -func computeClientHash(info *ServerInfo, clientNonce, serverNonce []byte) []byte { - h := hmac.New(sha256.New, info.AuthCookie) - io.WriteString(h, "ExtORPort authentication client-to-server hash") - h.Write(clientNonce) - h.Write(serverNonce) - return h.Sum([]byte{}) -} - -func extOrPortAuthenticate(s *net.TCPConn, info *ServerInfo) error { - r := bufio.NewReader(s) - - // Read auth types. 217-ext-orport-auth.txt section 4.1. - var authTypes [256]bool - var count int - for count = 0; count < 256; count++ { - b, err := r.ReadByte() - if err != nil { - return err - } - if b == 0 { - break - } - authTypes[b] = true - } - if count >= 256 { - return errors.New(fmt.Sprintf("read 256 auth types without seeing \x00")) - } - - // We support only type 1, SAFE_COOKIE. - if !authTypes[1] { - return errors.New(fmt.Sprintf("server didn't offer auth type 1")) - } - _, err := s.Write([]byte{1}) - if err != nil { - return err - } - - clientNonce := make([]byte, 32) - clientHash := make([]byte, 32) - serverNonce := make([]byte, 32) - serverHash := make([]byte, 32) - - _, err = io.ReadFull(rand.Reader, clientNonce) - if err != nil { - return err - } - _, err = s.Write(clientNonce) - if err != nil { - return err - } - - _, err = io.ReadFull(r, serverHash) - if err != nil { - return err - } - _, err = io.ReadFull(r, serverNonce) - if err != nil { - return err - } - - expectedServerHash := computeServerHash(info, clientNonce, serverNonce) - if subtle.ConstantTimeCompare(serverHash, expectedServerHash) != 1 { - return errors.New(fmt.Sprintf("mismatch in server hash")) - } - - clientHash = computeClientHash(info, clientNonce, serverNonce) - _, err = s.Write(clientHash) - if err != nil { - return err - } - - status := make([]byte, 1) - _, err = io.ReadFull(r, status) - if err != nil { - return err - } - if status[0] != 1 { - return errors.New(fmt.Sprintf("server rejected authentication")) - } - - if r.Buffered() != 0 { - return errors.New(fmt.Sprintf("%d bytes left after extended OR port authentication", r.Buffered())) - } - - return nil -} - -// See section 3.1 of 196-transport-control-ports.txt. -const ( - extOrCmdDone = 0x0000 - extOrCmdUserAddr = 0x0001 - extOrCmdTransport = 0x0002 - extOrCmdOkay = 0x1000 - extOrCmdDeny = 0x1001 -) - -func extOrPortWriteCommand(s *net.TCPConn, cmd uint16, body []byte) error { - var buf bytes.Buffer - if len(body) > 65535 { - return errors.New("command exceeds maximum length of 65535") - } - err := binary.Write(&buf, binary.BigEndian, cmd) - if err != nil { - return err - } - err = binary.Write(&buf, binary.BigEndian, uint16(len(body))) - if err != nil { - return err - } - err = binary.Write(&buf, binary.BigEndian, body) - if err != nil { - return err - } - _, err = s.Write(buf.Bytes()) - if err != nil { - return err - } - - return nil -} - -// Send a USERADDR command on s. See section 3.1.2.1 of -// 196-transport-control-ports.txt. -func extOrPortSendUserAddr(s *net.TCPConn, conn net.Conn) error { - return extOrPortWriteCommand(s, extOrCmdUserAddr, []byte(conn.RemoteAddr().String())) -} - -// Send a TRANSPORT command on s. See section 3.1.2.2 of -// 196-transport-control-ports.txt. -func extOrPortSendTransport(s *net.TCPConn, methodName string) error { - return extOrPortWriteCommand(s, extOrCmdTransport, []byte(methodName)) -} - -// Send a DONE command on s. See section 3.1 of 196-transport-control-ports.txt. -func extOrPortSendDone(s *net.TCPConn) error { - return extOrPortWriteCommand(s, extOrCmdDone, []byte{}) -} - -func extOrPortRecvCommand(s *net.TCPConn) (cmd uint16, body []byte, err error) { - var bodyLen uint16 - data := make([]byte, 4) - - _, err = io.ReadFull(s, data) - if err != nil { - return - } - buf := bytes.NewBuffer(data) - err = binary.Read(buf, binary.BigEndian, &cmd) - if err != nil { - return - } - err = binary.Read(buf, binary.BigEndian, &bodyLen) - if err != nil { - return - } - body = make([]byte, bodyLen) - _, err = io.ReadFull(s, body) - if err != nil { - return - } - - return cmd, body, err -} - -// Send USERADDR and TRANSPORT commands followed by a DONE command. Wait for an -// OKAY or DENY response command from the server. Returns nil if and only if -// OKAY is received. -func extOrPortSetup(s *net.TCPConn, conn net.Conn, methodName string) error { - var err error - - err = extOrPortSendUserAddr(s, conn) - if err != nil { - return err - } - err = extOrPortSendTransport(s, methodName) - if err != nil { - return err - } - err = extOrPortSendDone(s) - if err != nil { - return err - } - cmd, _, err := extOrPortRecvCommand(s) - if err != nil { - return err - } - if cmd == extOrCmdDeny { - return errors.New("server returned DENY after our USERADDR and DONE") - } else if cmd != extOrCmdOkay { - return errors.New(fmt.Sprintf("server returned unknown command 0x%04x after our USERADDR and DONE", cmd)) - } - - return nil -} - -// Connect to info.ExtendedOrAddr if defined, or else info.OrAddr, and return an -// open *net.TCPConn. If connecting to the extended OR port, extended OR port -// authentication à la 217-ext-orport-auth.txt is done before returning; an -// error is returned if authentication fails. -func ConnectOr(info *ServerInfo, conn net.Conn, methodName string) (*net.TCPConn, error) { - if info.ExtendedOrAddr == nil { - return net.DialTCP("tcp", nil, info.OrAddr) - } - - s, err := net.DialTCP("tcp", nil, info.ExtendedOrAddr) - if err != nil { - return nil, err - } - s.SetDeadline(time.Now().Add(5 * time.Second)) - err = extOrPortAuthenticate(s, info) - if err != nil { - s.Close() - return nil, err - } - err = extOrPortSetup(s, conn, methodName) - if err != nil { - s.Close() - return nil, err - } - s.SetDeadline(time.Time{}) - - return s, nil -} diff --git a/src/pt/pt_test.go b/src/pt/pt_test.go deleted file mode 100644 index cc7924a..0000000 --- a/src/pt/pt_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package pt - -import "os" -import "testing" - -func stringIsSafe(s string) bool { - for _, c := range []byte(s) { - if c == '\x00' || c == '\n' || c > 127 { - return false - } - } - return true -} - -func TestEscape(t *testing.T) { - tests := [...]string{ - "", - "abc", - "a\nb", - "a\b", - "ab\", - "ab\\n", - "ab\n\", - } - - check := func(input string) { - output := escape(input) - if !stringIsSafe(output) { - t.Errorf("escape(%q) → %q", input, output) - } - } - for _, input := range tests { - check(input) - } - for b := 0; b < 256; b++ { - // check one-byte string with each byte value 0–255 - check(string([]byte{byte(b)})) - // check UTF-8 encoding of each character 0–255 - check(string(b)) - } -} - -func TestGetManagedTransportVer(t *testing.T) { - tests := [...]struct { - input, expected string - }{ - {"1", "1"}, - {"1,1", "1"}, - {"1,2", "1"}, - {"2,1", "1"}, - {"2", ""}, - } - - for _, test := range tests { - os.Setenv("TOR_PT_MANAGED_TRANSPORT_VER", test.input) - output := getManagedTransportVer() - if output != test.expected { - t.Errorf("%q → %q (expected %q)", test.input, output, test.expected) - } - } -} diff --git a/src/pt/socks/socks.go b/src/pt/socks/socks.go deleted file mode 100644 index 788d53c..0000000 --- a/src/pt/socks/socks.go +++ /dev/null @@ -1,107 +0,0 @@ -// SOCKS4a server library. - -package socks - -import ( - "bufio" - "errors" - "fmt" - "io" - "net" -) - -const ( - socksVersion = 0x04 - socksCmdConnect = 0x01 - socksResponseVersion = 0x00 - socksRequestGranted = 0x5a - socksRequestFailed = 0x5b -) - -// Read a SOCKS4a connect request, and call the given connect callback with the -// requested destination string. If the callback returns an error, sends a SOCKS -// request failed message. Otherwise, sends a SOCKS request granted message for -// the destination address returned by the callback. -func AwaitSocks4aConnect(conn *net.TCPConn, connect func(string) (*net.TCPAddr, error)) error { - dest, err := readSocks4aConnect(conn) - if err != nil { - sendSocks4aResponseFailed(conn) - return err - } - destAddr, err := connect(dest) - if err != nil { - sendSocks4aResponseFailed(conn) - return err - } - sendSocks4aResponseGranted(conn, destAddr) - return nil -} - -// Read a SOCKS4a connect request. Returns a "host:port" string. -func readSocks4aConnect(s io.Reader) (string, error) { - r := bufio.NewReader(s) - - var h [8]byte - n, err := io.ReadFull(r, h[:]) - if err != nil { - return "", errors.New(fmt.Sprintf("after %d bytes of SOCKS header: %s", n, err)) - } - if h[0] != socksVersion { - return "", errors.New(fmt.Sprintf("SOCKS header had version 0x%02x, not 0x%02x", h[0], socksVersion)) - } - if h[1] != socksCmdConnect { - return "", errors.New(fmt.Sprintf("SOCKS header had command 0x%02x, not 0x%02x", h[1], socksCmdConnect)) - } - - _, err = r.ReadBytes('\x00') - if err != nil { - return "", errors.New(fmt.Sprintf("reading SOCKS userid: %s", err)) - } - - var port int - var host string - - port = int(h[2])<<8 | int(h[3])<<0 - if h[4] == 0 && h[5] == 0 && h[6] == 0 && h[7] != 0 { - hostBytes, err := r.ReadBytes('\x00') - if err != nil { - return "", errors.New(fmt.Sprintf("reading SOCKS4a destination: %s", err)) - } - host = string(hostBytes[:len(hostBytes)-1]) - } else { - host = net.IPv4(h[4], h[5], h[6], h[7]).String() - } - - if r.Buffered() != 0 { - return "", errors.New(fmt.Sprintf("%d bytes left after SOCKS header", r.Buffered())) - } - - return fmt.Sprintf("%s:%d", host, port), nil -} - -// Send a SOCKS4a response with the given code and address. -func sendSocks4aResponse(w io.Writer, code byte, addr *net.TCPAddr) error { - var resp [8]byte - resp[0] = socksResponseVersion - resp[1] = code - resp[2] = byte((addr.Port >> 8) & 0xff) - resp[3] = byte((addr.Port >> 0) & 0xff) - resp[4] = addr.IP[0] - resp[5] = addr.IP[1] - resp[6] = addr.IP[2] - resp[7] = addr.IP[3] - _, err := w.Write(resp[:]) - return err -} - -var emptyAddr = net.TCPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 0} - -// Send a SOCKS4a response code 0x5a. -func sendSocks4aResponseGranted(w io.Writer, addr *net.TCPAddr) error { - return sendSocks4aResponse(w, socksRequestGranted, addr) -} - -// Send a SOCKS4a response code 0x5b (with an all-zero address). -func sendSocks4aResponseFailed(w io.Writer) error { - return sendSocks4aResponse(w, socksRequestFailed, &emptyAddr) -} diff --git a/src/websocket-client/websocket-client.go b/src/websocket-client/websocket-client.go deleted file mode 100644 index 1c3b3b9..0000000 --- a/src/websocket-client/websocket-client.go +++ /dev/null @@ -1,254 +0,0 @@ -// Tor websocket client transport plugin. -// -// Usage: -// ClientTransportPlugin websocket exec ./websocket-client - -package main - -import ( - "code.google.com/p/go.net/websocket" - "flag" - "fmt" - "io" - "net" - "net/url" - "os" - "os/signal" - "sync" - "time" -) - -import "pt" -import "pt/socks" - -const ptMethodName = "websocket" -const socksTimeout = 2 * time.Second -const bufSiz = 1500 - -var logFile = os.Stderr - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -var logMutex sync.Mutex - -func usage() { - fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) - fmt.Printf("WebSocket client pluggable transport for Tor.\n") - fmt.Printf("Works only as a managed proxy.\n") - fmt.Printf("\n") - fmt.Printf(" -h, --help show this help.\n") - fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") - fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n") -} - -func Log(format string, v ...interface{}) { - dateStr := time.Now().Format("2006-01-02 15:04:05") - logMutex.Lock() - defer logMutex.Unlock() - msg := fmt.Sprintf(format, v...) - fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) -} - -func proxy(local *net.TCPConn, ws *websocket.Conn) { - var wg sync.WaitGroup - - wg.Add(2) - - // Local-to-WebSocket read loop. - go func() { - buf := make([]byte, bufSiz) - var err error - for { - n, er := local.Read(buf[:]) - if n > 0 { - ew := websocket.Message.Send(ws, buf[:n]) - if ew != nil { - err = ew - break - } - } - if er != nil { - err = er - break - } - } - if err != nil && err != io.EOF { - Log("%s", err) - } - local.CloseRead() - ws.Close() - - wg.Done() - }() - - // WebSocket-to-local read loop. - go func() { - var buf []byte - var err error - for { - er := websocket.Message.Receive(ws, &buf) - if er != nil { - err = er - break - } - n, ew := local.Write(buf) - if ew != nil { - err = ew - break - } - if n != len(buf) { - err = io.ErrShortWrite - break - } - } - if err != nil && err != io.EOF { - Log("%s", err) - } - local.CloseWrite() - ws.Close() - - wg.Done() - }() - - wg.Wait() -} - -func handleConnection(conn *net.TCPConn) error { - defer conn.Close() - - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - var ws *websocket.Conn - - conn.SetDeadline(time.Now().Add(socksTimeout)) - err := socks.AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) { - // Disable deadline. - conn.SetDeadline(time.Time{}) - Log("SOCKS request for %s", dest) - destAddr, err := net.ResolveTCPAddr("tcp", dest) - if err != nil { - return nil, err - } - wsUrl := url.URL{Scheme: "ws", Host: dest} - ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String()) - if err != nil { - return nil, err - } - Log("WebSocket connection to %s", ws.Config().Location.String()) - return destAddr, nil - }) - if err != nil { - return err - } - defer ws.Close() - proxy(conn, ws) - return nil -} - -func socksAcceptLoop(ln *net.TCPListener) error { - for { - socks, err := ln.AcceptTCP() - if err != nil { - return err - } - go func() { - err := handleConnection(socks) - if err != nil { - Log("SOCKS from %s: %s", socks.RemoteAddr(), err) - } - }() - } - return nil -} - -func startListener(addrStr string) (*net.TCPListener, error) { - addr, err := net.ResolveTCPAddr("tcp", addrStr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } - go func() { - err := socksAcceptLoop(ln) - if err != nil { - Log("accept: %s", err) - } - }() - return ln, nil -} - -func main() { - var logFilename string - var socksAddrStrs = []string{"127.0.0.1:0"} - var socksArg string - - flag.Usage = usage - flag.StringVar(&logFilename, "log", "", "log file to write to") - flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections") - flag.Parse() - - if logFilename != "" { - f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) - os.Exit(1) - } - logFile = f - } - - if socksArg != "" { - socksAddrStrs = []string{socksArg} - } - - Log("starting") - pt.ClientSetup([]string{ptMethodName}) - - listeners := make([]*net.TCPListener, 0) - for _, socksAddrStr := range socksAddrStrs { - ln, err := startListener(socksAddrStr) - if err != nil { - pt.CmethodError(ptMethodName, err.Error()) - } - pt.Cmethod(ptMethodName, "socks4", ln.Addr()) - Log("listening on %s", ln.Addr().String()) - listeners = append(listeners, ln) - } - pt.CmethodsDone() - - var numHandlers int = 0 - - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt) - var sigint bool = false - for !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } - - for _, ln := range listeners { - ln.Close() - } - - sigint = false - for numHandlers != 0 && !sigint { - select { - case n := <-handlerChan: - numHandlers += n - case <-signalChan: - Log("SIGINT") - sigint = true - } - } -} diff --git a/src/websocket-server/websocket-server.go b/src/websocket-server/websocket-server.go deleted file mode 100644 index 207be8d..0000000 --- a/src/websocket-server/websocket-server.go +++ /dev/null @@ -1,285 +0,0 @@ -// Tor websocket server transport plugin. -// -// Usage: -// ServerTransportPlugin websocket exec ./websocket-server --port 9901 - -package main - -import ( - "encoding/base64" - "errors" - "flag" - "fmt" - "io" - "net" - "net/http" - "os" - "os/signal" - "sync" - "syscall" - "time" -) - -import "pt" -import "websocket" - -const ptMethodName = "websocket" -const requestTimeout = 10 * time.Second - -// "4/3+1" accounts for possible base64 encoding. -const maxMessageSize = 64*1024*4/3 + 1 - -var logFile = os.Stderr - -var ptInfo pt.ServerInfo - -// When a connection handler starts, +1 is written to this channel; when it -// ends, -1 is written. -var handlerChan = make(chan int) - -func usage() { - fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) - fmt.Printf("WebSocket server pluggable transport for Tor.\n") - fmt.Printf("Works only as a managed proxy.\n") - fmt.Printf("\n") - fmt.Printf(" -h, --help show this help.\n") - fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") - fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n") -} - -var logMutex sync.Mutex - -func log(format string, v ...interface{}) { - dateStr := time.Now().Format("2006-01-02 15:04:05") - logMutex.Lock() - defer logMutex.Unlock() - msg := fmt.Sprintf(format, v...) - fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) -} - -// An abstraction that makes an underlying WebSocket connection look like an -// io.ReadWriteCloser. It internally takes care of things like base64 encoding -// and decoding. -type webSocketConn struct { - Ws *websocket.WebSocket - Base64 bool - messageBuf []byte -} - -// Implements io.Reader. -func (conn *webSocketConn) Read(b []byte) (n int, err error) { - for len(conn.messageBuf) == 0 { - var m websocket.Message - m, err = conn.Ws.ReadMessage() - if err != nil { - return - } - if m.Opcode == 8 { - err = io.EOF - return - } - if conn.Base64 { - if m.Opcode != 1 { - err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode)) - return - } - conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload))) - var num int - num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload) - if err != nil { - return - } - conn.messageBuf = conn.messageBuf[:num] - } else { - if m.Opcode != 2 { - err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode)) - return - } - conn.messageBuf = m.Payload - } - } - - n = copy(b, conn.messageBuf) - conn.messageBuf = conn.messageBuf[n:] - - return -} - -// Implements io.Writer. -func (conn *webSocketConn) Write(b []byte) (n int, err error) { - if conn.Base64 { - buf := make([]byte, base64.StdEncoding.EncodedLen(len(b))) - base64.StdEncoding.Encode(buf, b) - err = conn.Ws.WriteMessage(1, buf) - if err != nil { - return - } - n = len(b) - } else { - err = conn.Ws.WriteMessage(2, b) - n = len(b) - } - return -} - -// Implements io.Closer. -func (conn *webSocketConn) Close() error { - // Ignore any error in trying to write a Close frame. - _ = conn.Ws.WriteFrame(8, nil) - return conn.Ws.Conn.Close() -} - -// Create a new webSocketConn. -func newWebSocketConn(ws *websocket.WebSocket) webSocketConn { - var conn webSocketConn - conn.Ws = ws - conn.Base64 = (ws.Subprotocol == "base64") - return conn -} - -// Copy from WebSocket to socket and vice versa. -func proxy(local *net.TCPConn, conn *webSocketConn) { - var wg sync.WaitGroup - - wg.Add(2) - - go func() { - _, err := io.Copy(conn, local) - if err != nil { - log("error copying ORPort to WebSocket") - } - local.CloseRead() - conn.Close() - wg.Done() - }() - - go func() { - _, err := io.Copy(local, conn) - if err != nil { - log("error copying WebSocket to ORPort") - } - local.CloseWrite() - conn.Close() - wg.Done() - }() - - wg.Wait() -} - -func webSocketHandler(ws *websocket.WebSocket) { - // Undo timeouts on HTTP request handling. - ws.Conn.SetDeadline(time.Time{}) - conn := newWebSocketConn(ws) - - handlerChan <- 1 - defer func() { - handlerChan <- -1 - }() - - s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName) - if err != nil { - log("Failed to connect to ORPort: " + err.Error()) - return - } - - proxy(s, &conn) -} - -func startListener(addr *net.TCPAddr) (*net.TCPListener, error) { - ln, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } - go func() { - var config websocket.Config - config.Subprotocols = []string{"base64"} - config.MaxMessageSize = maxMessageSize - s := &http.Server{ - Handler: config.Handler(webSocketHandler), - ReadTimeout: requestTimeout, - } - err = s.Serve(ln) - if err != nil { - log("http.Serve: " + err.Error()) - } - }() - return ln, nil -} - -func main() { - var logFilename string - var port int - - flag.Usage = usage - flag.StringVar(&logFilename, "log", "", "log file to write to") - flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor") - flag.Parse() - - if logFilename != "" { - f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) - os.Exit(1) - } - logFile = f - } - - log("starting") - ptInfo = pt.ServerSetup([]string{ptMethodName}) - - listeners := make([]*net.TCPListener, 0) - for _, bindAddr := range ptInfo.BindAddrs { - // Override tor's requested port (which is 0 if this transport - // has not been run before) with the one requested by the --port - // option. - if port != 0 { - bindAddr.Addr.Port = port - } - - ln, err := startListener(bindAddr.Addr) - if err != nil { - pt.SmethodError(bindAddr.MethodName, err.Error()) - continue - } - pt.Smethod(bindAddr.MethodName, ln.Addr()) - log("listening on %s", ln.Addr().String()) - listeners = append(listeners, ln) - } - pt.SmethodsDone() - - var numHandlers int = 0 - var sig os.Signal - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - sig = nil - for sig == nil { - select { - case n := <-handlerChan: - numHandlers += n - case sig = <-sigChan: - } - } - log("Got first signal %q with %d running handlers.", sig, numHandlers) - for _, ln := range listeners { - ln.Close() - } - - if sig == syscall.SIGTERM { - log("Caught signal %q, exiting.", sig) - return - } - - sig = nil - for sig == nil && numHandlers != 0 { - select { - case n := <-handlerChan: - numHandlers += n - log("%d remaining handlers.", numHandlers) - case sig = <-sigChan: - } - } - if sig != nil { - log("Got second signal %q with %d running handlers.", sig, numHandlers) - } -} diff --git a/src/websocket/websocket.go b/src/websocket/websocket.go deleted file mode 100644 index dc228d1..0000000 --- a/src/websocket/websocket.go +++ /dev/null @@ -1,431 +0,0 @@ -// WebSocket library. Only the RFC 6455 variety of WebSocket is supported. -// -// Reading and writing is strictly per-frame (or per-message). There is no way -// to partially read a frame. Config.MaxMessageSize affords control of the -// maximum buffering of messages. -// -// The reason for using this custom implementation instead of -// code.google.com/p/go.net/websocket is that the latter has problems with long -// messages and does not support server subprotocols. -// "Denial of Service Protection in Go HTTP Servers" -// https://code.google.com/p/go/issues/detail?id=2093 -// "go.websocket: Read/Copy fail with long frames" -// https://code.google.com/p/go/issues/detail?id=2134 -// http://golang.org/pkg/net/textproto/#pkg-bugs -// "To let callers manage exposure to denial of service attacks, Reader should -// allow them to set and reset a limit on the number of bytes read from the -// connection." -// "websocket.Dial doesn't limit response header length as http.Get does" -// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI -// -// Example usage: -// -// func doSomething(ws *WebSocket) { -// } -// var config websocket.Config -// config.Subprotocols = []string{"base64"} -// config.MaxMessageSize = 2500 -// http.Handle("/", config.Handler(doSomething)) -// err = http.ListenAndServe(":8080", nil) - -package websocket - -import ( - "bufio" - "bytes" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "net/http" - "strings" -) - -// Settings for potential WebSocket connections. Subprotocols is a list of -// supported subprotocols as in RFC 6455 section 1.9. When answering client -// requests, the first of the client's requests subprotocols that is also in -// this list (if any) will be used as the subprotocol for the connection. -// MaxMessageSize is a limit on buffering messages. -type Config struct { - Subprotocols []string - MaxMessageSize int -} - -// Representation of a WebSocket frame. The Payload is always without masking. -type Frame struct { - Fin bool - Opcode byte - Payload []byte -} - -// Return true iff the frame's opcode says it is a control frame. -func (frame *Frame) IsControl() bool { - return (frame.Opcode & 0x08) != 0 -} - -// Representation of a WebSocket message. The Payload is always without masking. -type Message struct { - Opcode byte - Payload []byte -} - -// A WebSocket connection after hijacking from HTTP. -type WebSocket struct { - // Conn and ReadWriter from http.ResponseWriter.Hijack. - Conn net.Conn - Bufrw *bufio.ReadWriter - // Whether we are a client or a server has implications for masking. - IsClient bool - // Set from a parent Config. - MaxMessageSize int - // The single selected subprotocol after negotiation, or "". - Subprotocol string - // Buffer for message payloads, which may be interrupted by control - // messages. - messageBuf bytes.Buffer -} - -func applyMask(payload []byte, maskKey [4]byte) { - for i := 0; i < len(payload); i++ { - payload[i] = payload[i] ^ maskKey[i%4] - } -} - -func (ws *WebSocket) maxMessageSize() int { - if ws.MaxMessageSize == 0 { - return 64000 - } - return ws.MaxMessageSize -} - -// Read a single frame from the WebSocket. -func (ws *WebSocket) ReadFrame() (frame Frame, err error) { - var b byte - err = binary.Read(ws.Bufrw, binary.BigEndian, &b) - if err != nil { - return - } - frame.Fin = (b & 0x80) != 0 - frame.Opcode = b & 0x0f - err = binary.Read(ws.Bufrw, binary.BigEndian, &b) - if err != nil { - return - } - masked := (b & 0x80) != 0 - - payloadLen := uint64(b & 0x7f) - if payloadLen == 126 { - var short uint16 - err = binary.Read(ws.Bufrw, binary.BigEndian, &short) - if err != nil { - return - } - payloadLen = uint64(short) - } else if payloadLen == 127 { - var long uint64 - err = binary.Read(ws.Bufrw, binary.BigEndian, &long) - if err != nil { - return - } - payloadLen = long - } - if payloadLen > uint64(ws.maxMessageSize()) { - err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize)) - return - } - - maskKey := [4]byte{} - if masked { - if ws.IsClient { - err = errors.New("client got masked frame") - return - } - err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey) - if err != nil { - return - } - } else { - if !ws.IsClient { - err = errors.New("server got unmasked frame") - return - } - } - - frame.Payload = make([]byte, payloadLen) - _, err = io.ReadFull(ws.Bufrw, frame.Payload) - if err != nil { - return - } - if masked { - applyMask(frame.Payload, maskKey) - } - - return frame, nil -} - -// Read a single message from the WebSocket. Multiple fragmented frames are -// combined into a single message before being returned. Non-control messages -// may be interrupted by control frames. The control frames are returned as -// individual messages before the message that they interrupt. -func (ws *WebSocket) ReadMessage() (message Message, err error) { - var opcode byte = 0 - for { - var frame Frame - frame, err = ws.ReadFrame() - if err != nil { - return - } - if frame.IsControl() { - if !frame.Fin { - err = errors.New("control frame has fin bit unset") - return - } - message.Opcode = frame.Opcode - message.Payload = frame.Payload - return message, nil - } - - if opcode == 0 { - if frame.Opcode == 0 { - err = errors.New("first frame has opcode 0") - return - } - opcode = frame.Opcode - } else { - if frame.Opcode != 0 { - err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode)) - return - } - } - if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize { - err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d", - ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize)) - return - } - ws.messageBuf.Write(frame.Payload) - if frame.Fin { - break - } - } - message.Opcode = opcode - message.Payload = ws.messageBuf.Bytes() - ws.messageBuf.Reset() - - return message, nil -} - -// Write a single frame to the WebSocket stream. Destructively masks payload in -// place if ws.IsClient. Frames are always unfragmented. -func (ws *WebSocket) WriteFrame(opcode byte, payload []byte) (err error) { - if opcode >= 16 { - err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode)) - return - } - ws.Bufrw.WriteByte(0x80 | opcode) - - var maskBit byte - var maskKey [4]byte - if ws.IsClient { - _, err = io.ReadFull(rand.Reader, maskKey[:]) - if err != nil { - return - } - applyMask(payload, maskKey) - maskBit = 0x80 - } else { - maskBit = 0x00 - } - - if len(payload) < 126 { - ws.Bufrw.WriteByte(maskBit | byte(len(payload))) - } else if len(payload) <= 0xffff { - ws.Bufrw.WriteByte(maskBit | 126) - binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload))) - } else { - ws.Bufrw.WriteByte(maskBit | 127) - binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload))) - } - - if ws.IsClient { - _, err = ws.Bufrw.Write(maskKey[:]) - if err != nil { - return - } - } - _, err = ws.Bufrw.Write(payload) - if err != nil { - return - } - - ws.Bufrw.Flush() - - return -} - -// Write a single message to the WebSocket stream. Destructively masks payload -// in place if ws.IsClient. Messages are always sent as a single unfragmented -// frame. -func (ws *WebSocket) WriteMessage(opcode byte, payload []byte) (err error) { - return ws.WriteFrame(opcode, payload) -} - -// Split a string on commas and trim whitespace. -func commaSplit(s string) []string { - var result []string - if strings.TrimSpace(s) == "" { - return result - } - for _, e := range strings.Split(s, ",") { - result = append(result, strings.TrimSpace(e)) - } - return result -} - -// Returns true iff one of the strings in haystack is needle (case-insensitive). -func containsCase(haystack []string, needle string) bool { - for _, e := range haystack { - if strings.ToLower(e) == strings.ToLower(needle) { - return true - } - } - return false -} - -// One-step SHA-1 hash of a string. -func sha1Hash(data string) []byte { - h := sha1.New() - h.Write([]byte(data)) - return h.Sum(nil) -} - -func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) { - w.Header().Set("Connection", "close") - bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code))) - w.Header().Write(bufrw) - bufrw.WriteString("\r\n") - bufrw.Flush() -} - -// An implementation of http.Handler with a Config. The ServeHTTP function calls -// Callback assuming WebSocket HTTP negotiation is successful. -type HTTPHandler struct { - Config *Config - Callback func(*WebSocket) -} - -// Implements the http.Handler interface. -func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - conn, bufrw, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer conn.Close() - - // See RFC 6455 section 4.2.1 for this sequence of checks. - - // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"... - if req.Method != "GET" { - httpError(w, bufrw, http.StatusMethodNotAllowed) - return - } - if req.URL.Path != "/" { - httpError(w, bufrw, http.StatusNotFound) - return - } - // 2. A |Host| header field containing the server's authority. - // We deliberately skip this test. - // 3. An |Upgrade| header field containing the value "websocket", - // treated as an ASCII case-insensitive value. - if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 4. A |Connection| header field that includes the token "Upgrade", - // treated as an ASCII case-insensitive value. - if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value - // that, when decoded, is 16 bytes in length. - websocketKey := req.Header.Get("Sec-WebSocket-Key") - key, err := base64.StdEncoding.DecodeString(websocketKey) - if err != nil || len(key) != 16 { - httpError(w, bufrw, http.StatusBadRequest) - return - } - // 6. A |Sec-WebSocket-Version| header field, with a value of 13. - // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10. - var knownVersions = []string{"8", "13"} - websocketVersion := req.Header.Get("Sec-WebSocket-Version") - if !containsCase(knownVersions, websocketVersion) { - // "If this version does not match a version understood by the - // server, the server MUST abort the WebSocket handshake - // described in this section and instead send an appropriate - // HTTP error code (such as 426 Upgrade Required) and a - // |Sec-WebSocket-Version| header field indicating the - // version(s) the server is capable of understanding." - w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", ")) - httpError(w, bufrw, 426) - return - } - // 7. Optionally, an |Origin| header field. - // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of - // values indicating which protocols the client would like to speak, ordered - // by preference. - clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol")) - // 9. Optionally, a |Sec-WebSocket-Extensions| header field... - // 10. Optionally, other header fields... - - var ws WebSocket - ws.Conn = conn - ws.Bufrw = bufrw - ws.IsClient = false - ws.MaxMessageSize = handler.Config.MaxMessageSize - - // See RFC 6455 section 4.2.2, item 5 for these steps. - - // 1. A Status-Line with a 101 response code as per RFC 2616. - bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))) - // 2. An |Upgrade| header field with value "websocket" as per RFC 2616. - w.Header().Set("Upgrade", "websocket") - // 3. A |Connection| header field with value "Upgrade". - w.Header().Set("Connection", "Upgrade") - // 4. A |Sec-WebSocket-Accept| header field. The value of this header - // field is constructed by concatenating /key/, defined above in step 4 - // in Section 4.2.2, with the string - // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this - // concatenated value to obtain a 20-byte value and base64-encoding (see - // Section 4 of [RFC4648]) this 20-byte hash. - const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID)) - w.Header().Set("Sec-WebSocket-Accept", acceptKey) - // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value - // /subprotocol/ as defined in step 4 in Section 4.2.2. - for _, clientProto := range clientProtocols { - for _, serverProto := range handler.Config.Subprotocols { - if clientProto == serverProto { - ws.Subprotocol = clientProto - w.Header().Set("Sec-WebSocket-Protocol", clientProto) - break - } - } - } - // 6. Optionally, a |Sec-WebSocket-Extensions| header field... - w.Header().Write(bufrw) - bufrw.WriteString("\r\n") - bufrw.Flush() - - // Call the WebSocket-specific handler. - handler.Callback(&ws) -} - -// Return an http.Handler with the given callback function. -func (config *Config) Handler(callback func(*WebSocket)) http.Handler { - return &HTTPHandler{config, callback} -} diff --git a/websocket-client/websocket-client.go b/websocket-client/websocket-client.go new file mode 100644 index 0000000..7f838bb --- /dev/null +++ b/websocket-client/websocket-client.go @@ -0,0 +1,254 @@ +// Tor websocket client transport plugin. +// +// Usage: +// ClientTransportPlugin websocket exec ./websocket-client + +package main + +import ( + "code.google.com/p/go.net/websocket" + "flag" + "fmt" + "io" + "net" + "net/url" + "os" + "os/signal" + "sync" + "time" +) + +import "../pt" +import "../pt/socks" + +const ptMethodName = "websocket" +const socksTimeout = 2 * time.Second +const bufSiz = 1500 + +var logFile = os.Stderr + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +var logMutex sync.Mutex + +func usage() { + fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) + fmt.Printf("WebSocket client pluggable transport for Tor.\n") + fmt.Printf("Works only as a managed proxy.\n") + fmt.Printf("\n") + fmt.Printf(" -h, --help show this help.\n") + fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") + fmt.Printf(" --socks ADDR listen for SOCKS on ADDR.\n") +} + +func Log(format string, v ...interface{}) { + dateStr := time.Now().Format("2006-01-02 15:04:05") + logMutex.Lock() + defer logMutex.Unlock() + msg := fmt.Sprintf(format, v...) + fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) +} + +func proxy(local *net.TCPConn, ws *websocket.Conn) { + var wg sync.WaitGroup + + wg.Add(2) + + // Local-to-WebSocket read loop. + go func() { + buf := make([]byte, bufSiz) + var err error + for { + n, er := local.Read(buf[:]) + if n > 0 { + ew := websocket.Message.Send(ws, buf[:n]) + if ew != nil { + err = ew + break + } + } + if er != nil { + err = er + break + } + } + if err != nil && err != io.EOF { + Log("%s", err) + } + local.CloseRead() + ws.Close() + + wg.Done() + }() + + // WebSocket-to-local read loop. + go func() { + var buf []byte + var err error + for { + er := websocket.Message.Receive(ws, &buf) + if er != nil { + err = er + break + } + n, ew := local.Write(buf) + if ew != nil { + err = ew + break + } + if n != len(buf) { + err = io.ErrShortWrite + break + } + } + if err != nil && err != io.EOF { + Log("%s", err) + } + local.CloseWrite() + ws.Close() + + wg.Done() + }() + + wg.Wait() +} + +func handleConnection(conn *net.TCPConn) error { + defer conn.Close() + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + var ws *websocket.Conn + + conn.SetDeadline(time.Now().Add(socksTimeout)) + err := socks.AwaitSocks4aConnect(conn, func(dest string) (*net.TCPAddr, error) { + // Disable deadline. + conn.SetDeadline(time.Time{}) + Log("SOCKS request for %s", dest) + destAddr, err := net.ResolveTCPAddr("tcp", dest) + if err != nil { + return nil, err + } + wsUrl := url.URL{Scheme: "ws", Host: dest} + ws, err = websocket.Dial(wsUrl.String(), "", wsUrl.String()) + if err != nil { + return nil, err + } + Log("WebSocket connection to %s", ws.Config().Location.String()) + return destAddr, nil + }) + if err != nil { + return err + } + defer ws.Close() + proxy(conn, ws) + return nil +} + +func socksAcceptLoop(ln *net.TCPListener) error { + for { + socks, err := ln.AcceptTCP() + if err != nil { + return err + } + go func() { + err := handleConnection(socks) + if err != nil { + Log("SOCKS from %s: %s", socks.RemoteAddr(), err) + } + }() + } + return nil +} + +func startListener(addrStr string) (*net.TCPListener, error) { + addr, err := net.ResolveTCPAddr("tcp", addrStr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go func() { + err := socksAcceptLoop(ln) + if err != nil { + Log("accept: %s", err) + } + }() + return ln, nil +} + +func main() { + var logFilename string + var socksAddrStrs = []string{"127.0.0.1:0"} + var socksArg string + + flag.Usage = usage + flag.StringVar(&logFilename, "log", "", "log file to write to") + flag.StringVar(&socksArg, "socks", "", "address on which to listen for SOCKS connections") + flag.Parse() + + if logFilename != "" { + f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) + os.Exit(1) + } + logFile = f + } + + if socksArg != "" { + socksAddrStrs = []string{socksArg} + } + + Log("starting") + pt.ClientSetup([]string{ptMethodName}) + + listeners := make([]*net.TCPListener, 0) + for _, socksAddrStr := range socksAddrStrs { + ln, err := startListener(socksAddrStr) + if err != nil { + pt.CmethodError(ptMethodName, err.Error()) + } + pt.Cmethod(ptMethodName, "socks4", ln.Addr()) + Log("listening on %s", ln.Addr().String()) + listeners = append(listeners, ln) + } + pt.CmethodsDone() + + var numHandlers int = 0 + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + var sigint bool = false + for !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } + + for _, ln := range listeners { + ln.Close() + } + + sigint = false + for numHandlers != 0 && !sigint { + select { + case n := <-handlerChan: + numHandlers += n + case <-signalChan: + Log("SIGINT") + sigint = true + } + } +} diff --git a/websocket-server/websocket-server.go b/websocket-server/websocket-server.go new file mode 100644 index 0000000..e5ed1c5 --- /dev/null +++ b/websocket-server/websocket-server.go @@ -0,0 +1,285 @@ +// Tor websocket server transport plugin. +// +// Usage: +// ServerTransportPlugin websocket exec ./websocket-server --port 9901 + +package main + +import ( + "encoding/base64" + "errors" + "flag" + "fmt" + "io" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" +) + +import "../pt" +import "../websocket" + +const ptMethodName = "websocket" +const requestTimeout = 10 * time.Second + +// "4/3+1" accounts for possible base64 encoding. +const maxMessageSize = 64*1024*4/3 + 1 + +var logFile = os.Stderr + +var ptInfo pt.ServerInfo + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +func usage() { + fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) + fmt.Printf("WebSocket server pluggable transport for Tor.\n") + fmt.Printf("Works only as a managed proxy.\n") + fmt.Printf("\n") + fmt.Printf(" -h, --help show this help.\n") + fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") + fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n") +} + +var logMutex sync.Mutex + +func log(format string, v ...interface{}) { + dateStr := time.Now().Format("2006-01-02 15:04:05") + logMutex.Lock() + defer logMutex.Unlock() + msg := fmt.Sprintf(format, v...) + fmt.Fprintf(logFile, "%s %s\n", dateStr, msg) +} + +// An abstraction that makes an underlying WebSocket connection look like an +// io.ReadWriteCloser. It internally takes care of things like base64 encoding +// and decoding. +type webSocketConn struct { + Ws *websocket.WebSocket + Base64 bool + messageBuf []byte +} + +// Implements io.Reader. +func (conn *webSocketConn) Read(b []byte) (n int, err error) { + for len(conn.messageBuf) == 0 { + var m websocket.Message + m, err = conn.Ws.ReadMessage() + if err != nil { + return + } + if m.Opcode == 8 { + err = io.EOF + return + } + if conn.Base64 { + if m.Opcode != 1 { + err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode)) + return + } + conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload))) + var num int + num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload) + if err != nil { + return + } + conn.messageBuf = conn.messageBuf[:num] + } else { + if m.Opcode != 2 { + err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode)) + return + } + conn.messageBuf = m.Payload + } + } + + n = copy(b, conn.messageBuf) + conn.messageBuf = conn.messageBuf[n:] + + return +} + +// Implements io.Writer. +func (conn *webSocketConn) Write(b []byte) (n int, err error) { + if conn.Base64 { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(b))) + base64.StdEncoding.Encode(buf, b) + err = conn.Ws.WriteMessage(1, buf) + if err != nil { + return + } + n = len(b) + } else { + err = conn.Ws.WriteMessage(2, b) + n = len(b) + } + return +} + +// Implements io.Closer. +func (conn *webSocketConn) Close() error { + // Ignore any error in trying to write a Close frame. + _ = conn.Ws.WriteFrame(8, nil) + return conn.Ws.Conn.Close() +} + +// Create a new webSocketConn. +func newWebSocketConn(ws *websocket.WebSocket) webSocketConn { + var conn webSocketConn + conn.Ws = ws + conn.Base64 = (ws.Subprotocol == "base64") + return conn +} + +// Copy from WebSocket to socket and vice versa. +func proxy(local *net.TCPConn, conn *webSocketConn) { + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + _, err := io.Copy(conn, local) + if err != nil { + log("error copying ORPort to WebSocket") + } + local.CloseRead() + conn.Close() + wg.Done() + }() + + go func() { + _, err := io.Copy(local, conn) + if err != nil { + log("error copying WebSocket to ORPort") + } + local.CloseWrite() + conn.Close() + wg.Done() + }() + + wg.Wait() +} + +func webSocketHandler(ws *websocket.WebSocket) { + // Undo timeouts on HTTP request handling. + ws.Conn.SetDeadline(time.Time{}) + conn := newWebSocketConn(ws) + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + s, err := pt.ConnectOr(&ptInfo, ws.Conn, ptMethodName) + if err != nil { + log("Failed to connect to ORPort: " + err.Error()) + return + } + + proxy(s, &conn) +} + +func startListener(addr *net.TCPAddr) (*net.TCPListener, error) { + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go func() { + var config websocket.Config + config.Subprotocols = []string{"base64"} + config.MaxMessageSize = maxMessageSize + s := &http.Server{ + Handler: config.Handler(webSocketHandler), + ReadTimeout: requestTimeout, + } + err = s.Serve(ln) + if err != nil { + log("http.Serve: " + err.Error()) + } + }() + return ln, nil +} + +func main() { + var logFilename string + var port int + + flag.Usage = usage + flag.StringVar(&logFilename, "log", "", "log file to write to") + flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor") + flag.Parse() + + if logFilename != "" { + f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) + os.Exit(1) + } + logFile = f + } + + log("starting") + ptInfo = pt.ServerSetup([]string{ptMethodName}) + + listeners := make([]*net.TCPListener, 0) + for _, bindAddr := range ptInfo.BindAddrs { + // Override tor's requested port (which is 0 if this transport + // has not been run before) with the one requested by the --port + // option. + if port != 0 { + bindAddr.Addr.Port = port + } + + ln, err := startListener(bindAddr.Addr) + if err != nil { + pt.SmethodError(bindAddr.MethodName, err.Error()) + continue + } + pt.Smethod(bindAddr.MethodName, ln.Addr()) + log("listening on %s", ln.Addr().String()) + listeners = append(listeners, ln) + } + pt.SmethodsDone() + + var numHandlers int = 0 + var sig os.Signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + sig = nil + for sig == nil { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } + log("Got first signal %q with %d running handlers.", sig, numHandlers) + for _, ln := range listeners { + ln.Close() + } + + if sig == syscall.SIGTERM { + log("Caught signal %q, exiting.", sig) + return + } + + sig = nil + for sig == nil && numHandlers != 0 { + select { + case n := <-handlerChan: + numHandlers += n + log("%d remaining handlers.", numHandlers) + case sig = <-sigChan: + } + } + if sig != nil { + log("Got second signal %q with %d running handlers.", sig, numHandlers) + } +} diff --git a/websocket/websocket.go b/websocket/websocket.go new file mode 100644 index 0000000..dc228d1 --- /dev/null +++ b/websocket/websocket.go @@ -0,0 +1,431 @@ +// WebSocket library. Only the RFC 6455 variety of WebSocket is supported. +// +// Reading and writing is strictly per-frame (or per-message). There is no way +// to partially read a frame. Config.MaxMessageSize affords control of the +// maximum buffering of messages. +// +// The reason for using this custom implementation instead of +// code.google.com/p/go.net/websocket is that the latter has problems with long +// messages and does not support server subprotocols. +// "Denial of Service Protection in Go HTTP Servers" +// https://code.google.com/p/go/issues/detail?id=2093 +// "go.websocket: Read/Copy fail with long frames" +// https://code.google.com/p/go/issues/detail?id=2134 +// http://golang.org/pkg/net/textproto/#pkg-bugs +// "To let callers manage exposure to denial of service attacks, Reader should +// allow them to set and reset a limit on the number of bytes read from the +// connection." +// "websocket.Dial doesn't limit response header length as http.Get does" +// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/2Tge6U8-QYI +// +// Example usage: +// +// func doSomething(ws *WebSocket) { +// } +// var config websocket.Config +// config.Subprotocols = []string{"base64"} +// config.MaxMessageSize = 2500 +// http.Handle("/", config.Handler(doSomething)) +// err = http.ListenAndServe(":8080", nil) + +package websocket + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" +) + +// Settings for potential WebSocket connections. Subprotocols is a list of +// supported subprotocols as in RFC 6455 section 1.9. When answering client +// requests, the first of the client's requests subprotocols that is also in +// this list (if any) will be used as the subprotocol for the connection. +// MaxMessageSize is a limit on buffering messages. +type Config struct { + Subprotocols []string + MaxMessageSize int +} + +// Representation of a WebSocket frame. The Payload is always without masking. +type Frame struct { + Fin bool + Opcode byte + Payload []byte +} + +// Return true iff the frame's opcode says it is a control frame. +func (frame *Frame) IsControl() bool { + return (frame.Opcode & 0x08) != 0 +} + +// Representation of a WebSocket message. The Payload is always without masking. +type Message struct { + Opcode byte + Payload []byte +} + +// A WebSocket connection after hijacking from HTTP. +type WebSocket struct { + // Conn and ReadWriter from http.ResponseWriter.Hijack. + Conn net.Conn + Bufrw *bufio.ReadWriter + // Whether we are a client or a server has implications for masking. + IsClient bool + // Set from a parent Config. + MaxMessageSize int + // The single selected subprotocol after negotiation, or "". + Subprotocol string + // Buffer for message payloads, which may be interrupted by control + // messages. + messageBuf bytes.Buffer +} + +func applyMask(payload []byte, maskKey [4]byte) { + for i := 0; i < len(payload); i++ { + payload[i] = payload[i] ^ maskKey[i%4] + } +} + +func (ws *WebSocket) maxMessageSize() int { + if ws.MaxMessageSize == 0 { + return 64000 + } + return ws.MaxMessageSize +} + +// Read a single frame from the WebSocket. +func (ws *WebSocket) ReadFrame() (frame Frame, err error) { + var b byte + err = binary.Read(ws.Bufrw, binary.BigEndian, &b) + if err != nil { + return + } + frame.Fin = (b & 0x80) != 0 + frame.Opcode = b & 0x0f + err = binary.Read(ws.Bufrw, binary.BigEndian, &b) + if err != nil { + return + } + masked := (b & 0x80) != 0 + + payloadLen := uint64(b & 0x7f) + if payloadLen == 126 { + var short uint16 + err = binary.Read(ws.Bufrw, binary.BigEndian, &short) + if err != nil { + return + } + payloadLen = uint64(short) + } else if payloadLen == 127 { + var long uint64 + err = binary.Read(ws.Bufrw, binary.BigEndian, &long) + if err != nil { + return + } + payloadLen = long + } + if payloadLen > uint64(ws.maxMessageSize()) { + err = errors.New(fmt.Sprintf("frame payload length of %d exceeds maximum of %d", payloadLen, ws.MaxMessageSize)) + return + } + + maskKey := [4]byte{} + if masked { + if ws.IsClient { + err = errors.New("client got masked frame") + return + } + err = binary.Read(ws.Bufrw, binary.BigEndian, &maskKey) + if err != nil { + return + } + } else { + if !ws.IsClient { + err = errors.New("server got unmasked frame") + return + } + } + + frame.Payload = make([]byte, payloadLen) + _, err = io.ReadFull(ws.Bufrw, frame.Payload) + if err != nil { + return + } + if masked { + applyMask(frame.Payload, maskKey) + } + + return frame, nil +} + +// Read a single message from the WebSocket. Multiple fragmented frames are +// combined into a single message before being returned. Non-control messages +// may be interrupted by control frames. The control frames are returned as +// individual messages before the message that they interrupt. +func (ws *WebSocket) ReadMessage() (message Message, err error) { + var opcode byte = 0 + for { + var frame Frame + frame, err = ws.ReadFrame() + if err != nil { + return + } + if frame.IsControl() { + if !frame.Fin { + err = errors.New("control frame has fin bit unset") + return + } + message.Opcode = frame.Opcode + message.Payload = frame.Payload + return message, nil + } + + if opcode == 0 { + if frame.Opcode == 0 { + err = errors.New("first frame has opcode 0") + return + } + opcode = frame.Opcode + } else { + if frame.Opcode != 0 { + err = errors.New(fmt.Sprintf("non-first frame has nonzero opcode %d", frame.Opcode)) + return + } + } + if ws.messageBuf.Len()+len(frame.Payload) > ws.MaxMessageSize { + err = errors.New(fmt.Sprintf("message payload length of %d exceeds maximum of %d", + ws.messageBuf.Len()+len(frame.Payload), ws.MaxMessageSize)) + return + } + ws.messageBuf.Write(frame.Payload) + if frame.Fin { + break + } + } + message.Opcode = opcode + message.Payload = ws.messageBuf.Bytes() + ws.messageBuf.Reset() + + return message, nil +} + +// Write a single frame to the WebSocket stream. Destructively masks payload in +// place if ws.IsClient. Frames are always unfragmented. +func (ws *WebSocket) WriteFrame(opcode byte, payload []byte) (err error) { + if opcode >= 16 { + err = errors.New(fmt.Sprintf("opcode %d is >= 16", opcode)) + return + } + ws.Bufrw.WriteByte(0x80 | opcode) + + var maskBit byte + var maskKey [4]byte + if ws.IsClient { + _, err = io.ReadFull(rand.Reader, maskKey[:]) + if err != nil { + return + } + applyMask(payload, maskKey) + maskBit = 0x80 + } else { + maskBit = 0x00 + } + + if len(payload) < 126 { + ws.Bufrw.WriteByte(maskBit | byte(len(payload))) + } else if len(payload) <= 0xffff { + ws.Bufrw.WriteByte(maskBit | 126) + binary.Write(ws.Bufrw, binary.BigEndian, uint16(len(payload))) + } else { + ws.Bufrw.WriteByte(maskBit | 127) + binary.Write(ws.Bufrw, binary.BigEndian, uint64(len(payload))) + } + + if ws.IsClient { + _, err = ws.Bufrw.Write(maskKey[:]) + if err != nil { + return + } + } + _, err = ws.Bufrw.Write(payload) + if err != nil { + return + } + + ws.Bufrw.Flush() + + return +} + +// Write a single message to the WebSocket stream. Destructively masks payload +// in place if ws.IsClient. Messages are always sent as a single unfragmented +// frame. +func (ws *WebSocket) WriteMessage(opcode byte, payload []byte) (err error) { + return ws.WriteFrame(opcode, payload) +} + +// Split a string on commas and trim whitespace. +func commaSplit(s string) []string { + var result []string + if strings.TrimSpace(s) == "" { + return result + } + for _, e := range strings.Split(s, ",") { + result = append(result, strings.TrimSpace(e)) + } + return result +} + +// Returns true iff one of the strings in haystack is needle (case-insensitive). +func containsCase(haystack []string, needle string) bool { + for _, e := range haystack { + if strings.ToLower(e) == strings.ToLower(needle) { + return true + } + } + return false +} + +// One-step SHA-1 hash of a string. +func sha1Hash(data string) []byte { + h := sha1.New() + h.Write([]byte(data)) + return h.Sum(nil) +} + +func httpError(w http.ResponseWriter, bufrw *bufio.ReadWriter, code int) { + w.Header().Set("Connection", "close") + bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", code, http.StatusText(code))) + w.Header().Write(bufrw) + bufrw.WriteString("\r\n") + bufrw.Flush() +} + +// An implementation of http.Handler with a Config. The ServeHTTP function calls +// Callback assuming WebSocket HTTP negotiation is successful. +type HTTPHandler struct { + Config *Config + Callback func(*WebSocket) +} + +// Implements the http.Handler interface. +func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + // See RFC 6455 section 4.2.1 for this sequence of checks. + + // 1. An HTTP/1.1 or higher GET request, including a "Request-URI"... + if req.Method != "GET" { + httpError(w, bufrw, http.StatusMethodNotAllowed) + return + } + if req.URL.Path != "/" { + httpError(w, bufrw, http.StatusNotFound) + return + } + // 2. A |Host| header field containing the server's authority. + // We deliberately skip this test. + // 3. An |Upgrade| header field containing the value "websocket", + // treated as an ASCII case-insensitive value. + if !containsCase(commaSplit(req.Header.Get("Upgrade")), "websocket") { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 4. A |Connection| header field that includes the token "Upgrade", + // treated as an ASCII case-insensitive value. + if !containsCase(commaSplit(req.Header.Get("Connection")), "Upgrade") { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 5. A |Sec-WebSocket-Key| header field with a base64-encoded value + // that, when decoded, is 16 bytes in length. + websocketKey := req.Header.Get("Sec-WebSocket-Key") + key, err := base64.StdEncoding.DecodeString(websocketKey) + if err != nil || len(key) != 16 { + httpError(w, bufrw, http.StatusBadRequest) + return + } + // 6. A |Sec-WebSocket-Version| header field, with a value of 13. + // We also allow 8 from draft-ietf-hybi-thewebsocketprotocol-10. + var knownVersions = []string{"8", "13"} + websocketVersion := req.Header.Get("Sec-WebSocket-Version") + if !containsCase(knownVersions, websocketVersion) { + // "If this version does not match a version understood by the + // server, the server MUST abort the WebSocket handshake + // described in this section and instead send an appropriate + // HTTP error code (such as 426 Upgrade Required) and a + // |Sec-WebSocket-Version| header field indicating the + // version(s) the server is capable of understanding." + w.Header().Set("Sec-WebSocket-Version", strings.Join(knownVersions, ", ")) + httpError(w, bufrw, 426) + return + } + // 7. Optionally, an |Origin| header field. + // 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list of + // values indicating which protocols the client would like to speak, ordered + // by preference. + clientProtocols := commaSplit(req.Header.Get("Sec-WebSocket-Protocol")) + // 9. Optionally, a |Sec-WebSocket-Extensions| header field... + // 10. Optionally, other header fields... + + var ws WebSocket + ws.Conn = conn + ws.Bufrw = bufrw + ws.IsClient = false + ws.MaxMessageSize = handler.Config.MaxMessageSize + + // See RFC 6455 section 4.2.2, item 5 for these steps. + + // 1. A Status-Line with a 101 response code as per RFC 2616. + bufrw.WriteString(fmt.Sprintf("HTTP/1.0 %d %s\r\n", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))) + // 2. An |Upgrade| header field with value "websocket" as per RFC 2616. + w.Header().Set("Upgrade", "websocket") + // 3. A |Connection| header field with value "Upgrade". + w.Header().Set("Connection", "Upgrade") + // 4. A |Sec-WebSocket-Accept| header field. The value of this header + // field is constructed by concatenating /key/, defined above in step 4 + // in Section 4.2.2, with the string + // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + // concatenated value to obtain a 20-byte value and base64-encoding (see + // Section 4 of [RFC4648]) this 20-byte hash. + const magicGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + acceptKey := base64.StdEncoding.EncodeToString(sha1Hash(websocketKey + magicGUID)) + w.Header().Set("Sec-WebSocket-Accept", acceptKey) + // 5. Optionally, a |Sec-WebSocket-Protocol| header field, with a value + // /subprotocol/ as defined in step 4 in Section 4.2.2. + for _, clientProto := range clientProtocols { + for _, serverProto := range handler.Config.Subprotocols { + if clientProto == serverProto { + ws.Subprotocol = clientProto + w.Header().Set("Sec-WebSocket-Protocol", clientProto) + break + } + } + } + // 6. Optionally, a |Sec-WebSocket-Extensions| header field... + w.Header().Write(bufrw) + bufrw.WriteString("\r\n") + bufrw.Flush() + + // Call the WebSocket-specific handler. + handler.Callback(&ws) +} + +// Return an http.Handler with the given callback function. +func (config *Config) Handler(callback func(*WebSocket)) http.Handler { + return &HTTPHandler{config, callback} +}
tor-commits@lists.torproject.org