From 3699bbda1633b17eb5fae9ced6158df42fe1384b Mon Sep 17 00:00:00 2001
From: David Fifield <david@bamsoftware.com>
Date: Sat, 10 Jun 2017 17:26:13 -0700
Subject: [PATCH] Queue writes through an independent write scheduler.

Send padding packets where there is no application data waiting.

The one place this doesn't work is in the client, after sending the
client handshake and before receiving the Y' and AUTH portions of the
server reply. During this time the client doesn't yet know the session
key and so cannot send anything. This version waits even longer, until
the entire server handshake has been received (including P_S, M_S, and
MAC_S).

This is set up to do sends in a fixed size of 500 bytes at a fixed rate
of 10 sends per second. As a hack, this rounds the client handshake size
to a multiple of 500 bytes, so that it doesn't stall waiting for the
final full chunk to be available to send.
---
 transports/obfs4/handshake_ntor.go |   8 ++-
 transports/obfs4/obfs4.go          | 122 ++++++++++++++++++++++++++++++++-----
 2 files changed, 114 insertions(+), 16 deletions(-)

diff --git a/transports/obfs4/handshake_ntor.go b/transports/obfs4/handshake_ntor.go
index ee1bca8..fb5935a 100644
--- a/transports/obfs4/handshake_ntor.go
+++ b/transports/obfs4/handshake_ntor.go
@@ -127,7 +127,13 @@ func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, ses
 	hs.keypair = sessionKey
 	hs.nodeID = nodeID
 	hs.serverIdentity = serverIdentity
-	hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)
+	padLen := csrand.IntRange(clientMinPadLength, clientMaxPadLength)
+	// Hack: round total handshake size to a multiple of 500 for CBR mode.
+	padLen = padLen - (padLen + clientMinHandshakeLength) % 500
+	if padLen < 0 {
+		padLen += 500
+	}
+	hs.padLen = padLen
 	hs.mac = hmac.New(sha256.New, append(hs.serverIdentity.Bytes()[:], hs.nodeID.Bytes()[:]...))
 
 	return hs
diff --git a/transports/obfs4/obfs4.go b/transports/obfs4/obfs4.go
index 304097e..7210210 100644
--- a/transports/obfs4/obfs4.go
+++ b/transports/obfs4/obfs4.go
@@ -34,6 +34,7 @@ import (
 	"crypto/sha256"
 	"flag"
 	"fmt"
+	"io"
 	"math/rand"
 	"net"
 	"strconv"
@@ -265,7 +266,13 @@ func (sf *obfs4ServerFactory) WrapConn(conn net.Conn) (net.Conn, error) {
 		iatDist = probdist.New(sf.iatSeed, 0, maxIATDelay, biasedDist)
 	}
 
-	c := &obfs4Conn{conn, true, lenDist, iatDist, sf.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil}
+	c := &obfs4Conn{conn, true, lenDist, iatDist, sf.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), make(chan []byte), nil, nil, make(chan bool)}
+
+	ws := newWriteScheduler(c)
+	go func() {
+		ws.run()
+		c.Close()
+	}()
 
 	startTime := time.Now()
 
@@ -290,8 +297,11 @@ type obfs4Conn struct {
 	receiveDecodedBuffer *bytes.Buffer
 	readBuffer           []byte
 
-	encoder *framing.Encoder
-	decoder *framing.Decoder
+	writeQueue chan []byte
+
+	encoder         *framing.Encoder
+	decoder         *framing.Decoder
+	codersReadyChan chan bool
 }
 
 func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err error) {
@@ -312,7 +322,13 @@ func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err
 	}
 
 	// Allocate the client structure.
-	c = &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil}
+	c = &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), make(chan []byte), nil, nil, make(chan bool)}
+
+	ws := newWriteScheduler(c)
+	go func() {
+		ws.run()
+		c.Close()
+	}()
 
 	// Start the handshake timeout.
 	deadline := time.Now().Add(clientHandshakeTimeout)
@@ -343,9 +359,7 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto
 	if err != nil {
 		return err
 	}
-	if _, err = conn.Conn.Write(blob); err != nil {
-		return err
-	}
+	conn.writeQueue <- blob
 
 	// Consume the server handshake.
 	var hsBuf [maxHandshakeLength]byte
@@ -370,6 +384,8 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto
 		okm := ntor.Kdf(seed, framing.KeyLength*2)
 		conn.encoder = framing.NewEncoder(okm[:framing.KeyLength])
 		conn.decoder = framing.NewDecoder(okm[framing.KeyLength:])
+		// Signal to the gimmeData function that our encoder is ready.
+		close(conn.codersReadyChan)
 
 		return nil
 	}
@@ -440,9 +456,11 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.
 	if err := conn.makePacket(&frameBuf, packetTypePrngSeed, sf.lenSeed.Bytes()[:], 0); err != nil {
 		return err
 	}
-	if _, err = conn.Conn.Write(frameBuf.Bytes()); err != nil {
-		return err
-	}
+	conn.writeQueue <- frameBuf.Bytes()
+
+	// Signal to the gimmeData function that our encoder is ready.
+	// Need to do this *after* writing the handshake to the writeQueue.
+	close(conn.codersReadyChan)
 
 	return nil
 }
@@ -557,14 +575,13 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
 			iatDelta := time.Duration(conn.iatDist.Sample() * 100)
 
 			// Write then sleep.
-			_, err = conn.Conn.Write(iatFrame[:iatWrLen])
-			if err != nil {
-				return 0, err
-			}
+			tmpBuf := make([]byte, iatWrLen)
+			copy(tmpBuf, iatFrame[:iatWrLen])
+			conn.writeQueue <- tmpBuf
 			time.Sleep(iatDelta * time.Microsecond)
 		}
 	} else {
-		_, err = conn.Conn.Write(frameBuf.Bytes())
+		conn.writeQueue <- frameBuf.Bytes()
 	}
 
 	return
@@ -637,6 +654,81 @@ func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) (err error) {
 	return
 }
 
+func (conn *obfs4Conn) gimmeData() ([]byte, error) {
+	// Poll the writeQueue.
+	select {
+	case chunk := <-conn.writeQueue:
+		return chunk, nil
+	default:
+	}
+	// Otherwise wait until the session key is ready. Keep checking the
+	// writeQueue in case something come in meanwhile.
+	select {
+	case chunk := <-conn.writeQueue:
+		return chunk, nil
+	case <-conn.codersReadyChan:
+	}
+	// No actual data to send, but we have a session key, so send a padding
+	// frame. The exact size doesn't matter much.
+	var frameBuf bytes.Buffer
+	err := conn.makePacket(&frameBuf, packetTypePayload, nil, 1024)
+	if err != nil {
+		return nil, err
+	}
+	return frameBuf.Bytes(), nil
+}
+
+type writeScheduler struct {
+	obfs4 *obfs4Conn
+	buf   bytes.Buffer
+}
+
+func newWriteScheduler(obfs4 *obfs4Conn) *writeScheduler {
+	var ws writeScheduler
+	ws.obfs4 = obfs4
+	return &ws
+}
+
+func (ws *writeScheduler) Read(b []byte) (n int, err error) {
+	for {
+		if ws.buf.Len() > 0 {
+			return ws.buf.Read(b)
+		}
+		ws.buf.Reset()
+		data, err := ws.obfs4.gimmeData()
+		if err != nil {
+			return 0, err
+		}
+		ws.buf.Write(data)
+	}
+}
+
+func (ws *writeScheduler) run() error {
+	var buf [500]byte
+	sched := time.Now()
+	for {
+		n, err := io.ReadFull(ws, buf[:])
+		_, err2 := ws.obfs4.Conn.Write(buf[:n])
+		if err2 != nil {
+			return err2
+		}
+		if err == io.EOF || err == io.ErrUnexpectedEOF {
+			break
+		} else if err != nil {
+			return err
+		}
+
+		now := time.Now()
+		sched = sched.Add(100 * time.Millisecond)
+		if sched.Before(now) {
+			sched = now
+		} else {
+			time.Sleep(sched.Sub(now))
+		}
+	}
+	return nil
+}
+
 func init() {
 	flag.BoolVar(&biasedDist, biasCmdArg, false, "Enable obfs4 using ScrambleSuit style table generation")
 }
-- 
2.11.0

