commit c417fd5599c5d39951c606856c69a9f05941afd3 Author: Arlo Breault arlolra@gmail.com Date: Wed Oct 16 21:00:13 2019 -0400
Stop using custom websocket library in server
Trac: 31028 --- .travis.yml | 2 +- server/server.go | 88 +++++++++++++++++++++++++++++++++----------------------- 2 files changed, 53 insertions(+), 37 deletions(-)
diff --git a/.travis.yml b/.travis.yml index 9ed48bb..2f02f74 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,8 +26,8 @@ install: - go get -u github.com/keroserene/go-webrtc - go get -u github.com/pion/webrtc - go get -u github.com/dchest/uniuri + - go get -u github.com/gorilla/websocket - go get -u git.torproject.org/pluggable-transports/goptlib.git - - go get -u git.torproject.org/pluggable-transports/websocket.git/websocket - go get -u google.golang.org/appengine - go get -u golang.org/x/crypto/acme/autocert - go get -u golang.org/x/net/http2 diff --git a/server/server.go b/server/server.go index b1b566a..d111fce 100644 --- a/server/server.go +++ b/server/server.go @@ -21,7 +21,7 @@ import (
pt "git.torproject.org/pluggable-transports/goptlib.git" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" - "git.torproject.org/pluggable-transports/websocket.git/websocket" + "github.com/gorilla/websocket" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" ) @@ -53,50 +53,60 @@ additional HTTP listener on port 80 to work with ACME. // An abstraction that makes an underlying WebSocket connection look like an // io.ReadWriteCloser. type webSocketConn struct { - Ws *websocket.WebSocket - messageBuf []byte + Ws *websocket.Conn + r io.Reader }
// Implements io.Reader. func (conn *webSocketConn) Read(b []byte) (n int, err error) { - for len(conn.messageBuf) == 0 { - var m websocket.Message - m, err = conn.Ws.ReadMessage() - if err != nil { - return - } - if m.Opcode == 8 { - err = io.EOF - return - } - if m.Opcode != 2 { - err = fmt.Errorf("got non-binary opcode %d", m.Opcode) - return + var opCode int + if conn.r == nil { + // New message + var r io.Reader + for { + if opCode, r, err = conn.Ws.NextReader(); err != nil { + return + } + if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage { + continue + } + + conn.r = r + break } - conn.messageBuf = m.Payload }
- n = copy(b, conn.messageBuf) - conn.messageBuf = conn.messageBuf[n:] - + n, err = conn.r.Read(b) + if err != nil { + if err == io.EOF { + // Message finished + conn.r = nil + err = nil + } + } return }
// Implements io.Writer. -func (conn *webSocketConn) Write(b []byte) (int, error) { - err := conn.Ws.WriteMessage(2, b) - return len(b), err +func (conn *webSocketConn) Write(b []byte) (n int, err error) { + var w io.WriteCloser + if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil { + return + } + if n, err = w.Write(b); err != nil { + return + } + err = w.Close() + return }
// Implements io.Closer. func (conn *webSocketConn) Close() error { - // Ignore any error in trying to write a Close frame. - _ = conn.Ws.WriteFrame(8, nil) - return conn.Ws.Conn.Close() + return conn.Ws.Close() }
// Create a new webSocketConn. -func newWebSocketConn(ws *websocket.WebSocket) webSocketConn { +func newWebSocketConn(ws *websocket.Conn) webSocketConn { var conn webSocketConn conn.Ws = ws return conn @@ -145,16 +155,22 @@ func clientAddr(clientIPParam string) string { return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String() }
-func webSocketHandler(ws *websocket.WebSocket) { - // Undo timeouts on HTTP request handling. - if err := ws.Conn.SetDeadline(time.Time{}); err != nil { - log.Printf("unable to set deadlines with error: %v", err) +var upgrader = websocket.Upgrader{} + +type HTTPHandler struct{} + +func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return } + conn := newWebSocketConn(ws) defer conn.Close()
// Pass the address of client as the remote address of incoming connection - clientIPParam := ws.Request().URL.Query().Get("client_ip") + clientIPParam := r.URL.Query().Get("client_ip") addr := clientAddr(clientIPParam) if addr == "" { statsChannel <- false @@ -162,7 +178,6 @@ func webSocketHandler(ws *websocket.WebSocket) { statsChannel <- true } or, err := pt.DialOr(&ptInfo, addr, ptMethodName) - if err != nil { log.Printf("failed to connect to ORPort: %s", err) return @@ -185,11 +200,12 @@ func initServer(addr *net.TCPAddr, return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port) }
- var config websocket.Config - config.MaxMessageSize = maxMessageSize + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + + var handler HTTPHandler server := &http.Server{ Addr: addr.String(), - Handler: config.Handler(webSocketHandler), + Handler: &handler, ReadTimeout: requestTimeout, } // We need to override server.TLSConfig.GetCertificate--but first