diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index 6738893..0679673 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -2,7 +2,9 @@ package multiplex import ( "container/heap" - //"log" + "io" + + log "github.com/sirupsen/logrus" ) // The data is multiplexed through several TCP connections, therefore the @@ -50,73 +52,101 @@ func (sh *sorterHeap) Pop() interface{} { return x } -func (s *Stream) writeNewFrame(f *Frame) { - s.newFrameCh <- f +type frameSorter struct { + nextRecvSeq uint32 + rev int + sh sorterHeap + wrapMode bool + + // New frames are received through newFrameCh by frameSorter + newFrameCh chan *Frame + + output io.WriteCloser +} + +func NewFrameSorter(output io.WriteCloser) *frameSorter { + fs := &frameSorter{ + sh: []*frameNode{}, + newFrameCh: make(chan *Frame, 1024), + rev: 0, + output: output, + } + go fs.recvNewFrame() + return fs +} + +func (fs *frameSorter) writeNewFrame(f *Frame) { + fs.newFrameCh <- f +} +func (fs *frameSorter) Close() error { + fs.newFrameCh <- nil + return nil } // recvNewFrame is a forever running loop which receives frames unordered, // cache and order them and send them into sortedBufCh -func (s *Stream) recvNewFrame() { +func (fs *frameSorter) recvNewFrame() { + defer log.Tracef("a recvNewFrame has returned gracefully") for { - f := <-s.newFrameCh + f := <-fs.newFrameCh if f == nil { return } - // when there's no ooo packages in heap and we receive the next package in order - if len(s.sh) == 0 && f.Seq == s.nextRecvSeq { + // when there'fs no ooo packages in heap and we receive the next package in order + if len(fs.sh) == 0 && f.Seq == fs.nextRecvSeq { if f.Closing == 1 { // empty data indicates closing signal - s.passiveClose() + fs.output.Close() return } else { - s.sortedBuf.Write(f.Payload) - s.nextRecvSeq += 1 - if s.nextRecvSeq == 0 { // getting wrapped - s.rev += 1 - s.wrapMode = false + fs.output.Write(f.Payload) + fs.nextRecvSeq += 1 + if fs.nextRecvSeq == 0 { // getting wrapped + fs.rev += 1 + fs.wrapMode = false } } continue } - fs := &frameNode{ + node := &frameNode{ trueSeq: 0, frame: f, } - if f.Seq < s.nextRecvSeq { + if f.Seq < fs.nextRecvSeq { // For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255 // e.g. we are on rev=0 (wrap has not happened yet) // and we get the order of recv as 253 254 0 1 // after 254, nextN should be 255, but 0 is received and 0 < 255 // now 0 should have a trueSeq of 256 - if !s.wrapMode { + if !fs.wrapMode { // wrapMode is true when the latest seq is wrapped but nextN is not - s.wrapMode = true + fs.wrapMode = true } - fs.trueSeq = uint64(1<<32)*uint64(s.rev+1) + uint64(f.Seq) + 1 + node.trueSeq = uint64(1<<32)*uint64(fs.rev+1) + uint64(f.Seq) + 1 // +1 because wrapped 0 should have trueSeq of 256 instead of 255 // when this bit was run on 1, the trueSeq of 1 would become 256 } else { - fs.trueSeq = uint64(1<<32)*uint64(s.rev) + uint64(f.Seq) + node.trueSeq = uint64(1<<32)*uint64(fs.rev) + uint64(f.Seq) // when this bit was run on 255, the trueSeq of 255 would be 255 } - heap.Push(&s.sh, fs) + heap.Push(&fs.sh, node) // Keep popping from the heap until empty or to the point that the wanted seq was not received - for len(s.sh) > 0 && s.sh[0].frame.Seq == s.nextRecvSeq { - f = heap.Pop(&s.sh).(*frameNode).frame + for len(fs.sh) > 0 && fs.sh[0].frame.Seq == fs.nextRecvSeq { + f = heap.Pop(&fs.sh).(*frameNode).frame if f.Closing == 1 { // empty data indicates closing signal - s.passiveClose() + fs.output.Close() return } else { - s.sortedBuf.Write(f.Payload) - s.nextRecvSeq += 1 - if s.nextRecvSeq == 0 { // getting wrapped - s.rev += 1 - s.wrapMode = false + fs.output.Write(f.Payload) + fs.nextRecvSeq += 1 + if fs.nextRecvSeq == 0 { // getting wrapped + fs.rev += 1 + fs.wrapMode = false } } } diff --git a/internal/multiplex/frameSorter_test.go b/internal/multiplex/frameSorter_test.go index 39c3b99..0524caa 100644 --- a/internal/multiplex/frameSorter_test.go +++ b/internal/multiplex/frameSorter_test.go @@ -1,21 +1,32 @@ package multiplex import ( + "bytes" "encoding/binary" + "time" + //"log" "sort" "testing" ) +type BufferReaderWriterCloser struct { + *bytes.Buffer +} + +func (b *BufferReaderWriterCloser) Close() error { + return nil +} func TestRecvNewFrame(t *testing.T) { inOrder := []uint64{5, 6, 7, 8, 9, 10, 11} outOfOrder0 := []uint64{5, 7, 8, 6, 11, 10, 9} outOfOrder1 := []uint64{1, 96, 47, 2, 29, 18, 60, 8, 74, 22, 82, 58, 44, 51, 57, 71, 90, 94, 68, 83, 61, 91, 39, 97, 85, 63, 46, 73, 54, 84, 76, 98, 93, 79, 75, 50, 67, 37, 92, 99, 42, 77, 17, 16, 38, 3, 100, 24, 31, 7, 36, 40, 86, 64, 34, 45, 12, 5, 9, 27, 21, 26, 35, 6, 65, 69, 53, 4, 48, 28, 30, 56, 32, 11, 80, 66, 25, 41, 78, 13, 88, 62, 15, 70, 49, 43, 72, 23, 10, 55, 52, 95, 14, 59, 87, 33, 19, 20, 81, 89} outOfOrderWrap0 := []uint64{1<<32 - 5, 1<<32 + 3, 1 << 32, 1<<32 - 3, 1<<32 - 4, 1<<32 + 2, 1<<32 - 2, 1<<32 - 1, 1<<32 + 1} - sets := [][]uint64{inOrder, outOfOrder0, outOfOrder1, outOfOrderWrap0} - for _, set := range sets { - stream := makeStream(1, &Session{}) - stream.nextRecvSeq = uint32(set[0]) + + sortedBuf := &BufferReaderWriterCloser{new(bytes.Buffer)} + test := func(set []uint64, ct *testing.T) { + fs := NewFrameSorter(sortedBuf) + fs.nextRecvSeq = uint32(set[0]) for _, n := range set { bu64 := make([]byte, 8) binary.BigEndian.PutUint64(bu64, n) @@ -23,33 +34,49 @@ func TestRecvNewFrame(t *testing.T) { Seq: uint32(n), Payload: bu64, } - stream.writeNewFrame(frame) + fs.writeNewFrame(frame) } - var testSorted []uint32 + time.Sleep(100 * time.Microsecond) + + var sortedResult []uint64 for x := 0; x < len(set); x++ { oct := make([]byte, 8) - stream.sortedBuf.Read(oct) + n, err := sortedBuf.Read(oct) + if n != 8 || err != nil { + ct.Error("failed to read from sorted Buf", n, err) + } //log.Print(p) - testSorted = append(testSorted, uint32(binary.BigEndian.Uint64(oct))) - } - sorted64 := make([]uint64, len(set)) - copy(sorted64, set) - sort.Slice(sorted64, func(i, j int) bool { return sorted64[i] < sorted64[j] }) - sorted32 := make([]uint32, len(set)) - for i, _ := range sorted64 { - sorted32[i] = uint32(sorted64[i]) + sortedResult = append(sortedResult, binary.BigEndian.Uint64(oct)) } + targetSorted := make([]uint64, len(set)) + copy(targetSorted, set) + sort.Slice(targetSorted, func(i, j int) bool { return targetSorted[i] < targetSorted[j] }) - for i, _ := range sorted32 { - if sorted32[i] != testSorted[i] { - t.Error( - "For", set, - "expecting", sorted32, - "got", testSorted, - ) + for i, _ := range targetSorted { + if sortedResult[i] != targetSorted[i] { + goto fail } } - stream.newFrameCh <- nil + fs.Close() + return + fail: + ct.Error( + "expecting", targetSorted, + "got", sortedResult, + ) } + + t.Run("in order", func(t *testing.T) { + test(inOrder, t) + }) + t.Run("out of order0", func(t *testing.T) { + test(outOfOrder0, t) + }) + t.Run("out of order1", func(t *testing.T) { + test(outOfOrder1, t) + }) + t.Run("out of order wrap", func(t *testing.T) { + test(outOfOrderWrap0, t) + }) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 1d17af7..cbc5b7b 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -19,17 +19,10 @@ type Stream struct { session *Session - // Explanations of the following 4 fields can be found in frameSorter.go - nextRecvSeq uint32 - rev int - sh sorterHeap - wrapMode bool - - // New frames are received through newFrameCh by frameSorter - newFrameCh chan *Frame - sortedBuf *bufferedPipe + sorter *frameSorter + // atomic nextSendSeq uint32 @@ -42,15 +35,16 @@ type Stream struct { } func makeStream(id uint32, sesh *Session) *Stream { + buf := NewBufferedPipe() + stream := &Stream{ - id: id, - session: sesh, - sh: []*frameNode{}, - newFrameCh: make(chan *Frame, 1024), - sortedBuf: NewBufferedPipe(), - obfsBuf: make([]byte, 17000), + id: id, + session: sesh, + sortedBuf: buf, + obfsBuf: make([]byte, 17000), + sorter: NewFrameSorter(buf), } - go stream.recvNewFrame() + log.Tracef("stream %v opened", id) return stream } @@ -108,7 +102,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { // the necessary steps to mark the stream as closed and to release resources func (s *Stream) _close() { atomic.StoreUint32(&s.closed, 1) - s.newFrameCh <- nil // this will trigger frameSorter to return + s.sorter.Close() // this will trigger frameSorter to return s.sortedBuf.Close() } diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index d43720a..a0788b7 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -148,7 +148,7 @@ func (sb *switchboard) deplex(ce *connEnclave) { // (this happens when ss-server and ss-local closes the stream // simutaneously), we don't do anything if stream != nil { - stream.writeNewFrame(frame) + stream.sorter.writeNewFrame(frame) } } }