diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 66f7036..b302f95 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -70,11 +70,7 @@ func (s *Stream) writeFrame(frame Frame) error { func (s *Stream) Read(buf []byte) (n int, err error) { //log.Tracef("attempting to read from stream %v", s.id) if len(buf) == 0 { - if s.isClosed() { - return 0, ErrBrokenStream - } else { - return 0, nil - } + return 0, nil } n, err = s.recvBuf.Read(buf) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 435956b..ccb49db 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -3,6 +3,7 @@ package multiplex import ( "bytes" "github.com/cbeuw/Cloak/internal/util" + "io" "math/rand" "net" "testing" @@ -121,10 +122,9 @@ func TestStream_Write(t *testing.T) { func TestStream_Close(t *testing.T) { sesh := setupSesh(false) testPayload := []byte{42, 42, 42} - streamID := uint32(1) f := &Frame{ - streamID, + 1, 0, 0, testPayload, @@ -147,10 +147,19 @@ func TestStream_Close(t *testing.T) { return } - if sI, _ := sesh.streams.Load(streamID); sI != nil { + if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { t.Error("stream still exists") return } + + readBuf := make([]byte, len(testPayload)) + _, err = io.ReadFull(stream, readBuf) + if err != nil { + t.Errorf("can't read residual data %v", err) + } + if !bytes.Equal(readBuf, testPayload) { + t.Errorf("read wrong data") + } } func TestStream_Read(t *testing.T) { @@ -210,14 +219,6 @@ func TestStream_Read(t *testing.T) { t.Error("expecting", 0, nil, "got", i, err) } - - stream.Close() - i, err = stream.Read(nil) - if i != 0 || err != ErrBrokenStream { - t.Error("expecting", 0, ErrBrokenStream, - "got", i, err) - } - }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID @@ -325,14 +326,6 @@ func TestStream_UnorderedRead(t *testing.T) { t.Error("expecting", 0, nil, "got", i, err) } - - stream.Close() - i, err = stream.Read(nil) - if i != 0 || err != ErrBrokenStream { - t.Error("expecting", 0, ErrBrokenStream, - "got", i, err) - } - }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID