[tor-commits] [snowflake/master] Stop using custom websocket library in server

arlo at torproject.org arlo at torproject.org
Mon Nov 11 22:20:29 UTC 2019


commit c417fd5599c5d39951c606856c69a9f05941afd3
Author: Arlo Breault <arlolra at gmail.com>
Date:   Wed Oct 16 21:00:13 2019 -0400

    Stop using custom websocket library in server
    
    Trac: 31028
---
 .travis.yml      |  2 +-
 server/server.go | 88 +++++++++++++++++++++++++++++++++-----------------------
 2 files changed, 53 insertions(+), 37 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 9ed48bb..2f02f74 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -26,8 +26,8 @@ install:
     - go get -u github.com/keroserene/go-webrtc
     - go get -u github.com/pion/webrtc
     - go get -u github.com/dchest/uniuri
+    - go get -u github.com/gorilla/websocket
     - go get -u git.torproject.org/pluggable-transports/goptlib.git
-    - go get -u git.torproject.org/pluggable-transports/websocket.git/websocket
     - go get -u google.golang.org/appengine
     - go get -u golang.org/x/crypto/acme/autocert
     - go get -u golang.org/x/net/http2
diff --git a/server/server.go b/server/server.go
index b1b566a..d111fce 100644
--- a/server/server.go
+++ b/server/server.go
@@ -21,7 +21,7 @@ import (
 
 	pt "git.torproject.org/pluggable-transports/goptlib.git"
 	"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
-	"git.torproject.org/pluggable-transports/websocket.git/websocket"
+	"github.com/gorilla/websocket"
 	"golang.org/x/crypto/acme/autocert"
 	"golang.org/x/net/http2"
 )
@@ -53,50 +53,60 @@ additional HTTP listener on port 80 to work with ACME.
 // An abstraction that makes an underlying WebSocket connection look like an
 // io.ReadWriteCloser.
 type webSocketConn struct {
-	Ws         *websocket.WebSocket
-	messageBuf []byte
+	Ws *websocket.Conn
+	r  io.Reader
 }
 
 // Implements io.Reader.
 func (conn *webSocketConn) Read(b []byte) (n int, err error) {
-	for len(conn.messageBuf) == 0 {
-		var m websocket.Message
-		m, err = conn.Ws.ReadMessage()
-		if err != nil {
-			return
-		}
-		if m.Opcode == 8 {
-			err = io.EOF
-			return
-		}
-		if m.Opcode != 2 {
-			err = fmt.Errorf("got non-binary opcode %d", m.Opcode)
-			return
+	var opCode int
+	if conn.r == nil {
+		// New message
+		var r io.Reader
+		for {
+			if opCode, r, err = conn.Ws.NextReader(); err != nil {
+				return
+			}
+			if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
+				continue
+			}
+
+			conn.r = r
+			break
 		}
-		conn.messageBuf = m.Payload
 	}
 
-	n = copy(b, conn.messageBuf)
-	conn.messageBuf = conn.messageBuf[n:]
-
+	n, err = conn.r.Read(b)
+	if err != nil {
+		if err == io.EOF {
+			// Message finished
+			conn.r = nil
+			err = nil
+		}
+	}
 	return
 }
 
 // Implements io.Writer.
-func (conn *webSocketConn) Write(b []byte) (int, error) {
-	err := conn.Ws.WriteMessage(2, b)
-	return len(b), err
+func (conn *webSocketConn) Write(b []byte) (n int, err error) {
+	var w io.WriteCloser
+	if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
+		return
+	}
+	if n, err = w.Write(b); err != nil {
+		return
+	}
+	err = w.Close()
+	return
 }
 
 // Implements io.Closer.
 func (conn *webSocketConn) Close() error {
-	// Ignore any error in trying to write a Close frame.
-	_ = conn.Ws.WriteFrame(8, nil)
-	return conn.Ws.Conn.Close()
+	return conn.Ws.Close()
 }
 
 // Create a new webSocketConn.
-func newWebSocketConn(ws *websocket.WebSocket) webSocketConn {
+func newWebSocketConn(ws *websocket.Conn) webSocketConn {
 	var conn webSocketConn
 	conn.Ws = ws
 	return conn
@@ -145,16 +155,22 @@ func clientAddr(clientIPParam string) string {
 	return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
 }
 
-func webSocketHandler(ws *websocket.WebSocket) {
-	// Undo timeouts on HTTP request handling.
-	if err := ws.Conn.SetDeadline(time.Time{}); err != nil {
-		log.Printf("unable to set deadlines with error: %v", err)
+var upgrader = websocket.Upgrader{}
+
+type HTTPHandler struct{}
+
+func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ws, err := upgrader.Upgrade(w, r, nil)
+	if err != nil {
+		log.Println(err)
+		return
 	}
+
 	conn := newWebSocketConn(ws)
 	defer conn.Close()
 
 	// Pass the address of client as the remote address of incoming connection
-	clientIPParam := ws.Request().URL.Query().Get("client_ip")
+	clientIPParam := r.URL.Query().Get("client_ip")
 	addr := clientAddr(clientIPParam)
 	if addr == "" {
 		statsChannel <- false
@@ -162,7 +178,6 @@ func webSocketHandler(ws *websocket.WebSocket) {
 		statsChannel <- true
 	}
 	or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
-
 	if err != nil {
 		log.Printf("failed to connect to ORPort: %s", err)
 		return
@@ -185,11 +200,12 @@ func initServer(addr *net.TCPAddr,
 		return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
 	}
 
-	var config websocket.Config
-	config.MaxMessageSize = maxMessageSize
+	upgrader.CheckOrigin = func(r *http.Request) bool { return true }
+
+	var handler HTTPHandler
 	server := &http.Server{
 		Addr:        addr.String(),
-		Handler:     config.Handler(webSocketHandler),
+		Handler:     &handler,
 		ReadTimeout: requestTimeout,
 	}
 	// We need to override server.TLSConfig.GetCertificate--but first





More information about the tor-commits mailing list