commit e53705519e1aace78da1fad441c49114c90b858b Author: David Fifield david@bamsoftware.com Date: Sat Mar 26 12:05:06 2016 -0700
Add a copy of websocket-server in the server directory.
From commit ea3ee6b5cf035a9465c02b4a2b38882c9c6a7050 in https://git.torproject.org/pluggable-transports/websocket.git --- server/server.go | 291 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 291 insertions(+)
diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..6fb5a0b --- /dev/null +++ b/server/server.go @@ -0,0 +1,291 @@ +// Tor websocket server transport plugin. +// +// Usage in torrc: +// ExtORPort 6669 +// ServerTransportPlugin websocket exec ./websocket-server --port 9901 +package main + +import ( + "encoding/base64" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "../websocket" + + "git.torproject.org/pluggable-transports/goptlib.git" +) + +const ptMethodName = "websocket" +const requestTimeout = 10 * time.Second + +// "4/3+1" accounts for possible base64 encoding. +const maxMessageSize = 64*1024*4/3 + 1 + +var logFile = os.Stderr + +var ptInfo pt.ServerInfo + +// When a connection handler starts, +1 is written to this channel; when it +// ends, -1 is written. +var handlerChan = make(chan int) + +func usage() { + fmt.Printf("Usage: %s [OPTIONS]\n", os.Args[0]) + fmt.Printf("WebSocket server pluggable transport for Tor.\n") + fmt.Printf("Works only as a managed proxy.\n") + fmt.Printf("\n") + fmt.Printf(" -h, --help show this help.\n") + fmt.Printf(" --log FILE log messages to FILE (default stderr).\n") + fmt.Printf(" --port PORT listen on PORT (overrides Tor's requested port).\n") +} + +// An abstraction that makes an underlying WebSocket connection look like an +// io.ReadWriteCloser. It internally takes care of things like base64 encoding +// and decoding. +type webSocketConn struct { + Ws *websocket.WebSocket + Base64 bool + messageBuf []byte +} + +// Implements io.Reader. +func (conn *webSocketConn) Read(b []byte) (n int, err error) { + for len(conn.messageBuf) == 0 { + var m websocket.Message + m, err = conn.Ws.ReadMessage() + if err != nil { + return + } + if m.Opcode == 8 { + err = io.EOF + return + } + if conn.Base64 { + if m.Opcode != 1 { + err = errors.New(fmt.Sprintf("got non-text opcode %d with the base64 subprotocol", m.Opcode)) + return + } + conn.messageBuf = make([]byte, base64.StdEncoding.DecodedLen(len(m.Payload))) + var num int + num, err = base64.StdEncoding.Decode(conn.messageBuf, m.Payload) + if err != nil { + return + } + conn.messageBuf = conn.messageBuf[:num] + } else { + if m.Opcode != 2 { + err = errors.New(fmt.Sprintf("got non-binary opcode %d with no subprotocol", m.Opcode)) + return + } + conn.messageBuf = m.Payload + } + } + + n = copy(b, conn.messageBuf) + conn.messageBuf = conn.messageBuf[n:] + + return +} + +// Implements io.Writer. +func (conn *webSocketConn) Write(b []byte) (n int, err error) { + if conn.Base64 { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(b))) + base64.StdEncoding.Encode(buf, b) + err = conn.Ws.WriteMessage(1, buf) + if err != nil { + return + } + n = len(b) + } else { + err = conn.Ws.WriteMessage(2, b) + n = len(b) + } + return +} + +// Implements io.Closer. +func (conn *webSocketConn) Close() error { + // Ignore any error in trying to write a Close frame. + _ = conn.Ws.WriteFrame(8, nil) + return conn.Ws.Conn.Close() +} + +// Create a new webSocketConn. +func newWebSocketConn(ws *websocket.WebSocket) webSocketConn { + var conn webSocketConn + conn.Ws = ws + conn.Base64 = (ws.Subprotocol == "base64") + return conn +} + +// Copy from WebSocket to socket and vice versa. +func proxy(local *net.TCPConn, conn *webSocketConn) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + _, err := io.Copy(conn, local) + if err != nil { + log.Printf("error copying ORPort to WebSocket") + } + local.CloseRead() + conn.Close() + wg.Done() + }() + go func() { + _, err := io.Copy(local, conn) + if err != nil { + log.Printf("error copying WebSocket to ORPort") + } + local.CloseWrite() + conn.Close() + wg.Done() + }() + + wg.Wait() +} + +func webSocketHandler(ws *websocket.WebSocket) { + // Undo timeouts on HTTP request handling. + ws.Conn.SetDeadline(time.Time{}) + conn := newWebSocketConn(ws) + defer conn.Close() + + handlerChan <- 1 + defer func() { + handlerChan <- -1 + }() + + or, err := pt.DialOr(&ptInfo, ws.Conn.RemoteAddr().String(), ptMethodName) + if err != nil { + log.Printf("Failed to connect to ORPort: " + err.Error()) + return + } + defer or.Close() + + proxy(or, &conn) +} + +func startListener(addr *net.TCPAddr) (*net.TCPListener, error) { + ln, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + go func() { + defer ln.Close() + var config websocket.Config + config.Subprotocols = []string{"base64"} + config.MaxMessageSize = maxMessageSize + s := &http.Server{ + Handler: config.Handler(webSocketHandler), + ReadTimeout: requestTimeout, + } + err = s.Serve(ln) + if err != nil { + log.Printf("http.Serve: " + err.Error()) + } + }() + return ln, nil +} + +func main() { + var logFilename string + var port int + + flag.Usage = usage + flag.StringVar(&logFilename, "log", "", "log file to write to") + flag.IntVar(&port, "port", 0, "port to listen on if unspecified by Tor") + flag.Parse() + + if logFilename != "" { + f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't open log file %q: %s.\n", logFilename, err.Error()) + os.Exit(1) + } + log.SetOutput(f) + } + + log.SetFlags(log.LstdFlags | log.LUTC) + log.Printf("starting") + var err error + ptInfo, err = pt.ServerSetup(nil) + if err != nil { + log.Printf("error in setup: %s", err) + os.Exit(1) + } + + listeners := make([]*net.TCPListener, 0) + for _, bindaddr := range ptInfo.Bindaddrs { + // Override tor's requested port (which is 0 if this transport + // has not been run before) with the one requested by the --port + // option. + if port != 0 { + bindaddr.Addr.Port = port + } + + switch bindaddr.MethodName { + case ptMethodName: + ln, err := startListener(bindaddr.Addr) + if err != nil { + pt.SmethodError(bindaddr.MethodName, err.Error()) + break + } + pt.Smethod(bindaddr.MethodName, ln.Addr()) + log.Printf("listening on %s", ln.Addr().String()) + listeners = append(listeners, ln) + default: + pt.SmethodError(bindaddr.MethodName, "no such method") + } + } + pt.SmethodsDone() + + var numHandlers int = 0 + var sig os.Signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // wait for first signal + sig = nil + for sig == nil { + select { + case n := <-handlerChan: + numHandlers += n + case sig = <-sigChan: + } + } + log.Printf("Got first signal %q with %d running handlers.", sig, numHandlers) + for _, ln := range listeners { + ln.Close() + } + + if sig == syscall.SIGTERM { + log.Printf("Caught signal %q, exiting.", sig) + return + } + + // wait for second signal or no more handlers + sig = nil + for sig == nil && numHandlers != 0 { + select { + case n := <-handlerChan: + numHandlers += n + log.Printf("%d remaining handlers.", numHandlers) + case sig = <-sigChan: + } + } + if sig != nil { + log.Printf("Got second signal %q with %d running handlers.", sig, numHandlers) + } +}
tor-commits@lists.torproject.org