diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index 0679673..236f3fe 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -86,6 +86,7 @@ func (fs *frameSorter) Close() error { // recvNewFrame is a forever running loop which receives frames unordered, // cache and order them and send them into sortedBufCh func (fs *frameSorter) recvNewFrame() { + // TODO: add timeout defer log.Tracef("a recvNewFrame has returned gracefully") for { f := <-fs.newFrameCh diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index bec8c5f..8e02fb5 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -67,7 +67,8 @@ func (sb *switchboard) removeConn(connId uint32) { } // a pointer to connId is passed here so that the switchboard can reassign it -func (sb *switchboard) send(data []byte, connId *uint32) (int, error) { +func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { + sb.Valve.rxWait(len(data)) sb.connsM.RLock() defer sb.connsM.RUnlock() if sb.strategy == UNIFORM_SPREAD { @@ -76,13 +77,17 @@ func (sb *switchboard) send(data []byte, connId *uint32) (int, error) { if !ok { return 0, errBrokenSwitchboard } else { - return conn.Write(data) + n, err = conn.Write(data) + sb.AddTx(int64(n)) + return } } else { var conn net.Conn conn, ok := sb.conns[*connId] if ok { - return conn.Write(data) + n, err = conn.Write(data) + sb.AddTx(int64(n)) + return } else { // do not call assignRandomConn() here. // we'll have to do connsM.RLock() after we get a new connId from assignRandomConn, in order to @@ -99,7 +104,9 @@ func (sb *switchboard) send(data []byte, connId *uint32) (int, error) { if !ok { return 0, errBrokenSwitchboard } else { - return conn.Write(data) + n, err = conn.Write(data) + sb.AddTx(int64(n)) + return } } } diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 756cc83..49bc4cd 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -1,6 +1,7 @@ package multiplex import ( + "github.com/cbeuw/Cloak/internal/util" "math/rand" "testing" ) @@ -9,17 +10,13 @@ func BenchmarkSwitchboard_Send(b *testing.B) { seshConfig := &SessionConfig{ Obfuscator: nil, Valve: nil, - UnitRead: nil, + UnitRead: util.ReadTLS, } sesh := MakeSession(0, seshConfig) - sbConfig := &switchboardConfig{ - Valve: UNLIMITED_VALVE, - strategy: FIXED_CONN_MAPPING, - } - sb := makeSwitchboard(sesh, sbConfig) + hole := newBlackHole() - sb.addConn(hole) - connId, err := sb.assignRandomConn() + sesh.sb.addConn(hole) + connId, err := sesh.sb.assignRandomConn() if err != nil { b.Error("failed to get a random conn", err) return @@ -28,7 +25,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { rand.Read(data) b.ResetTimer() for i := 0; i < b.N; i++ { - n, err := sb.send(data, &connId) + n, err := sesh.sb.send(data, &connId) if err != nil { b.Error(err) return @@ -36,3 +33,54 @@ func BenchmarkSwitchboard_Send(b *testing.B) { b.SetBytes(int64(n)) } } + +func TestSwitchboard_TxCredit(t *testing.T) { + seshConfig := &SessionConfig{ + Obfuscator: nil, + Valve: MakeValve(1<<20, 1<<20), + UnitRead: util.ReadTLS, + } + sesh := MakeSession(0, seshConfig) + hole := newBlackHole() + sesh.sb.addConn(hole) + connId, err := sesh.sb.assignRandomConn() + if err != nil { + t.Error("failed to get a random conn", err) + return + } + data := make([]byte, 1000) + rand.Read(data) + + t.Run("FIXED CONN MAPPING", func(t *testing.T) { + *sesh.sb.Valve.(*LimitedValve).tx = 0 + sesh.sb.strategy = FIXED_CONN_MAPPING + n, err := sesh.sb.send(data[:10], &connId) + if err != nil { + t.Error(err) + return + } + if n != 10 { + t.Errorf("wanted to send %v, got %v", 10, n) + return + } + if *sesh.sb.Valve.(*LimitedValve).tx != 10 { + t.Error("tx credit didn't increase by 10") + } + }) + t.Run("UNIFORM", func(t *testing.T) { + *sesh.sb.Valve.(*LimitedValve).tx = 0 + sesh.sb.strategy = UNIFORM_SPREAD + n, err := sesh.sb.send(data[:10], &connId) + if err != nil { + t.Error(err) + return + } + if n != 10 { + t.Errorf("wanted to send %v, got %v", 10, n) + return + } + if *sesh.sb.Valve.(*LimitedValve).tx != 10 { + t.Error("tx credit didn't increase by 10") + } + }) +}