commit 50e4f4fd61596bab254cb34e850c9ae63d82f891 Author: idk hankhill19580@gmail.com Date: Mon Oct 25 22:51:40 2021 -0400
Turn the proxy code into a library
Allow other go programs to easily import the snowflake proxy library and start/stop a snowflake proxy. --- proxy/{ => lib}/proxy-go_test.go | 8 +- proxy/{ => lib}/snowflake.go | 205 ++++++++++++++++++++++----------------- proxy/{ => lib}/tokens.go | 2 +- proxy/{ => lib}/tokens_test.go | 2 +- proxy/{ => lib}/util.go | 18 +++- proxy/{ => lib}/webrtcconn.go | 2 +- proxy/main.go | 48 +++++++++ 7 files changed, 185 insertions(+), 100 deletions(-)
diff --git a/proxy/proxy-go_test.go b/proxy/lib/proxy-go_test.go similarity index 98% rename from proxy/proxy-go_test.go rename to proxy/lib/proxy-go_test.go index 6fb5a0b9..af71648 100644 --- a/proxy/proxy-go_test.go +++ b/proxy/lib/proxy-go_test.go @@ -1,4 +1,4 @@ -package main +package snowflake
import ( "bytes" @@ -365,7 +365,7 @@ func TestBrokerInteractions(t *testing.T) { b, }
- sdp := broker.pollOffer(sampleOffer) + sdp := broker.pollOffer(sampleOffer, nil) expectedSDP, _ := strconv.Unquote(sampleSDP) So(sdp.SDP, ShouldResemble, expectedSDP) }) @@ -379,7 +379,7 @@ func TestBrokerInteractions(t *testing.T) { b, }
- sdp := broker.pollOffer(sampleOffer) + sdp := broker.pollOffer(sampleOffer, nil) So(sdp, ShouldBeNil) }) Convey("sends answer to broker", func() { @@ -478,7 +478,7 @@ func TestUtilityFuncs(t *testing.T) { Convey("CopyLoop", t, func() { c1, s1 := net.Pipe() c2, s2 := net.Pipe() - go CopyLoop(s1, s2) + go copyLoop(s1, s2, nil) go func() { bytes := []byte("Hello!") c1.Write(bytes) diff --git a/proxy/snowflake.go b/proxy/lib/snowflake.go similarity index 72% rename from proxy/snowflake.go rename to proxy/lib/snowflake.go index 7d7f9a2..e35eabd 100644 --- a/proxy/snowflake.go +++ b/proxy/lib/snowflake.go @@ -1,10 +1,9 @@ -package main +package snowflake
import ( "bytes" "crypto/rand" "encoding/base64" - "flag" "fmt" "io" "io/ioutil" @@ -12,27 +11,44 @@ import ( "net" "net/http" "net/url" - "os" "strings" "sync" "time"
"git.torproject.org/pluggable-transports/snowflake.git/common/messages" - "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/util" "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" "github.com/gorilla/websocket" "github.com/pion/webrtc/v3" )
-const defaultBrokerURL = "https://snowflake-broker.torproject.net/" -const defaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe" -const defaultRelayURL = "wss://snowflake.torproject.net/" -const defaultSTUNURL = "stun:stun.stunprotocol.org:3478" +// DefaultBrokerURL is the bamsoftware.com broker, https://snowflake-broker.bamsoftware.com +// Changing this will change the default broker. The recommended way of changing +// the broker that gets used is by passing an argument to Main. +const DefaultBrokerURL = "https://snowflake-broker.bamsoftware.com/" + +// DefaultProbeURL is the torproject.org ProbeURL, https://snowflake-broker.torproject.net:8443/probe +// Changing this will change the default Probe URL. The recommended way of changing +// the probe that gets used is by passing an argument to Main. +const DefaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe" + +// DefaultRelayURL is the bamsoftware.com Websocket Relay, wss://snowflake.bamsoftware.com/ +// Changing this will change the default Relay URL. The recommended way of changing +// the relay that gets used is by passing an argument to Main. +const DefaultRelayURL = "wss://snowflake.bamsoftware.com/" + +// DefaultSTUNURL is a stunprotocol.org STUN URL. stun:stun.stunprotocol.org:3478 +// Changing this will change the default STUN URL. The recommended way of changing +// the STUN Server that gets used is by passing an argument to Main. +const DefaultSTUNURL = "stun:stun.stunprotocol.org:3478" const pollInterval = 5 * time.Second + const ( - NATUnknown = "unknown" - NATRestricted = "restricted" + // NATUnknown represents a NAT type which is unknown. + NATUnknown = "unknown" + // NATRestricted represents a restricted NAT. + NATRestricted = "restricted" + // NATUnrestricted represents an unrestricted NAT. NATUnrestricted = "unrestricted" )
@@ -43,7 +59,6 @@ const dataChannelTimeout = 20 * time.Second const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
var broker *SignalingServer -var relayURL string
var currentNATType = NATUnknown
@@ -57,6 +72,18 @@ var ( client http.Client )
+// SnowflakeProxy is a structure which is used to configure an embedded +// Snowflake in another Go application. +type SnowflakeProxy struct { + Capacity uint + StunURL string + RawBrokerURL string + KeepLocalAddresses bool + RelayURL string + LogOutput io.Writer + shutdown chan struct{} +} + // Checks whether an IP address is a remote address for the client func isRemoteAddress(ip net.IP) bool { return !(util.IsLocal(ip) || ip.IsUnspecified() || ip.IsLoopback()) @@ -81,6 +108,7 @@ func limitedRead(r io.Reader, limit int64) ([]byte, error) { return p, err }
+// SignalingServer keeps track of the SignalingServer in use by the Snowflake type SignalingServer struct { url *url.URL transport http.RoundTripper @@ -102,6 +130,7 @@ func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServe return s, nil }
+// Post sends a POST request to the SignalingServer func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
req, err := http.NewRequest("POST", path, payload) @@ -121,7 +150,7 @@ func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) { return limitedRead(resp.Body, readLimit) }
-func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription { +func (s *SignalingServer) pollOffer(sid string, shutdown chan struct{}) *webrtc.SessionDescription { brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"})
ticker := time.NewTicker(pollInterval) @@ -129,31 +158,36 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
// Run the loop once before hitting the ticker for ; true; <-ticker.C { - numClients := int((tokens.count() / 8) * 8) // Round down to 8 - body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients) - if err != nil { - log.Printf("Error encoding poll message: %s", err.Error()) + select { + case <-shutdown: return nil - } - resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body)) - if err != nil { - log.Printf("error polling broker: %s", err.Error()) - } + default: + numClients := int((tokens.count() / 8) * 8) // Round down to 8 + body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients) + if err != nil { + log.Printf("Error encoding poll message: %s", err.Error()) + return nil + } + resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body)) + if err != nil { + log.Printf("error polling broker: %s", err.Error()) + }
- offer, _, err := messages.DecodePollResponse(resp) - if err != nil { - log.Printf("Error reading broker response: %s", err.Error()) - log.Printf("body: %s", resp) - return nil - } - if offer != "" { - offer, err := util.DeserializeSessionDescription(offer) + offer, _, err := messages.DecodePollResponse(resp) if err != nil { - log.Printf("Error processing session description: %s", err.Error()) + log.Printf("Error reading broker response: %s", err.Error()) + log.Printf("body: %s", resp) return nil } - return offer + if offer != "" { + offer, err := util.DeserializeSessionDescription(offer) + if err != nil { + log.Printf("Error processing session description: %s", err.Error()) + return nil + } + return offer
+ } } } return nil @@ -192,33 +226,41 @@ func (s *SignalingServer) sendAnswer(sid string, pc *webrtc.PeerConnection) erro return nil }
-func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) { - var wg sync.WaitGroup +func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct{}) { + var once sync.Once + defer c2.Close() + defer c1.Close() + done := make(chan struct{}) copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) { - defer wg.Done() // Ignore io.ErrClosedPipe because it is likely caused by the // termination of copyer in the other direction. if _, err := io.Copy(dst, src); err != nil && err != io.ErrClosedPipe { log.Printf("io.Copy inside CopyLoop generated an error: %v", err) } - dst.Close() - src.Close() + once.Do(func() { + close(done) + }) } - wg.Add(2) + go copyer(c1, c2) go copyer(c2, c1) - wg.Wait() + + select { + case <-done: + case <-shutdown: + } + log.Println("copy loop ended") }
// We pass conn.RemoteAddr() as an additional parameter, rather than calling // conn.RemoteAddr() inside this function, as a workaround for a hang that // otherwise occurs inside of conn.pc.RemoteDescription() (called by // RemoteAddr). https://bugs.torproject.org/18628#comment:8 -func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { +func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { defer conn.Close() defer tokens.ret()
- u, err := url.Parse(relayURL) + u, err := url.Parse(sf.RelayURL) if err != nil { log.Fatalf("invalid relay url: %s", err) } @@ -241,7 +283,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { wsConn := websocketconn.New(ws) log.Printf("connected to relay") defer wsConn.Close() - CopyLoop(conn, wsConn) + copyLoop(conn, wsConn, sf.shutdown) log.Printf("datachannelHandler ends") }
@@ -249,7 +291,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { // candidates is complete and the answer is available in LocalDescription. // Installs an OnDataChannel callback that creates a webRTCConn and passes it to // datachannelHandler. -func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, +func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, config webrtc.Configuration, dataChan chan struct{}, handler func(conn *webRTCConn, remoteAddr net.Addr)) (*webrtc.PeerConnection, error) { @@ -333,7 +375,7 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
// Create a new PeerConnection. Blocks until the gathering of ICE // candidates is complete and the answer is available in LocalDescription. -func makeNewPeerConnection(config webrtc.Configuration, +func (sf *SnowflakeProxy) makeNewPeerConnection(config webrtc.Configuration, dataChan chan struct{}) (*webrtc.PeerConnection, error) {
pc, err := webrtc.NewPeerConnection(config) @@ -383,15 +425,15 @@ func makeNewPeerConnection(config webrtc.Configuration, return pc, nil }
-func runSession(sid string) { - offer := broker.pollOffer(sid) +func (sf *SnowflakeProxy) runSession(sid string) { + offer := broker.pollOffer(sid, sf.shutdown) if offer == nil { log.Printf("bad offer from broker") tokens.ret() return } dataChan := make(chan struct{}) - pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler) + pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, sf.datachannelHandler) if err != nil { log.Printf("error making WebRTC connection: %s", err) tokens.ret() @@ -421,53 +463,28 @@ func runSession(sid string) { } }
-func main() { - var capacity uint - var stunURL string - var logFilename string - var rawBrokerURL string - var unsafeLogging bool - var keepLocalAddresses bool - - flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients") - flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL") - flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL") - flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL") - flag.StringVar(&logFilename, "log", "", "log filename") - flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed") - flag.BoolVar(&keepLocalAddresses, "keep-local-addresses", false, "keep local LAN address ICE candidates") - flag.Parse() - - var logOutput io.Writer = os.Stderr +// Start configures and starts a Snowflake, fully formed and special. In the +// case of an empty map, defaults are configured automatically and can be +// found in the GoDoc and in main.go +func (sf *SnowflakeProxy) Start() { + + sf.shutdown = make(chan struct{}) + log.SetFlags(log.LstdFlags | log.LUTC) - if logFilename != "" { - f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err != nil { - log.Fatal(err) - } - defer f.Close() - logOutput = io.MultiWriter(os.Stderr, f) - } - if unsafeLogging { - log.SetOutput(logOutput) - } else { - // We want to send the log output through our scrubber first - log.SetOutput(&safelog.LogScrubber{Output: logOutput}) - }
log.Println("starting")
var err error - broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses) + broker, err = newSignalingServer(sf.RawBrokerURL, sf.KeepLocalAddresses) if err != nil { log.Fatal(err) }
- _, err = url.Parse(stunURL) + _, err = url.Parse(sf.StunURL) if err != nil { log.Fatalf("invalid stun url: %s", err) } - _, err = url.Parse(relayURL) + _, err = url.Parse(sf.RelayURL) if err != nil { log.Fatalf("invalid relay url: %s", err) } @@ -475,27 +492,37 @@ func main() { config = webrtc.Configuration{ ICEServers: []webrtc.ICEServer{ { - URLs: []string{stunURL}, + URLs: []string{sf.StunURL}, }, }, } - tokens = newTokens(capacity) + tokens = newTokens(sf.Capacity)
// use probetest to determine NAT compatability - checkNATType(config, defaultProbeURL) + sf.checkNATType(config, DefaultProbeURL) log.Printf("NAT type: %s", currentNATType)
ticker := time.NewTicker(pollInterval) defer ticker.Stop()
for ; true; <-ticker.C { - tokens.get() - sessionID := genSessionID() - runSession(sessionID) + select { + case <-sf.shutdown: + return + default: + tokens.get() + sessionID := genSessionID() + sf.runSession(sessionID) + } } }
-func checkNATType(config webrtc.Configuration, probeURL string) { +// Stop calls close on the sf.shutdown channel shutting down the Snowflake. +func (sf *SnowflakeProxy) Stop() { + close(sf.shutdown) +} + +func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
probe, err := newSignalingServer(probeURL, false) if err != nil { @@ -504,7 +531,7 @@ func checkNATType(config webrtc.Configuration, probeURL string) {
// create offer dataChan := make(chan struct{}) - pc, err := makeNewPeerConnection(config, dataChan) + pc, err := sf.makeNewPeerConnection(config, dataChan) if err != nil { log.Printf("error making WebRTC connection: %s", err) return diff --git a/proxy/tokens.go b/proxy/lib/tokens.go similarity index 97% rename from proxy/tokens.go rename to proxy/lib/tokens.go index fedb8f7..1331778 100644 --- a/proxy/tokens.go +++ b/proxy/lib/tokens.go @@ -1,4 +1,4 @@ -package main +package snowflake
import ( "sync/atomic" diff --git a/proxy/tokens_test.go b/proxy/lib/tokens_test.go similarity index 96% rename from proxy/tokens_test.go rename to proxy/lib/tokens_test.go index 622cc05..702a887 100644 --- a/proxy/tokens_test.go +++ b/proxy/lib/tokens_test.go @@ -1,4 +1,4 @@ -package main +package snowflake
import ( "testing" diff --git a/proxy/util.go b/proxy/lib/util.go similarity index 71% rename from proxy/util.go rename to proxy/lib/util.go index d737056..c6613d9 100644 --- a/proxy/util.go +++ b/proxy/lib/util.go @@ -1,21 +1,28 @@ -package main +package snowflake
import ( "fmt" "time" )
+// BytesLogger is an interface which is used to allow logging the throughput +// of the Snowflake. A default BytesLogger(BytesNullLogger) does nothing. type BytesLogger interface { AddOutbound(int) AddInbound(int) ThroughputSummary() string }
-// Default BytesLogger does nothing. +// BytesNullLogger Default BytesLogger does nothing. type BytesNullLogger struct{}
-func (b BytesNullLogger) AddOutbound(amount int) {} -func (b BytesNullLogger) AddInbound(amount int) {} +// AddOutbound in BytesNullLogger does nothing +func (b BytesNullLogger) AddOutbound(amount int) {} + +// AddInbound in BytesNullLogger does nothing +func (b BytesNullLogger) AddInbound(amount int) {} + +// ThroughputSummary in BytesNullLogger does nothing func (b BytesNullLogger) ThroughputSummary() string { return "" }
// BytesSyncLogger uses channels to safely log from multiple sources with output @@ -50,14 +57,17 @@ func (b *BytesSyncLogger) log() { } }
+// AddOutbound add a number of bytes to the outbound total reported by the logger func (b *BytesSyncLogger) AddOutbound(amount int) { b.outboundChan <- amount }
+// AddInbound add a number of bytes to the inbound total reported by the logger func (b *BytesSyncLogger) AddInbound(amount int) { b.inboundChan <- amount }
+// ThroughputSummary view a formatted summary of the throughput totals func (b *BytesSyncLogger) ThroughputSummary() string { var inUnit, outUnit string units := []string{"B", "KB", "MB", "GB"} diff --git a/proxy/webrtcconn.go b/proxy/lib/webrtcconn.go similarity index 99% rename from proxy/webrtcconn.go rename to proxy/lib/webrtcconn.go index 5d95919..5c6192b 100644 --- a/proxy/webrtcconn.go +++ b/proxy/lib/webrtcconn.go @@ -1,4 +1,4 @@ -package main +package snowflake
import ( "fmt" diff --git a/proxy/main.go b/proxy/main.go new file mode 100644 index 0000000..12b3752 --- /dev/null +++ b/proxy/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "io" + "log" + "os" + + "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" + "git.torproject.org/pluggable-transports/snowflake.git/proxy/lib" +) + +func main() { + capacity := flag.Int("capacity", 10, "maximum concurrent clients") + stunURL := flag.String("stun", snowflake.DefaultSTUNURL, "broker URL") + logFilename := flag.String("log", "", "log filename") + rawBrokerURL := flag.String("broker", snowflake.DefaultBrokerURL, "broker URL") + unsafeLogging := flag.Bool("unsafe-logging", false, "prevent logs from being scrubbed") + keepLocalAddresses := flag.Bool("keep-local-addresses", false, "keep local LAN address ICE candidates") + relayURL := flag.String("relay", snowflake.DefaultRelayURL, "websocket relay URL") + + flag.Parse() + + sf := snowflake.SnowflakeProxy{ + Capacity: uint(*capacity), + StunURL: *stunURL, + RawBrokerURL: *rawBrokerURL, + KeepLocalAddresses: *keepLocalAddresses, + RelayURL: *relayURL, + LogOutput: os.Stderr, + } + + if *logFilename != "" { + f, err := os.OpenFile(*logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + log.Fatal(err) + } + defer f.Close() + sf.LogOutput = io.MultiWriter(os.Stderr, f) + } + if *unsafeLogging { + log.SetOutput(sf.LogOutput) + } else { + log.SetOutput(&safelog.LogScrubber{Output: sf.LogOutput}) + } + + sf.Start() +}