This is an automated email from the git hooks/post-receive script.
shelikhoo pushed a commit to branch main in repository pluggable-transports/snowflake.
commit 3132f680122e27bb9cfb957fbb29c3cbe73935cf Author: Shelikhoo xiaokangwang@outlook.com AuthorDate: Wed Feb 16 11:11:37 2022 +0000
Add connection expire time for uTLS pendingConn --- common/utls/roundtripper.go | 47 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-)
diff --git a/common/utls/roundtripper.go b/common/utls/roundtripper.go index e2fc82b..df31ff4 100644 --- a/common/utls/roundtripper.go +++ b/common/utls/roundtripper.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "sync" + "time"
utls "github.com/refraction-networking/utls" "golang.org/x/net/http2" @@ -19,7 +20,7 @@ func NewUTLSHTTPRoundTripper(clientHelloID utls.ClientHelloID, uTlsConfig *utls. config: uTlsConfig, connectWithH1: map[string]bool{}, backdropTransport: backdropTransport, - pendingConn: map[pendingConnKey]net.Conn{}, + pendingConn: map[pendingConnKey]*unclaimedConnection{}, removeSNI: removeSNI, } rtImpl.init() @@ -38,7 +39,7 @@ type uTLSHTTPRoundTripperImpl struct { backdropTransport http.RoundTripper
accessDialingConnection sync.Mutex - pendingConn map[pendingConnKey]net.Conn + pendingConn map[pendingConnKey]*unclaimedConnection
removeSNI bool } @@ -50,6 +51,7 @@ type pendingConnKey struct {
var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN") var errEAGAINTooMany = errors.New("incorrect ALPN negotiated") +var errExpired = errors.New("connection have expired")
func (r *uTLSHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) { if req.URL.Scheme != "https" { @@ -99,12 +101,15 @@ func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
func (r *uTLSHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) { connId := getPendingConnectionID(addr, alpnIsH2) - r.pendingConn[connId] = conn + r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute) } func (r *uTLSHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn { connId := getPendingConnectionID(addr, alpnIsH2) if conn, ok := r.pendingConn[connId]; ok { - return conn + delete(r.pendingConn, connId) + if claimedConnection, err := conn.claimConnection(); err == nil { + return claimedConnection + } } return nil } @@ -189,3 +194,37 @@ func (r *uTLSHTTPRoundTripperImpl) init() { }, } } + +func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection { + c := &unclaimedConnection{ + Conn: conn, + } + time.AfterFunc(expireTime, c.tick) + return c +} + +type unclaimedConnection struct { + net.Conn + claimed bool + access sync.Mutex +} + +func (c *unclaimedConnection) claimConnection() (net.Conn, error) { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + return c.Conn, nil + } + return nil, errExpired +} + +func (c *unclaimedConnection) tick() { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + c.Conn.Close() + c.Conn = nil + } +}