From 57fc31a5fc93628661104dc8246755749d5ffe26 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Fri, 2 Aug 2019 23:23:54 +0100 Subject: [PATCH] Add tests --- internal/multiplex/stream_test.go | 125 +++++++++++++++++++++++++++++- 1 file changed, 123 insertions(+), 2 deletions(-) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 014633b..86d4f88 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -2,6 +2,7 @@ package multiplex import ( "bufio" + "bytes" "crypto/aes" "crypto/cipher" "encoding/binary" @@ -119,9 +120,8 @@ func (b *blackhole) SetDeadline(t time.Time) error { return nil } func (b *blackhole) SetReadDeadline(t time.Time) error { return nil } func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil } -const PAYLOAD_LEN = 1 << 20 * 100 - func BenchmarkStream_Write(b *testing.B) { + const PAYLOAD_LEN = 1 << 20 * 100 hole := newBlackHole() sesh := setupSesh() sesh.AddConnection(hole) @@ -141,3 +141,124 @@ func BenchmarkStream_Write(b *testing.B) { b.SetBytes(PAYLOAD_LEN) } } + +func TestStream_Read(t *testing.T) { + sesh := setupSesh() + testPayload := []byte{42, 42, 42} + const PAYLOAD_LEN = 3 + + f := &Frame{ + 1, + 0, + 0, + testPayload, + } + + ch := make(chan []byte) + l, _ := net.Listen("tcp", ":0") + go func() { + conn, _ := net.Dial("tcp", l.Addr().String()) + for { + data := <-ch + _, err := conn.Write(data) + if err != nil { + t.Error("cannot write to connection", err) + } + } + }() + conn, _ := l.Accept() + sesh.AddConnection(conn) + + var streamID uint32 + buf := make([]byte, 10) + t.Run("Plain read", func(t *testing.T) { + f.StreamID = streamID + obfsed, _ := sesh.Obfs(f) + streamID++ + ch <- obfsed + stream, err := sesh.Accept() + if err != nil { + t.Error("failed to accept stream", err) + } + i, err := stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + }) + t.Run("Nil buf", func(t *testing.T) { + f.StreamID = streamID + obfsed, _ := sesh.Obfs(f) + streamID++ + ch <- obfsed + stream, _ := sesh.Accept() + i, err := stream.Read(nil) + if i != 0 || err != nil { + t.Error("expecting", 0, nil, + "got", i, err) + } + + stream.Close() + i, err = stream.Read(nil) + if i != 0 || err != ErrBrokenStream { + t.Error("expecting", 0, ErrBrokenStream, + "got", i, err) + } + + }) + t.Run("Read after stream close", func(t *testing.T) { + f.StreamID = streamID + obfsed, _ := sesh.Obfs(f) + streamID++ + ch <- obfsed + stream, _ := sesh.Accept() + stream.Close() + i, err := stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + _, err = stream.Read(buf) + if err == nil { + t.Error("expecting error", ErrBrokenStream, + "got nil error") + } + }) + t.Run("Read after session close", func(t *testing.T) { + f.StreamID = streamID + obfsed, _ := sesh.Obfs(f) + streamID++ + ch <- obfsed + stream, _ := sesh.Accept() + sesh.Close() + i, err := stream.Read(buf) + if err != nil { + t.Error("failed to read", err) + } + if i != PAYLOAD_LEN { + t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + } + if !bytes.Equal(buf[:i], testPayload) { + t.Error("expected", testPayload, + "got", buf[:i]) + } + _, err = stream.Read(buf) + if err == nil { + t.Error("expecting error", ErrBrokenStream, + "got nil error") + } + }) + +}