diff --git a/internal/multiplex/streamBuffer_test.go b/internal/multiplex/streamBuffer_test.go index 1df21eb..20ab571 100644 --- a/internal/multiplex/streamBuffer_test.go +++ b/internal/multiplex/streamBuffer_test.go @@ -2,6 +2,7 @@ package multiplex import ( "encoding/binary" + "io" "time" //"log" @@ -72,3 +73,23 @@ func TestRecvNewFrame(t *testing.T) { test(outOfOrder2, t) }) } + +func TestStreamBuffer_RecvThenClose(t *testing.T) { + const testDataLen = 128 + sb := NewStreamBuffer() + testData := make([]byte, testDataLen) + testFrame := Frame{ + StreamID: 0, + Seq: 0, + Closing: 0, + Payload: testData, + } + sb.Write(testFrame) + sb.Close() + + readBuf := make([]byte, testDataLen) + _, err := io.ReadFull(sb, readBuf) + if err != nil { + t.Error(err) + } +} diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index a6f10d3..9e53197 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -2,6 +2,7 @@ package multiplex import ( "bytes" + "github.com/cbeuw/Cloak/internal/common" "io" "math/rand" "net" @@ -11,10 +12,10 @@ import ( "github.com/cbeuw/connutil" ) -func setupSesh(unordered bool) *Session { - var sessionKey [32]byte - rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(0x00, sessionKey) +const payloadLen = 1000 + +func setupSesh(unordered bool, key [32]byte) *Session { + obfuscator, _ := MakeObfuscator(0x00, key) seshConfig := SessionConfig{ Obfuscator: obfuscator, @@ -25,11 +26,12 @@ func setupSesh(unordered bool) *Session { } func BenchmarkStream_Write_Ordered(b *testing.B) { - const PAYLOAD_LEN = 1000 hole := connutil.Discard() - sesh := setupSesh(false) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) sesh.AddConnection(hole) - testData := make([]byte, PAYLOAD_LEN) + testData := make([]byte, payloadLen) rand.Read(testData) stream, _ := sesh.OpenStream() @@ -42,14 +44,15 @@ func BenchmarkStream_Write_Ordered(b *testing.B) { "got", err, ) } - b.SetBytes(PAYLOAD_LEN) + b.SetBytes(payloadLen) } } func BenchmarkStream_Read_Ordered(b *testing.B) { - sesh := setupSesh(false) - const PAYLOAD_LEN = 1000 - testPayload := make([]byte, PAYLOAD_LEN) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) + testPayload := make([]byte, payloadLen) rand.Read(testPayload) f := &Frame{ @@ -84,7 +87,7 @@ func BenchmarkStream_Read_Ordered(b *testing.B) { //time.Sleep(5*time.Second) // wait for buffer to fill up - readBuf := make([]byte, PAYLOAD_LEN) + readBuf := make([]byte, payloadLen) b.ResetTimer() for j := 0; j < b.N; j++ { n, err := stream.Read(readBuf) @@ -100,11 +103,12 @@ func BenchmarkStream_Read_Ordered(b *testing.B) { } func TestStream_Write(t *testing.T) { - const PAYLOAD_LEN = 1000 hole := connutil.Discard() - sesh := setupSesh(false) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) sesh.AddConnection(hole) - testData := make([]byte, PAYLOAD_LEN) + testData := make([]byte, payloadLen) rand.Read(testData) stream, _ := sesh.OpenStream() @@ -117,8 +121,57 @@ func TestStream_Write(t *testing.T) { } } +func TestStream_WriteSync(t *testing.T) { + // Close calls made after write MUST have a higher seq + var sessionKey [32]byte + rand.Read(sessionKey[:]) + clientSesh := setupSesh(false, sessionKey) + serverSesh := setupSesh(false, sessionKey) + w, r := connutil.AsyncPipe() + clientSesh.AddConnection(&common.TLSConn{Conn: w}) + serverSesh.AddConnection(&common.TLSConn{Conn: r}) + testData := make([]byte, payloadLen) + rand.Read(testData) + + t.Run("test single", func(t *testing.T) { + go func() { + stream, _ := clientSesh.OpenStream() + stream.Write(testData) + stream.Close() + }() + + recvBuf := make([]byte, payloadLen) + serverStream, _ := serverSesh.Accept() + _, err := io.ReadFull(serverStream, recvBuf) + if err != nil { + t.Error(err) + } + }) + + t.Run("test multiple", func(t *testing.T) { + const numStreams = 100 + for i := 0; i < numStreams; i++ { + go func() { + stream, _ := clientSesh.OpenStream() + stream.Write(testData) + stream.Close() + }() + } + for i := 0; i < numStreams; i++ { + recvBuf := make([]byte, payloadLen) + serverStream, _ := serverSesh.Accept() + _, err := io.ReadFull(serverStream, recvBuf) + if err != nil { + t.Error(err) + } + } + }) +} + func TestStream_Close(t *testing.T) { - sesh := setupSesh(false) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) testPayload := []byte{42, 42, 42} f := &Frame{ @@ -161,9 +214,11 @@ func TestStream_Close(t *testing.T) { } func TestStream_Read(t *testing.T) { - sesh := setupSesh(false) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) testPayload := []byte{42, 42, 42} - const PAYLOAD_LEN = 3 + const smallPayloadLen = 3 f := &Frame{ 1, @@ -195,8 +250,8 @@ func TestStream_Read(t *testing.T) { t.Error("failed to read", err) return } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) return } if !bytes.Equal(buf[:i], testPayload) { @@ -230,8 +285,8 @@ func TestStream_Read(t *testing.T) { if err != nil { t.Error("failed to read", err) } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, @@ -255,8 +310,8 @@ func TestStream_Read(t *testing.T) { if err != nil { t.Error("failed to read", err) } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, @@ -272,9 +327,11 @@ func TestStream_Read(t *testing.T) { } func TestStream_UnorderedRead(t *testing.T) { - sesh := setupSesh(true) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) testPayload := []byte{42, 42, 42} - const PAYLOAD_LEN = 3 + const smallPayloadLen = 3 f := &Frame{ 1, @@ -304,8 +361,8 @@ func TestStream_UnorderedRead(t *testing.T) { if err != nil { t.Error("failed to read", err) } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, @@ -337,8 +394,8 @@ func TestStream_UnorderedRead(t *testing.T) { if err != nil { t.Error("failed to read", err) } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, @@ -362,8 +419,8 @@ func TestStream_UnorderedRead(t *testing.T) { if err != nil { t.Error("failed to read", err) } - if i != PAYLOAD_LEN { - t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + if i != smallPayloadLen { + t.Errorf("expected read %v, got %v", smallPayloadLen, i) } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index af14c1d..27595ba 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -137,7 +137,9 @@ func TestSwitchboard_TxCredit(t *testing.T) { } func TestSwitchboard_CloseOnOneDisconn(t *testing.T) { - sesh := setupSesh(false) + var sessionKey [32]byte + rand.Read(sessionKey[:]) + sesh := setupSesh(false, sessionKey) conn0client, conn0server := connutil.AsyncPipe() sesh.AddConnection(conn0client)