diff --git a/internal/multiplex/blackhole_test.go b/internal/multiplex/blackhole_test.go new file mode 100644 index 0000000..7d34215 --- /dev/null +++ b/internal/multiplex/blackhole_test.go @@ -0,0 +1,41 @@ +package multiplex + +import ( + "bufio" + "io" + "io/ioutil" + "net" + "time" +) + +type blackhole struct { + hole *bufio.Writer + closer chan int +} + +func newBlackHole() *blackhole { + return &blackhole{ + hole: bufio.NewWriter(ioutil.Discard), + closer: make(chan int), + } +} +func (b *blackhole) Read([]byte) (int, error) { + <-b.closer + return 0, io.EOF +} +func (b *blackhole) Write(in []byte) (int, error) { return b.hole.Write(in) } +func (b *blackhole) Close() error { + b.closer <- 1 + return nil +} +func (b *blackhole) LocalAddr() net.Addr { + ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + return ret +} +func (b *blackhole) RemoteAddr() net.Addr { + ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + return ret +} +func (b *blackhole) SetDeadline(t time.Time) error { return nil } +func (b *blackhole) SetReadDeadline(t time.Time) error { return nil } +func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 1592fed..3426df3 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -1,10 +1,8 @@ package multiplex import ( - "bufio" "bytes" "github.com/cbeuw/Cloak/internal/util" - "io/ioutil" "math/rand" "net" "testing" @@ -25,29 +23,6 @@ func setupSesh(unordered bool) *Session { return MakeSession(0, seshConfig) } -type blackhole struct { - hole *bufio.Writer -} - -func newBlackHole() *blackhole { return &blackhole{hole: bufio.NewWriter(ioutil.Discard)} } -func (b *blackhole) Read([]byte) (int, error) { - time.Sleep(1 * time.Hour) - return 0, nil -} -func (b *blackhole) Write(in []byte) (int, error) { return b.hole.Write(in) } -func (b *blackhole) Close() error { return nil } -func (b *blackhole) LocalAddr() net.Addr { - ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") - return ret -} -func (b *blackhole) RemoteAddr() net.Addr { - ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") - return ret -} -func (b *blackhole) SetDeadline(t time.Time) error { return nil } -func (b *blackhole) SetReadDeadline(t time.Time) error { return nil } -func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil } - func BenchmarkStream_Write_Ordered(b *testing.B) { const PAYLOAD_LEN = 1000 hole := newBlackHole() diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index b711a3c..837b907 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -72,15 +72,22 @@ func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { sb.connsM.RLock() defer sb.connsM.RUnlock() if sb.strategy == UNIFORM_SPREAD { - randConnId := rand.Intn(len(sb.conns)) - conn, ok := sb.conns[uint32(randConnId)] - if !ok { + if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 { return 0, errBrokenSwitchboard - } else { - n, err = conn.Write(data) - sb.AddTx(int64(n)) - return } + + r := rand.Intn(len(sb.conns)) + var c int + for newConnId := range sb.conns { + if r == c { + conn, _ := sb.conns[newConnId] + n, err = conn.Write(data) + sb.AddTx(int64(n)) + return + } + c++ + } + return 0, errBrokenSwitchboard } else { var conn net.Conn conn, ok := sb.conns[*connId] diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 49bc4cd..a1e5926 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -6,15 +6,79 @@ import ( "testing" ) +func TestSwitchboard_Send(t *testing.T) { + doTest := func(seshConfig *SessionConfig) { + sesh := MakeSession(0, seshConfig) + hole0 := newBlackHole() + sesh.sb.addConn(hole0) + connId, err := sesh.sb.assignRandomConn() + if err != nil { + t.Error("failed to get a random conn", err) + return + } + data := make([]byte, 1000) + rand.Read(data) + _, err = sesh.sb.send(data, &connId) + if err != nil { + t.Error(err) + return + } + + hole1 := newBlackHole() + sesh.sb.addConn(hole1) + connId, err = sesh.sb.assignRandomConn() + if err != nil { + t.Error("failed to get a random conn", err) + return + } + _, err = sesh.sb.send(data, &connId) + if err != nil { + t.Error(err) + return + } + + hole0.Close() + + connId, err = sesh.sb.assignRandomConn() + if err != nil { + t.Error("failed to get a random conn", err) + return + } + _, err = sesh.sb.send(data, &connId) + if err != nil { + t.Error(err) + return + } + } + + t.Run("Ordered", func(t *testing.T) { + seshConfig := &SessionConfig{ + Obfuscator: nil, + Valve: nil, + UnitRead: util.ReadTLS, + Unordered: false, + } + doTest(seshConfig) + }) + t.Run("Unordered", func(t *testing.T) { + seshConfig := &SessionConfig{ + Obfuscator: nil, + Valve: nil, + UnitRead: util.ReadTLS, + Unordered: true, + } + doTest(seshConfig) + }) +} + func BenchmarkSwitchboard_Send(b *testing.B) { + hole := newBlackHole() seshConfig := &SessionConfig{ Obfuscator: nil, Valve: nil, UnitRead: util.ReadTLS, } sesh := MakeSession(0, seshConfig) - - hole := newBlackHole() sesh.sb.addConn(hole) connId, err := sesh.sb.assignRandomConn() if err != nil {