commit 30b5ef8a9e9c7a5b306e9285d1a8db323f8f22b2 Author: Arlo Breault arlolra@gmail.com Date: Wed Nov 20 19:33:28 2019 -0500
Use gorilla websocket in proxy-go too
Trac: 32465 --- common/websocketconn/websocketconn.go | 89 +++++++++++++++++++++++++++ common/websocketconn/websocketconn_test.go | 30 +++++++++ proxy-go/proxy-go_test.go | 19 ------ proxy-go/snowflake.go | 25 ++------ server/server.go | 99 ++---------------------------- 5 files changed, 128 insertions(+), 134 deletions(-)
diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go new file mode 100644 index 0000000..399cbaa --- /dev/null +++ b/common/websocketconn/websocketconn.go @@ -0,0 +1,89 @@ +package websocketconn + +import ( + "io" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// An abstraction that makes an underlying WebSocket connection look like an +// io.ReadWriteCloser. +type WebSocketConn struct { + Ws *websocket.Conn + r io.Reader +} + +// Implements io.Reader. +func (conn *WebSocketConn) Read(b []byte) (n int, err error) { + var opCode int + if conn.r == nil { + // New message + var r io.Reader + for { + if opCode, r, err = conn.Ws.NextReader(); err != nil { + return + } + if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage { + continue + } + + conn.r = r + break + } + } + + n, err = conn.r.Read(b) + if err == io.EOF { + // Message finished + conn.r = nil + err = nil + } + return +} + +// Implements io.Writer. +func (conn *WebSocketConn) Write(b []byte) (n int, err error) { + var w io.WriteCloser + if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil { + return + } + if n, err = w.Write(b); err != nil { + return + } + err = w.Close() + return +} + +// Implements io.Closer. +func (conn *WebSocketConn) Close() error { + // Ignore any error in trying to write a Close frame. + _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) + return conn.Ws.Close() +} + +// Create a new WebSocketConn. +func NewWebSocketConn(ws *websocket.Conn) WebSocketConn { + var conn WebSocketConn + conn.Ws = ws + return conn +} + +// Copy from WebSocket to socket and vice versa. +func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) { + var wg sync.WaitGroup + copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) { + defer wg.Done() + if _, err := io.Copy(dst, src); err != nil { + log.Printf("io.Copy inside CopyLoop generated an error: %v", err) + } + dst.Close() + src.Close() + } + wg.Add(2) + go copyer(c1, c2) + go copyer(c2, c1) + wg.Wait() +} diff --git a/common/websocketconn/websocketconn_test.go b/common/websocketconn/websocketconn_test.go new file mode 100644 index 0000000..3293165 --- /dev/null +++ b/common/websocketconn/websocketconn_test.go @@ -0,0 +1,30 @@ +package websocketconn + +import ( + "net" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestWebsocketConn(t *testing.T) { + Convey("CopyLoop", t, func() { + c1, s1 := net.Pipe() + c2, s2 := net.Pipe() + go CopyLoop(s1, s2) + go func() { + bytes := []byte("Hello!") + c1.Write(bytes) + }() + bytes := make([]byte, 6) + n, err := c2.Read(bytes) + So(n, ShouldEqual, 6) + So(err, ShouldEqual, nil) + So(bytes, ShouldResemble, []byte("Hello!")) + s1.Close() + + // Check that copy loop has closed other connection + _, err = s2.Write(bytes) + So(err, ShouldNotBeNil) + }) +} diff --git a/proxy-go/proxy-go_test.go b/proxy-go/proxy-go_test.go index ebe4381..538957b 100644 --- a/proxy-go/proxy-go_test.go +++ b/proxy-go/proxy-go_test.go @@ -374,23 +374,4 @@ func TestUtilityFuncs(t *testing.T) { sid2 := genSessionID() So(sid1, ShouldNotEqual, sid2) }) - Convey("CopyLoop", t, func() { - c1, s1 := net.Pipe() - c2, s2 := net.Pipe() - go CopyLoop(s1, s2) - go func() { - bytes := []byte("Hello!") - c1.Write(bytes) - }() - bytes := make([]byte, 6) - n, err := c2.Read(bytes) - So(n, ShouldEqual, 6) - So(err, ShouldEqual, nil) - So(bytes, ShouldResemble, []byte("Hello!")) - s1.Close() - - //Check that copy loop has closed other connection - _, err = s2.Write(bytes) - So(err, ShouldNotBeNil) - }) } diff --git a/proxy-go/snowflake.go b/proxy-go/snowflake.go index c4b2f0b..0e14eb2 100644 --- a/proxy-go/snowflake.go +++ b/proxy-go/snowflake.go @@ -21,8 +21,9 @@ import (
"git.torproject.org/pluggable-transports/snowflake.git/common/messages" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" + "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" + "github.com/gorilla/websocket" "github.com/pion/webrtc" - "golang.org/x/net/websocket" )
const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/" @@ -239,22 +240,6 @@ func (b *Broker) sendAnswer(sid string, pc *webrtc.PeerConnection) error { return nil }
-func CopyLoop(c1 net.Conn, c2 net.Conn) { - var wg sync.WaitGroup - copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) { - defer wg.Done() - if _, err := io.Copy(dst, src); err != nil { - log.Printf("io.Copy inside CopyLoop generated an error: %v", err) - } - dst.Close() - src.Close() - } - wg.Add(2) - go copyer(c1, c2) - go copyer(c2, c1) - wg.Wait() -} - // We pass conn.RemoteAddr() as an additional parameter, rather than calling // conn.RemoteAddr() inside this function, as a workaround for a hang that // otherwise occurs inside of conn.pc.RemoteDescription() (called by @@ -279,15 +264,15 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { log.Printf("no remote address given in websocket") }
- wsConn, err := websocket.Dial(u.String(), "", relayURL) + ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { log.Printf("error dialing relay: %s", err) return } + wsConn := websocketconn.NewWebSocketConn(ws) log.Printf("connected to relay") defer wsConn.Close() - wsConn.PayloadType = websocket.BinaryFrame - CopyLoop(conn, wsConn) + websocketconn.CopyLoop(conn, &wsConn) log.Printf("datachannelHandler ends") }
diff --git a/server/server.go b/server/server.go index ce804fc..d950ddc 100644 --- a/server/server.go +++ b/server/server.go @@ -15,12 +15,12 @@ import ( "os/signal" "path/filepath" "strings" - "sync" "syscall" "time"
pt "git.torproject.org/pluggable-transports/goptlib.git" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" + "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" "github.com/gorilla/websocket" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" @@ -50,97 +50,6 @@ additional HTTP listener on port 80 to work with ACME. flag.PrintDefaults() }
-// An abstraction that makes an underlying WebSocket connection look like an -// io.ReadWriteCloser. -type webSocketConn struct { - Ws *websocket.Conn - r io.Reader -} - -// Implements io.Reader. -func (conn *webSocketConn) Read(b []byte) (n int, err error) { - var opCode int - if conn.r == nil { - // New message - var r io.Reader - for { - if opCode, r, err = conn.Ws.NextReader(); err != nil { - return - } - if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage { - continue - } - - conn.r = r - break - } - } - - n, err = conn.r.Read(b) - if err == io.EOF { - // Message finished - conn.r = nil - err = nil - } - return -} - -// Implements io.Writer. -func (conn *webSocketConn) Write(b []byte) (n int, err error) { - var w io.WriteCloser - if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil { - return - } - if n, err = w.Write(b); err != nil { - return - } - err = w.Close() - return -} - -// Implements io.Closer. -func (conn *webSocketConn) Close() error { - // Ignore any error in trying to write a Close frame. - _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) - return conn.Ws.Close() -} - -// Create a new webSocketConn. -func newWebSocketConn(ws *websocket.Conn) webSocketConn { - var conn webSocketConn - conn.Ws = ws - return conn -} - -// Copy from WebSocket to socket and vice versa. -func proxy(local *net.TCPConn, conn *webSocketConn) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - if _, err := io.Copy(conn, local); err != nil { - log.Printf("error copying ORPort to WebSocket %v", err) - } - if err := local.CloseRead(); err != nil { - log.Printf("error closing read after copying ORPort to WebSocket %v", err) - } - conn.Close() - wg.Done() - }() - go func() { - if _, err := io.Copy(local, conn); err != nil { - log.Printf("error copying WebSocket to ORPort") - } - if err := local.CloseWrite(); err != nil { - log.Printf("error closing write after copying WebSocket to ORPort %v", err) - } - conn.Close() - wg.Done() - }() - - wg.Wait() -} - // Return an address string suitable to pass into pt.DialOr. func clientAddr(clientIPParam string) string { if clientIPParam == "" { @@ -166,8 +75,8 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return }
- conn := newWebSocketConn(ws) - defer conn.Close() + wsConn := websocketconn.NewWebSocketConn(ws) + defer wsConn.Close()
// Pass the address of client as the remote address of incoming connection clientIPParam := r.URL.Query().Get("client_ip") @@ -184,7 +93,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer or.Close()
- proxy(or, &conn) + websocketconn.CopyLoop(or, &wsConn) }
func initServer(addr *net.TCPAddr,