From e33afb258aba989deaddc2e80cff46e0c20dc576 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Mon, 16 Mar 2020 11:37:09 +0000 Subject: [PATCH] extract util testing function --- internal/multiplex/stream_test.go | 63 ++++++---------------------- internal/server/websocketAux_test.go | 58 +++++++++++++++++++++++++ internal/util/util.go | 19 +++++++++ 3 files changed, 90 insertions(+), 50 deletions(-) create mode 100644 internal/server/websocketAux_test.go diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 25ecf8c..26d66f7 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -127,24 +127,12 @@ func TestStream_Close(t *testing.T) { 0, testPayload, } - ch := make(chan []byte) - l, _ := net.Listen("tcp", "127.0.0.1:0") - go func() { - conn, _ := net.Dial("tcp", l.Addr().String()) - for { - data := <-ch - _, err := conn.Write(data) - if err != nil { - t.Error("cannot write to connection", err) - return - } - } - }() - conn, _ := l.Accept() + + conn, writingEnd := util.GetMockConn() sesh.AddConnection(conn) obfsBuf := make([]byte, 512) i, _ := sesh.Obfs(f, obfsBuf) - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, err := sesh.Accept() if err != nil { @@ -175,20 +163,7 @@ func TestStream_Read(t *testing.T) { testPayload, } - ch := make(chan []byte) - l, _ := net.Listen("tcp", "127.0.0.1:0") - go func() { - conn, _ := net.Dial("tcp", l.Addr().String()) - for { - data := <-ch - _, err := conn.Write(data) - if err != nil { - t.Error("cannot write to connection", err) - return - } - } - }() - conn, _ := l.Accept() + conn, writingEnd := util.GetMockConn() sesh.AddConnection(conn) var streamID uint32 @@ -199,7 +174,7 @@ func TestStream_Read(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, err := sesh.Accept() if err != nil { @@ -225,7 +200,7 @@ func TestStream_Read(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() i, err := stream.Read(nil) @@ -246,7 +221,7 @@ func TestStream_Read(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() stream.Close() @@ -271,7 +246,7 @@ func TestStream_Read(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() sesh.Close() @@ -307,19 +282,7 @@ func TestStream_UnorderedRead(t *testing.T) { testPayload, } - ch := make(chan []byte) - l, _ := net.Listen("tcp", "127.0.0.1:0") - go func() { - conn, _ := net.Dial("tcp", l.Addr().String()) - for { - data := <-ch - _, err := conn.Write(data) - if err != nil { - t.Error("cannot write to connection", err) - } - } - }() - conn, _ := l.Accept() + conn, writingEnd := util.GetMockConn() sesh.AddConnection(conn) var streamID uint32 @@ -330,7 +293,7 @@ func TestStream_UnorderedRead(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, err := sesh.Accept() if err != nil { @@ -352,7 +315,7 @@ func TestStream_UnorderedRead(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() i, err := stream.Read(nil) @@ -373,7 +336,7 @@ func TestStream_UnorderedRead(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() stream.Close() @@ -398,7 +361,7 @@ func TestStream_UnorderedRead(t *testing.T) { f.StreamID = streamID i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsBuf[:i] + writingEnd <- obfsBuf[:i] time.Sleep(100 * time.Microsecond) stream, _ := sesh.Accept() sesh.Close() diff --git a/internal/server/websocketAux_test.go b/internal/server/websocketAux_test.go new file mode 100644 index 0000000..eb68c5a --- /dev/null +++ b/internal/server/websocketAux_test.go @@ -0,0 +1,58 @@ +package server + +import ( + "bytes" + "github.com/cbeuw/Cloak/internal/util" + "testing" +) + +func TestFirstBuffedConn_Read(t *testing.T) { + mockConn, writingEnd := util.GetMockConn() + + expectedFirstPacket := []byte{1, 2, 3} + firstBuffedConn := &firstBuffedConn{ + Conn: mockConn, + firstRead: false, + firstPacket: expectedFirstPacket, + } + + buf := make([]byte, 1024) + n ,err :=firstBuffedConn.Read(buf) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(expectedFirstPacket, buf[:n]){ + t.Error("first read doesn't produce given packet") + return + } + + expectedSecondPacket := []byte{4,5,6,7} + writingEnd <- expectedSecondPacket + n ,err =firstBuffedConn.Read(buf) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(expectedSecondPacket, buf[:n]){ + t.Error("second read doesn't produce subsequently written packet") + return + } +} + +func TestWsAcceptor(t *testing.T){ + mockConn, _ := util.GetMockConn() + expectedFirstPacket := []byte{1, 2, 3} + + wsAcceptor:=newWsAcceptor(mockConn, expectedFirstPacket) + _,err := wsAcceptor.Accept() + if err != nil { + t.Error(err) + return + } + + _,err = wsAcceptor.Accept() + if err == nil{ + t.Error("accepting second time doesn't return error") + } +} \ No newline at end of file diff --git a/internal/util/util.go b/internal/util/util.go index f45ab8d..f871ef9 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/binary" "errors" + "fmt" "io" "net" "strconv" @@ -136,3 +137,21 @@ func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) { } } } + +func GetMockConn() (net.Conn, chan []byte) { + ch := make(chan []byte) + l, _ := net.Listen("tcp", "127.0.0.1:0") + go func() { + conn, _ := net.Dial("tcp", l.Addr().String()) + for { + data := <-ch + _, err := conn.Write(data) + if err != nil { + fmt.Println("cannot write to connection", err) + } + } + }() + conn, _ := l.Accept() + + return conn, ch +}