commit ddcdfc4f0922e00c672797b0d6544371423f2989 Author: Cecylia Bocovich cohosh@torproject.org Date: Thu Jun 17 16:36:50 2021 -0400
Fix datarace for WebRTCPeer.closed
The race condition occurs because concurrent goroutines are intermixing reads and writes of `WebRTCPeer.closed`.
Spotted when integrating Snowflake inside OONI in https://github.com/ooni/probe-cli/pull/373. --- client/lib/lib_test.go | 6 +++--- client/lib/peers.go | 4 ++-- client/lib/webrtc.go | 24 ++++++++++++++++++------ 3 files changed, 23 insertions(+), 11 deletions(-)
diff --git a/client/lib/lib_test.go b/client/lib/lib_test.go index e742e06..55ea7b9 100644 --- a/client/lib/lib_test.go +++ b/client/lib/lib_test.go @@ -33,7 +33,7 @@ type FakeDialer struct {
func (w FakeDialer) Catch() (*WebRTCPeer, error) { fmt.Println("Caught a dummy snowflake.") - return &WebRTCPeer{}, nil + return &WebRTCPeer{closed: make(chan struct{})}, nil }
func (w FakeDialer) GetMax() int { @@ -97,7 +97,7 @@ func TestSnowflakeClient(t *testing.T) { So(err, ShouldNotBeNil) So(p.Count(), ShouldEqual, c)
- // But popping and closing allows it to continue. + // But popping allows it to continue. s := p.Pop() s.Close() So(s, ShouldNotBeNil) @@ -127,7 +127,7 @@ func TestSnowflakeClient(t *testing.T) { cnt := 5 p, _ := NewPeers(FakeDialer{max: cnt}) for i := 0; i < cnt; i++ { - p.activePeers.PushBack(&WebRTCPeer{}) + p.activePeers.PushBack(&WebRTCPeer{closed: make(chan struct{})}) } So(p.Count(), ShouldEqual, cnt) p.End() diff --git a/client/lib/peers.go b/client/lib/peers.go index d02eed3..6fa2d29 100644 --- a/client/lib/peers.go +++ b/client/lib/peers.go @@ -83,7 +83,7 @@ func (p *Peers) Pop() *WebRTCPeer { if !ok { return nil } - if snowflake.closed { + if snowflake.Closed() { continue } // Set to use the same rate-limited traffic logger to keep consistency. @@ -110,7 +110,7 @@ func (p *Peers) purgeClosedPeers() { next := e.Next() conn := e.Value.(*WebRTCPeer) // Purge those marked for deletion. - if conn.closed { + if conn.Closed() { p.activePeers.Remove(e) } e = next diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go index 6a42ebd..234f53c 100644 --- a/client/lib/webrtc.go +++ b/client/lib/webrtc.go @@ -28,7 +28,7 @@ type WebRTCPeer struct { lastReceive time.Time
open chan struct{} // Channel to notify when datachannel opens - closed bool + closed chan struct{}
once sync.Once // Synchronization for PeerConnection destruction
@@ -46,6 +46,7 @@ func NewWebRTCPeer(config *webrtc.Configuration, } connection.id = "snowflake-" + hex.EncodeToString(buf[:]) } + connection.closed = make(chan struct{})
// Override with something that's not NullLogger to have real logging. connection.BytesLogger = &BytesNullLogger{} @@ -78,9 +79,19 @@ func (c *WebRTCPeer) Write(b []byte) (int, error) { return len(b), nil }
+//Returns a boolean indicated whether the peer is closed +func (c *WebRTCPeer) Closed() bool { + select { + case <-c.closed: + return true + default: + } + return false +} + func (c *WebRTCPeer) Close() error { c.once.Do(func() { - c.closed = true + close(c.closed) c.cleanup() log.Printf("WebRTC: Closing") }) @@ -95,9 +106,6 @@ func (c *WebRTCPeer) checkForStaleness() { c.lastReceive = time.Now() c.mu.Unlock() for { - if c.closed { - return - } c.mu.Lock() lastReceive := c.lastReceive c.mu.Unlock() @@ -107,7 +115,11 @@ func (c *WebRTCPeer) checkForStaleness() { c.Close() return } - <-time.After(time.Second) + select { + case <-c.closed: + return + case <-time.After(time.Second): + } } }