commit 222ab3d85a4113088db3e3b742411806922c028c Author: David Fifield david@bamsoftware.com Date: Tue Jan 28 02:29:34 2020 -0700
Import Turbo Tunnel support code.
Copied and slightly modified from https://gitweb.torproject.org/pluggable-transports/meek.git/log/?h=turbotunn... https://github.com/net4people/bbs/issues/21
RedialPacketConn is adapted from clientPacketConn in https://dip.torproject.org/dcf/obfs4/blob/c64a61c6da3bf1c2f98221bb1e1af8a358... https://github.com/net4people/bbs/issues/14#issuecomment-544747519 --- common/encapsulation/encapsulation.go | 194 +++++++++++++++++ common/encapsulation/encapsulation_test.go | 330 +++++++++++++++++++++++++++++ common/turbotunnel/clientid.go | 28 +++ common/turbotunnel/clientmap.go | 144 +++++++++++++ common/turbotunnel/consts.go | 13 ++ common/turbotunnel/queuepacketconn.go | 137 ++++++++++++ common/turbotunnel/redialpacketconn.go | 204 ++++++++++++++++++ 7 files changed, 1050 insertions(+)
diff --git a/common/encapsulation/encapsulation.go b/common/encapsulation/encapsulation.go new file mode 100644 index 0000000..bfe9b5b --- /dev/null +++ b/common/encapsulation/encapsulation.go @@ -0,0 +1,194 @@ +// Package encapsulation implements a way of encoding variable-size chunks of +// data and padding into a byte stream. +// +// Each chunk of data or padding starts with a variable-size length prefix. One +// bit ("d") in the first byte of the prefix indicates whether the chunk +// represents data or padding (1=data, 0=padding). Another bit ("c" for +// "continuation") is the indicates whether there are more bytes in the length +// prefix. The remaining 6 bits ("x") encode part of the length value. +// dcxxxxxx +// If the continuation bit is set, then the next byte is also part of the length +// prefix. It lacks the "d" bit, has its own "c" bit, and 7 value-carrying bits +// ("y"). +// cyyyyyyy +// The length is decoded by concatenating value-carrying bits, from left to +// right, of all value-carrying bits, up to and including the first byte whose +// "c" bit is 0. Although in principle this encoding would allow for length +// prefixes of any size, length prefixes are arbitrarily limited to 3 bytes and +// any attempt to read or write a longer one is an error. These are therefore +// the only valid formats: +// 00xxxxxx xxxxxx₂ bytes of padding +// 10xxxxxx xxxxxx₂ bytes of data +// 01xxxxxx 0yyyyyyy xxxxxxyyyyyyy₂ bytes of padding +// 11xxxxxx 0yyyyyyy xxxxxxyyyyyyy₂ bytes of data +// 01xxxxxx 1yyyyyyy 0zzzzzzz xxxxxxyyyyyyyzzzzzzz₂ bytes of padding +// 11xxxxxx 1yyyyyyy 0zzzzzzz xxxxxxyyyyyyyzzzzzzz₂ bytes of data +// The maximum encodable length is 11111111111111111111₂ = 0xfffff = 1048575. +// There is no requirement to use a length prefix of minimum size; i.e. 00000100 +// and 01000000 00000100 are both valid encodings of the value 4. +// +// After the length prefix follow that many bytes of padding or data. There are +// no restrictions on the value of bytes comprising padding. +// +// The idea for this encapsulation is sketched here: +// https://github.com/net4people/bbs/issues/9#issuecomment-524095186 +package encapsulation + +import ( + "errors" + "io" + "io/ioutil" +) + +// ErrTooLong is the error returned when an encoded length prefix is longer than +// 3 bytes, or when ReadData receives an input whose length is too large to +// encode in a 3-byte length prefix. +var ErrTooLong = errors.New("length prefix is too long") + +// ReadData returns a new slice with the contents of the next available data +// chunk, skipping over any padding chunks that may come first. The returned +// error value is nil if and only if a data chunk was present and was read in +// its entirety. The returned error is io.EOF only if r ended before the first +// byte of a length prefix. If r ended in the middle of a length prefix or +// data/padding, the returned error is io.ErrUnexpectedEOF. +func ReadData(r io.Reader) ([]byte, error) { + for { + var b [1]byte + _, err := r.Read(b[:]) + if err != nil { + // This is the only place we may return a real io.EOF. + return nil, err + } + isData := (b[0] & 0x80) != 0 + moreLength := (b[0] & 0x40) != 0 + n := int(b[0] & 0x3f) + for i := 0; moreLength; i++ { + if i >= 2 { + return nil, ErrTooLong + } + _, err := r.Read(b[:]) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if err != nil { + return nil, err + } + moreLength = (b[0] & 0x80) != 0 + n = (n << 7) | int(b[0]&0x7f) + } + if isData { + p := make([]byte, n) + _, err := io.ReadFull(r, p) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if err != nil { + return nil, err + } + return p, err + } else { + _, err := io.CopyN(ioutil.Discard, r, int64(n)) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if err != nil { + return nil, err + } + } + } +} + +// dataPrefixForLength returns a length prefix for the given length, with the +// "d" bit set to 1. +func dataPrefixForLength(n int) ([]byte, error) { + switch { + case (n>>0)&0x3f == (n >> 0): + return []byte{0x80 | byte((n>>0)&0x3f)}, nil + case (n>>7)&0x3f == (n >> 7): + return []byte{0xc0 | byte((n>>7)&0x3f), byte((n >> 0) & 0x7f)}, nil + case (n>>14)&0x3f == (n >> 14): + return []byte{0xc0 | byte((n>>14)&0x3f), 0x80 | byte((n>>7)&0x7f), byte((n >> 0) & 0x7f)}, nil + default: + return nil, ErrTooLong + } +} + +// WriteData encodes a data chunk into w. It returns the total number of bytes +// written; i.e., including the length prefix. The error is ErrTooLong if the +// length of data cannot fit into a length prefix. +func WriteData(w io.Writer, data []byte) (int, error) { + prefix, err := dataPrefixForLength(len(data)) + if err != nil { + return 0, err + } + total := 0 + n, err := w.Write(prefix) + total += n + if err != nil { + return total, err + } + n, err = w.Write(data) + total += n + return total, err +} + +var paddingBuffer = make([]byte, 1024) + +// WritePadding encodes padding chunks, whose total size (including their own +// length prefixes) is n. Returns the total number of bytes written to w, which +// will be exactly n unless there was an error. The error cannot be ErrTooLong +// because this function will write multiple padding chunks if necessary to +// reach the requested size. Panics if n is negative. +func WritePadding(w io.Writer, n int) (int, error) { + if n < 0 { + panic("negative length") + } + total := 0 + for n > 0 { + p := len(paddingBuffer) + if p > n { + p = n + } + n -= p + var prefix []byte + switch { + case ((p-1)>>0)&0x3f == ((p - 1) >> 0): + p = p - 1 + prefix = []byte{byte((p >> 0) & 0x3f)} + case ((p-2)>>7)&0x3f == ((p - 2) >> 7): + p = p - 2 + prefix = []byte{0x40 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)} + case ((p-3)>>14)&0x3f == ((p - 3) >> 14): + p = p - 3 + prefix = []byte{0x40 | byte((p>>14)&0x3f), 0x80 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)} + } + nn, err := w.Write(prefix) + total += nn + if err != nil { + return total, err + } + nn, err = w.Write(paddingBuffer[:p]) + total += nn + if err != nil { + return total, err + } + } + return total, nil +} + +// MaxDataForSize returns the length of the longest slice that can pe passed to +// WriteData, whose total encoded size (including length prefix) is no larger +// than n. Call this to find out if a chunk of data will fit into a length +// budget. Panics if n == 0. +func MaxDataForSize(n int) int { + if n == 0 { + panic("zero length") + } + prefix, err := dataPrefixForLength(n) + if err == ErrTooLong { + return (1 << (6 + 7 + 7)) - 1 - 3 + } else if err != nil { + panic(err) + } + return n - len(prefix) +} diff --git a/common/encapsulation/encapsulation_test.go b/common/encapsulation/encapsulation_test.go new file mode 100644 index 0000000..333abb4 --- /dev/null +++ b/common/encapsulation/encapsulation_test.go @@ -0,0 +1,330 @@ +package encapsulation + +import ( + "bytes" + "io" + "math/rand" + "testing" +) + +// Return a byte slice with non-trivial contents. +func pseudorandomBuffer(n int) []byte { + source := rand.NewSource(0) + p := make([]byte, n) + for i := 0; i < len(p); i++ { + p[i] = byte(source.Int63() & 0xff) + } + return p +} + +func mustWriteData(w io.Writer, p []byte) int { + n, err := WriteData(w, p) + if err != nil { + panic(err) + } + return n +} + +func mustWritePadding(w io.Writer, n int) int { + n, err := WritePadding(w, n) + if err != nil { + panic(err) + } + return n +} + +// Test that ReadData(WriteData()) recovers the original data. +func TestRoundtrip(t *testing.T) { + // Test above and below interesting thresholds. + for _, i := range []int{ + 0x00, 0x01, + 0x3e, 0x3f, 0x40, 0x41, + 0xfe, 0xff, 0x100, 0x101, + 0x1ffe, 0x1fff, 0x2000, 0x2001, + 0xfffe, 0xffff, 0x10000, 0x10001, + 0xffffe, 0xfffff, + } { + original := pseudorandomBuffer(i) + var enc bytes.Buffer + n, err := WriteData(&enc, original) + if err != nil { + t.Fatalf("size %d, WriteData returned error %v", i, err) + } + if enc.Len() != n { + t.Fatalf("size %d, returned length was %d, written length was %d", + i, n, enc.Len()) + } + inverse, err := ReadData(&enc) + if err != nil { + t.Fatalf("size %d, ReadData returned error %v", i, err) + } + if !bytes.Equal(inverse, original) { + t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse, original) + } + } +} + +// Test that WritePadding writes exactly as much as requested. +func TestPaddingLength(t *testing.T) { + // Test above and below interesting thresholds. WritePadding also gets + // values above 0xfffff, the maximum value of a single length prefix. + for _, i := range []int{ + 0x00, 0x01, + 0x3f, 0x40, 0x41, 0x42, + 0xff, 0x100, 0x101, 0x102, + 0x2000, 0x2001, 0x2002, 0x2003, + 0x10000, 0x10001, 0x10002, 0x10003, + 0x100001, 0x100002, 0x100003, 0x100004, + } { + var enc bytes.Buffer + n, err := WritePadding(&enc, i) + if err != nil { + t.Fatalf("size %d, WritePadding returned error %v", i, err) + } + if n != i { + t.Fatalf("requested %d bytes, returned %d", i, n) + } + if enc.Len() != n { + t.Fatalf("requested %d bytes, wrote %d bytes", i, enc.Len()) + } + } +} + +// Test that ReadData skips over padding. +func TestSkipPadding(t *testing.T) { + var data = [][]byte{{}, {}, []byte("hello"), {}, []byte("world")} + var enc bytes.Buffer + mustWritePadding(&enc, 10) + mustWritePadding(&enc, 100) + mustWriteData(&enc, data[0]) + mustWriteData(&enc, data[1]) + mustWritePadding(&enc, 10) + mustWriteData(&enc, data[2]) + mustWriteData(&enc, data[3]) + mustWritePadding(&enc, 10) + mustWriteData(&enc, data[4]) + mustWritePadding(&enc, 10) + mustWritePadding(&enc, 10) + for i, expected := range data { + actual, err := ReadData(&enc) + if err != nil { + t.Fatalf("slice %d, got error %v, expected %v", i, err, nil) + } + if !bytes.Equal(actual, expected) { + t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual, expected) + } + } + p, err := ReadData(&enc) + if p != nil || err != io.EOF { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF) + } +} + +// Test that EOF before a length prefix returns io.EOF. +func TestEOF(t *testing.T) { + p, err := ReadData(bytes.NewReader(nil)) + if p != nil || err != io.EOF { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF) + } +} + +// Test that an EOF while reading a length prefix, or while reading the +// subsequent data/padding, returns io.ErrUnexpectedEOF. +func TestUnexpectedEOF(t *testing.T) { + for _, test := range [][]byte{ + {0x40}, // expecting a second length byte + {0xc0}, // expecting a second length byte + {0x41, 0x80}, // expecting a third length byte + {0xc1, 0x80}, // expecting a third length byte + {0x02}, // expecting 2 bytes of padding + {0x82}, // expecting 2 bytes of data + {0x02, 'X'}, // expecting 1 byte of padding + {0x82, 'X'}, // expecting 1 byte of data + {0x41, 0x00}, // expecting 128 bytes of padding + {0xc1, 0x00}, // expecting 128 bytes of data + {0x41, 0x00, 'X'}, // expecting 127 bytes of padding + {0xc1, 0x00, 'X'}, // expecting 127 bytes of data + {0x41, 0x80, 0x00}, // expecting 32768 bytes of padding + {0xc1, 0x80, 0x00}, // expecting 32768 bytes of data + {0x41, 0x80, 0x00, 'X'}, // expecting 32767 bytes of padding + {0xc1, 0x80, 0x00, 'X'}, // expecting 32767 bytes of data + } { + p, err := ReadData(bytes.NewReader(test)) + if p != nil || err != io.ErrUnexpectedEOF { + t.Fatalf("<%x> got (<%x>, %v), expected (%v, %v)", test, p, err, nil, io.ErrUnexpectedEOF) + } + } +} + +// Test that length encodings that are longer than they could be are still +// interpreted. +func TestNonMinimalLengthEncoding(t *testing.T) { + for _, test := range []struct { + enc []byte + expected []byte + }{ + {[]byte{0x81, 'X'}, []byte("X")}, + {[]byte{0xc0, 0x01, 'X'}, []byte("X")}, + {[]byte{0xc0, 0x80, 0x01, 'X'}, []byte("X")}, + } { + p, err := ReadData(bytes.NewReader(test.enc)) + if err != nil { + t.Fatalf("<%x> got error %v, expected %v", test.enc, err, nil) + } + if !bytes.Equal(p, test.expected) { + t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p, test.expected) + } + } +} + +// Test that ReadData only reads up to 3 bytes of length prefix. +func TestReadLimits(t *testing.T) { + // Test the maximum length that's possible with 3 bytes of length + // prefix. + maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f + data := bytes.Repeat([]byte{'X'}, maxLength) + prefix := []byte{0xff, 0xff, 0x7f} // encodes 0xfffff + p, err := ReadData(bytes.NewReader(append(prefix, data...))) + if err != nil { + t.Fatalf("got error %v, expected %v", err, nil) + } + if !bytes.Equal(p, data) { + t.Fatalf("got %d bytes unequal to %d bytes", len(p), len(data)) + } + // Test a 4-byte prefix. + prefix = []byte{0xc0, 0xc0, 0x80, 0x80} // encodes 0x100000 + data = bytes.Repeat([]byte{'X'}, maxLength+1) + p, err = ReadData(bytes.NewReader(append(prefix, data...))) + if p != nil || err != ErrTooLong { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + } + // Test that 4 bytes don't work, even when they encode an integer that + // would fix in 3 bytes. + prefix = []byte{0xc0, 0x80, 0x80, 0x80} // encodes 0x0 + data = []byte{} + p, err = ReadData(bytes.NewReader(append(prefix, data...))) + if p != nil || err != ErrTooLong { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + } + + // Do the same tests with padding lengths. + data = []byte("hello") + prefix = []byte{0x7f, 0xff, 0x7f} // encodes 0xfffff + padding := bytes.Repeat([]byte{'X'}, maxLength) + enc := bytes.NewBuffer(append(prefix, padding...)) + mustWriteData(enc, data) + p, err = ReadData(enc) + if err != nil { + t.Fatalf("got error %v, expected %v", err, nil) + } + if !bytes.Equal(p, data) { + t.Fatalf("got <%x>, expected <%x>", p, data) + } + prefix = []byte{0x40, 0xc0, 0x80, 0x80} // encodes 0x100000 + padding = bytes.Repeat([]byte{'X'}, maxLength+1) + enc = bytes.NewBuffer(append(prefix, padding...)) + mustWriteData(enc, data) + p, err = ReadData(enc) + if p != nil || err != ErrTooLong { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + } + prefix = []byte{0x40, 0x80, 0x80, 0x80} // encodes 0x0 + padding = []byte{} + enc = bytes.NewBuffer(append(prefix, padding...)) + mustWriteData(enc, data) + p, err = ReadData(enc) + if p != nil || err != ErrTooLong { + t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + } +} + +// Test that WriteData and WritePadding only accept lengths that can be encoded +// in up to 3 bytes of length prefix. +func TestWriteLimits(t *testing.T) { + maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f + var enc bytes.Buffer + n, err := WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength)) + if n != maxLength+3 || err != nil { + t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength, nil) + } + enc.Reset() + n, err = WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength+1)) + if n != 0 || err != ErrTooLong { + t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, 0, ErrTooLong) + } + + // Padding gets an extra 3 bytes because the prefix is counted as part + // of the length. + enc.Reset() + n, err = WritePadding(&enc, maxLength+3) + if n != maxLength+3 || err != nil { + t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+3, nil) + } + // Writing a too-long padding is okay because WritePadding will break it + // into smaller chunks. + enc.Reset() + n, err = WritePadding(&enc, maxLength+4) + if n != maxLength+4 || err != nil { + t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+4, nil) + } +} + +// Test that WritePadding panics when given a negative length. +func TestNegativeLength(t *testing.T) { + for _, n := range []int{-1, ^0} { + var enc bytes.Buffer + panicked, nn, err := testNegativeLengthSub(t, &enc, n) + if !panicked { + t.Fatalf("WritePadding(%d) returned (%d, %v) instead of panicking", n, nn, err) + } + } +} + +// Calls WritePadding(w, n) and augments the return value with a flag indicating +// whether the call panicked. +func testNegativeLengthSub(t *testing.T, w io.Writer, n int) (panicked bool, nn int, err error) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + t.Helper() + nn, err = WritePadding(w, n) + return false, n, err +} + +// Test that MaxDataForSize panics when given a 0 length. +func TestMaxDataForSizeZero(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("didn't panic") + } + }() + MaxDataForSize(0) +} + +// Test thresholds of available sizes for MaxDataForSize. +func TestMaxDataForSize(t *testing.T) { + for _, test := range []struct { + size int + expected int + }{ + {0x01, 0x00}, + {0x02, 0x01}, + {0x3f, 0x3e}, + {0x40, 0x3e}, + {0x41, 0x3f}, + {0x1fff, 0x1ffd}, + {0x2000, 0x1ffd}, + {0x2001, 0x1ffe}, + {0xfffff, 0xffffc}, + {0x100000, 0xffffc}, + {0x100001, 0xffffc}, + {0x7fffffff, 0xffffc}, + } { + max := MaxDataForSize(test.size) + if max != test.expected { + t.Fatalf("size %d, got %d, expected %d", test.size, max, test.expected) + } + } +} diff --git a/common/turbotunnel/clientid.go b/common/turbotunnel/clientid.go new file mode 100644 index 0000000..17257e1 --- /dev/null +++ b/common/turbotunnel/clientid.go @@ -0,0 +1,28 @@ +package turbotunnel + +import ( + "crypto/rand" + "encoding/hex" +) + +// ClientID is an abstract identifier that binds together all the communications +// belonging to a single client session, even though those communications may +// arrive from multiple IP addresses or over multiple lower-level connections. +// It plays the same role that an (IP address, port number) tuple plays in a +// net.UDPConn: it's the return address pertaining to a long-lived abstract +// client session. The client attaches its ClientID to each of its +// communications, enabling the server to disambiguate requests among its many +// clients. ClientID implements the net.Addr interface. +type ClientID [8]byte + +func NewClientID() ClientID { + var id ClientID + _, err := rand.Read(id[:]) + if err != nil { + panic(err) + } + return id +} + +func (id ClientID) Network() string { return "clientid" } +func (id ClientID) String() string { return hex.EncodeToString(id[:]) } diff --git a/common/turbotunnel/clientmap.go b/common/turbotunnel/clientmap.go new file mode 100644 index 0000000..fa12915 --- /dev/null +++ b/common/turbotunnel/clientmap.go @@ -0,0 +1,144 @@ +package turbotunnel + +import ( + "container/heap" + "net" + "sync" + "time" +) + +// clientRecord is a record of a recently seen client, with the time it was last +// seen and a send queue. +type clientRecord struct { + Addr net.Addr + LastSeen time.Time + SendQueue chan []byte +} + +// ClientMap manages a mapping of live clients (keyed by address, which will be +// a ClientID) to their respective send queues. ClientMap's functions are safe +// to call from multiple goroutines. +type ClientMap struct { + // We use an inner structure to avoid exposing public heap.Interface + // functions to users of clientMap. + inner clientMapInner + // Synchronizes access to inner. + lock sync.Mutex +} + +// NewClientMap creates a ClientMap that expires clients after a timeout. +// +// The timeout does not have to be kept in sync with QUIC's internal idle +// timeout. If a client is removed from the client map while the QUIC session is +// still live, the worst that can happen is a loss of whatever packets were in +// the send queue at the time. If QUIC later decides to send more packets to the +// same client, we'll instantiate a new send queue, and if the client ever +// connects again with the proper client ID, we'll deliver them. +func NewClientMap(timeout time.Duration) *ClientMap { + m := &ClientMap{ + inner: clientMapInner{ + byAge: make([]*clientRecord, 0), + byAddr: make(map[net.Addr]int), + }, + } + go func() { + for { + time.Sleep(timeout / 2) + now := time.Now() + m.lock.Lock() + m.inner.removeExpired(now, timeout) + m.lock.Unlock() + } + }() + return m +} + +// SendQueue returns the send queue corresponding to addr, creating it if +// necessary. +func (m *ClientMap) SendQueue(addr net.Addr) chan []byte { + m.lock.Lock() + defer m.lock.Unlock() + return m.inner.SendQueue(addr, time.Now()) +} + +// clientMapInner is the inner type of ClientMap, implementing heap.Interface. +// byAge is the backing store, a heap ordered by LastSeen time, to facilitate +// expiring old client records. byAddr is a map from addresses (i.e., ClientIDs) +// to heap indices, to allow looking up by address. Unlike ClientMap, +// clientMapInner requires external synchonization. +type clientMapInner struct { + byAge []*clientRecord + byAddr map[net.Addr]int +} + +// removeExpired removes all client records whose LastSeen timestamp is more +// than timeout in the past. +func (inner *clientMapInner) removeExpired(now time.Time, timeout time.Duration) { + for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout { + heap.Pop(inner) + } +} + +// SendQueue finds the existing client record corresponding to addr, or creates +// a new one if none exists yet. It updates the client record's LastSeen time +// and returns its SendQueue. +func (inner *clientMapInner) SendQueue(addr net.Addr, now time.Time) chan []byte { + var record *clientRecord + i, ok := inner.byAddr[addr] + if ok { + // Found one, update its LastSeen. + record = inner.byAge[i] + record.LastSeen = now + heap.Fix(inner, i) + } else { + // Not found, create a new one. + record = &clientRecord{ + Addr: addr, + LastSeen: now, + SendQueue: make(chan []byte, queueSize), + } + heap.Push(inner, record) + } + return record.SendQueue +} + +// heap.Interface for clientMapInner. + +func (inner *clientMapInner) Len() int { + if len(inner.byAge) != len(inner.byAddr) { + panic("inconsistent clientMap") + } + return len(inner.byAge) +} + +func (inner *clientMapInner) Less(i, j int) bool { + return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen) +} + +func (inner *clientMapInner) Swap(i, j int) { + inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i] + inner.byAddr[inner.byAge[i].Addr] = i + inner.byAddr[inner.byAge[j].Addr] = j +} + +func (inner *clientMapInner) Push(x interface{}) { + record := x.(*clientRecord) + if _, ok := inner.byAddr[record.Addr]; ok { + panic("duplicate address in clientMap") + } + // Insert into byAddr map. + inner.byAddr[record.Addr] = len(inner.byAge) + // Insert into byAge slice. + inner.byAge = append(inner.byAge, record) +} + +func (inner *clientMapInner) Pop() interface{} { + n := len(inner.byAddr) + // Remove from byAge slice. + record := inner.byAge[n-1] + inner.byAge[n-1] = nil + inner.byAge = inner.byAge[:n-1] + // Remove from byAddr map. + delete(inner.byAddr, record.Addr) + return record +} diff --git a/common/turbotunnel/consts.go b/common/turbotunnel/consts.go new file mode 100644 index 0000000..4699d1d --- /dev/null +++ b/common/turbotunnel/consts.go @@ -0,0 +1,13 @@ +// Package turbotunnel provides support for overlaying a virtual net.PacketConn +// on some other network carrier. +// +// https://github.com/net4people/bbs/issues/9 +package turbotunnel + +import "errors" + +// The size of receive and send queues. +const queueSize = 32 + +var errClosedPacketConn = errors.New("operation on closed connection") +var errNotImplemented = errors.New("not implemented") diff --git a/common/turbotunnel/queuepacketconn.go b/common/turbotunnel/queuepacketconn.go new file mode 100644 index 0000000..14a9833 --- /dev/null +++ b/common/turbotunnel/queuepacketconn.go @@ -0,0 +1,137 @@ +package turbotunnel + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the +// return type of PacketConn.ReadFrom. +type taggedPacket struct { + P []byte + Addr net.Addr +} + +// QueuePacketConn implements net.PacketConn by storing queues of packets. There +// is one incoming queue (where packets are additionally tagged by the source +// address of the client that sent them). There are many outgoing queues, one +// for each client address that has been recently seen. The QueueIncoming method +// inserts a packet into the incoming queue, to eventually be returned by +// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue, +// which can later by accessed through the OutgoingQueue method. +type QueuePacketConn struct { + clients *ClientMap + localAddr net.Addr + recvQueue chan taggedPacket + closeOnce sync.Once + closed chan struct{} + // What error to return when the QueuePacketConn is closed. + err atomic.Value +} + +// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients +// for at least a duration of timeout. +func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn { + return &QueuePacketConn{ + clients: NewClientMap(timeout), + localAddr: localAddr, + recvQueue: make(chan taggedPacket, queueSize), + closed: make(chan struct{}), + } +} + +// QueueIncoming queues and incoming packet and its source address, to be +// returned in a future call to ReadFrom. +func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) { + select { + case <-c.closed: + // If we're closed, silently drop it. + return + default: + } + // Copy the slice so that the caller may reuse it. + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.recvQueue <- taggedPacket{buf, addr}: + default: + // Drop the incoming packet if the receive queue is full. + } +} + +// OutgoingQueue returns the queue of outgoing packets corresponding to addr, +// creating it if necessary. The contents of the queue will be packets that are +// written to the address in question using WriteTo. +func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte { + return c.clients.SendQueue(addr) +} + +// ReadFrom returns a packet and address previously stored by QueueIncoming. +func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + default: + } + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + case packet := <-c.recvQueue: + return copy(p, packet.P), packet.Addr, nil + } +} + +// WriteTo queues an outgoing packet for the given address. The queue can later +// be retrieved using the OutgoingQueue method. +func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + select { + case <-c.closed: + return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + default: + } + // Copy the slice so that the caller may reuse it. + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.clients.SendQueue(addr) <- buf: + return len(buf), nil + default: + // Drop the outgoing packet if the send queue is full. + return len(buf), nil + } +} + +// closeWithError unblocks pending operations and makes future operations fail +// with the given error. If err is nil, it becomes errClosedPacketConn. +func (c *QueuePacketConn) closeWithError(err error) error { + var newlyClosed bool + c.closeOnce.Do(func() { + newlyClosed = true + // Store the error to be returned by future PacketConn + // operations. + if err == nil { + err = errClosedPacketConn + } + c.err.Store(err) + close(c.closed) + }) + if !newlyClosed { + return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + } + return nil +} + +// Close unblocks pending operations and makes future operations fail with a +// "closed connection" error. +func (c *QueuePacketConn) Close() error { + return c.closeWithError(nil) +} + +// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn. +func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr } + +func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented } +func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented } +func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented } diff --git a/common/turbotunnel/redialpacketconn.go b/common/turbotunnel/redialpacketconn.go new file mode 100644 index 0000000..cf6a8c9 --- /dev/null +++ b/common/turbotunnel/redialpacketconn.go @@ -0,0 +1,204 @@ +package turbotunnel + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" +) + +// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of +// other, transient net.PacketConns. RedialPacketConn creates a new +// net.PacketConn by calling a provided dialContext function. Whenever the +// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn +// calls the dialContext function again and starts sending and receiving packets +// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo +// methods return an error only when the dialContext function returns an error. +// +// RedialPacketConn uses static local and remote addresses that are independent +// of those of any dialed net.PacketConn. +type RedialPacketConn struct { + localAddr net.Addr + remoteAddr net.Addr + dialContext func(context.Context) (net.PacketConn, error) + recvQueue chan []byte + sendQueue chan []byte + closed chan struct{} + closeOnce sync.Once + // The first dial error, which causes the clientPacketConn to be + // closed and is returned from future read/write operations. Compare to + // the rerr and werr in io.Pipe. + err atomic.Value +} + +// NewQueuePacketConn makes a new RedialPacketConn, with the given static local +// and remote addresses, and dialContext function. +func NewRedialPacketConn( + localAddr, remoteAddr net.Addr, + dialContext func(context.Context) (net.PacketConn, error), +) *RedialPacketConn { + c := &RedialPacketConn{ + localAddr: localAddr, + remoteAddr: remoteAddr, + dialContext: dialContext, + recvQueue: make(chan []byte, queueSize), + sendQueue: make(chan []byte, queueSize), + closed: make(chan struct{}), + err: atomic.Value{}, + } + go c.dialLoop() + return c +} + +// dialLoop repeatedly calls c.dialContext and passes the resulting +// net.PacketConn to c.exchange. It returns only when c is closed or dialContext +// returns an error. +func (c *RedialPacketConn) dialLoop() { + ctx, cancel := context.WithCancel(context.Background()) + for { + select { + case <-c.closed: + cancel() + return + default: + } + conn, err := c.dialContext(ctx) + if err != nil { + c.closeWithError(err) + cancel() + return + } + c.exchange(conn) + conn.Close() + } +} + +// exchange calls ReadFrom on the given net.PacketConn and places the resulting +// packets in the receive queue, and takes packets from the send queue and calls +// WriteTo on them, making the current net.PacketConn active. +func (c *RedialPacketConn) exchange(conn net.PacketConn) { + readErrCh := make(chan error) + writeErrCh := make(chan error) + + go func() { + defer close(readErrCh) + for { + select { + case <-c.closed: + return + case <-writeErrCh: + return + default: + } + + var buf [1500]byte + n, _, err := conn.ReadFrom(buf[:]) + if err != nil { + readErrCh <- err + return + } + p := make([]byte, n) + copy(p, buf[:]) + select { + case c.recvQueue <- p: + default: // OK to drop packets. + } + } + }() + + go func() { + defer close(writeErrCh) + for { + select { + case <-c.closed: + return + case <-readErrCh: + return + case p := <-c.sendQueue: + _, err := conn.WriteTo(p, c.remoteAddr) + if err != nil { + writeErrCh <- err + return + } + } + } + }() + + select { + case <-readErrCh: + case <-writeErrCh: + } +} + +// ReadFrom reads a packet from the currently active net.PacketConn. The +// packet's original remote address is replaced with the RedialPacketConn's own +// remote address. +func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + default: + } + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + case buf := <-c.recvQueue: + return copy(p, buf), c.remoteAddr, nil + } +} + +// WriteTo writes a packet to the currently active net.PacketConn. The addr +// argument is ignored and instead replaced with the RedialPacketConn's own +// remote address. +func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + // addr is ignored. + select { + case <-c.closed: + return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + default: + } + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.sendQueue <- buf: + return len(buf), nil + default: + // Drop the outgoing packet if the send queue is full. + return len(buf), nil + } +} + +// closeWithError unblocks pending operations and makes future operations fail +// with the given error. If err is nil, it becomes errClosedPacketConn. +func (c *RedialPacketConn) closeWithError(err error) error { + var once bool + c.closeOnce.Do(func() { + // Store the error to be returned by future read/write + // operations. + if err == nil { + err = errors.New("operation on closed connection") + } + c.err.Store(err) + close(c.closed) + once = true + }) + if !once { + return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + } + return nil +} + +// Close unblocks pending operations and makes future operations fail with a +// "closed connection" error. +func (c *RedialPacketConn) Close() error { + return c.closeWithError(nil) +} + +// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn. +func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr } + +func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented } +func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented } +func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }