This is an automated email from the git hooks/post-receive script.
meskio pushed a commit to branch main in repository pluggable-transports/snowflake.
commit c097d5f3bc9e95403006527b90207dfb11ce6438 Author: David Fifield david@bamsoftware.com AuthorDate: Tue Apr 4 18:45:26 2023 -0600
Use a sync.Pool to reuse packet buffers in QueuePacketConn.
This is meant to reduce overall allocations. See past discussion at https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowf... ff. --- common/turbotunnel/queuepacketconn.go | 47 +++++++++++++++---- common/turbotunnel/queuepacketconn_test.go | 72 ++++++++++++++++++++++++++++-- server/lib/http.go | 5 ++- server/lib/snowflake.go | 6 ++- 4 files changed, 116 insertions(+), 14 deletions(-)
diff --git a/common/turbotunnel/queuepacketconn.go b/common/turbotunnel/queuepacketconn.go index 5cdb559..6fcc3bf 100644 --- a/common/turbotunnel/queuepacketconn.go +++ b/common/turbotunnel/queuepacketconn.go @@ -27,23 +27,29 @@ type QueuePacketConn struct { recvQueue chan taggedPacket closeOnce sync.Once closed chan struct{} + mtu int + // Pool of reusable mtu-sized buffers. + bufPool sync.Pool // What error to return when the QueuePacketConn is closed. err atomic.Value }
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients -// for at least a duration of timeout. -func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn { +// for at least a duration of timeout. The maximum packet size is mtu. +func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn { return &QueuePacketConn{ clients: NewClientMap(timeout), localAddr: localAddr, recvQueue: make(chan taggedPacket, queueSize), closed: make(chan struct{}), + mtu: mtu, + bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }}, } }
// QueueIncoming queues an incoming packet and its source address, to be -// returned in a future call to ReadFrom. +// returned in a future call to ReadFrom. If p is longer than the MTU, only its +// first MTU bytes will be used. func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) { select { case <-c.closed: @@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) { default: } // Copy the slice so that the caller may reuse it. - buf := make([]byte, len(p)) + buf := c.bufPool.Get().([]byte) + if len(p) < cap(buf) { + buf = buf[:len(p)] + } else { + buf = buf[:cap(buf)] + } copy(buf, p) select { case c.recvQueue <- taggedPacket{buf, addr}: default: // Drop the incoming packet if the receive queue is full. + c.Restore(buf) } }
@@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte { return c.clients.SendQueue(addr) }
+// Restore adds a slice to the internal pool of packet buffers. Typically you +// will call this with a slice from the OutgoingQueue channel once you are done +// using it. (It is not an error to fail to do so, it will just result in more +// allocations.) +func (c *QueuePacketConn) Restore(p []byte) { + if cap(p) >= c.mtu { + c.bufPool.Put(p) + } +} + // ReadFrom returns a packet and address previously stored by QueueIncoming. func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { select { @@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { case <-c.closed: return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} case packet := <-c.recvQueue: - return copy(p, packet.P), packet.Addr, nil + n := copy(p, packet.P) + c.Restore(packet.P) + return n, packet.Addr, nil } }
// WriteTo queues an outgoing packet for the given address. The queue can later -// be retrieved using the OutgoingQueue method. +// be retrieved using the OutgoingQueue method. If p is longer than the MTU, +// only its first MTU bytes will be used. func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { select { case <-c.closed: @@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { default: } // Copy the slice so that the caller may reuse it. - buf := make([]byte, len(p)) + buf := c.bufPool.Get().([]byte) + if len(p) < cap(buf) { + buf = buf[:len(p)] + } else { + buf = buf[:cap(buf)] + } copy(buf, p) select { case c.clients.SendQueue(addr) <- buf: return len(buf), nil default: // Drop the outgoing packet if the send queue is full. - return len(buf), nil + c.Restore(buf) + return len(p), nil } }
diff --git a/common/turbotunnel/queuepacketconn_test.go b/common/turbotunnel/queuepacketconn_test.go index 37f46bc..b9f62c9 100644 --- a/common/turbotunnel/queuepacketconn_test.go +++ b/common/turbotunnel/queuepacketconn_test.go @@ -23,7 +23,7 @@ func (i intAddr) String() string { return fmt.Sprintf("%d", i) }
// Run with -benchmem to see memory allocations. func BenchmarkQueueIncoming(b *testing.B) { - conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500) defer conn.Close()
b.ResetTimer() @@ -36,7 +36,7 @@ func BenchmarkQueueIncoming(b *testing.B) {
// BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function. func BenchmarkWriteTo(b *testing.B) { - conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500) defer conn.Close()
b.ResetTimer() @@ -47,6 +47,72 @@ func BenchmarkWriteTo(b *testing.B) { b.StopTimer() }
+// TestQueueIncomingOversize tests that QueueIncoming truncates packets that are +// larger than the MTU. +func TestQueueIncomingOversize(t *testing.T) { + const payload = "abcdefghijklmnopqrstuvwxyz" + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1) + defer conn.Close() + conn.QueueIncoming([]byte(payload), emptyAddr{}) + var p [500]byte + n, _, err := conn.ReadFrom(p[:]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(p[:n], []byte(payload[:len(payload)-1])) { + t.Fatalf("payload was %+q, expected %+q", p[:n], payload[:len(payload)-1]) + } +} + +// TestWriteToOversize tests that WriteTo truncates packets that are larger than +// the MTU. +func TestWriteToOversize(t *testing.T) { + const payload = "abcdefghijklmnopqrstuvwxyz" + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1) + defer conn.Close() + conn.WriteTo([]byte(payload), emptyAddr{}) + p := <-conn.OutgoingQueue(emptyAddr{}) + if !bytes.Equal(p, []byte(payload[:len(payload)-1])) { + t.Fatalf("payload was %+q, expected %+q", p, payload[:len(payload)-1]) + } +} + +// TestRestoreMTU tests that Restore ignores any inputs that are not at least +// MTU-sized. +func TestRestoreMTU(t *testing.T) { + const mtu = 500 + const payload = "hello" + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu) + defer conn.Close() + conn.Restore(make([]byte, mtu-1)) + // This WriteTo may use the short slice we just gave to Restore. + conn.WriteTo([]byte(payload), emptyAddr{}) + // Read the queued slice and ensure its capacity is at least the MTU. + p := <-conn.OutgoingQueue(emptyAddr{}) + if cap(p) != mtu { + t.Fatalf("cap was %v, expected %v", cap(p), mtu) + } + // Check the payload while we're at it. + if !bytes.Equal(p, []byte(payload)) { + t.Fatalf("payload was %+q, expected %+q", p, payload) + } +} + +// TestRestoreCap tests that Restore can use slices whose cap is at least the +// MTU, even if the len is shorter. +func TestRestoreCap(t *testing.T) { + const mtu = 500 + const payload = "hello" + conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu) + defer conn.Close() + conn.Restore(make([]byte, 0, mtu)) + conn.WriteTo([]byte(payload), emptyAddr{}) + p := <-conn.OutgoingQueue(emptyAddr{}) + if !bytes.Equal(p, []byte(payload)) { + t.Fatalf("payload was %+q, expected %+q", p, payload) + } +} + // DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and // whose WriteTo method discards whatever it is called with. type DiscardPacketConn struct{} @@ -122,7 +188,7 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) { } }()
- pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) + pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500) defer pconn.Close() addr1 := intAddr(1) outgoing := pconn.OutgoingQueue(addr1) diff --git a/server/lib/http.go b/server/lib/http.go index 3a01884..8c0343f 100644 --- a/server/lib/http.go +++ b/server/lib/http.go @@ -69,10 +69,10 @@ type httpHandler struct {
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets // over incoming WebSocket connections. -func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler { +func newHTTPHandler(localAddr net.Addr, numInstances int, mtu int) *httpHandler { pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances) for i := 0; i < numInstances; i++ { - pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout)) + pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout, mtu)) }
clientIDLookupKey := make([]byte, 16) @@ -200,6 +200,7 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error return } _, err := encapsulation.WriteData(bw, p) + pconn.Restore(p) if err == nil { err = bw.Flush() } diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go index c4d3fbc..3c3c440 100644 --- a/server/lib/snowflake.go +++ b/server/lib/snowflake.go @@ -79,7 +79,11 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen ln: make([]*kcp.Listener, 0, numKCPInstances), }
- handler := newHTTPHandler(addr, numKCPInstances) + // kcp-go doesn't provide an accessor for the current MTU setting (and + // anyway we could not create a kcp.Listener without creating a + // net.PacketConn for it first), so assume the default kcp.IKCP_MTU_DEF + // (1400 bytes) and don't increase it elsewhere. + handler := newHTTPHandler(addr, numKCPInstances, kcp.IKCP_MTU_DEF) server := &http.Server{ Addr: addr.String(), Handler: handler,