package multiplex import ( "errors" "io" "net" "time" log "github.com/sirupsen/logrus" "sync" "sync/atomic" ) var ErrBrokenStream = errors.New("broken stream") type Stream struct { id uint32 session *Session recvBuf recvBuffer // atomic nextSendSeq uint64 writingM sync.Mutex // atomic closed uint32 // only alloc when writing to the stream allocIdempot sync.Once obfsBuf []byte // we assign each stream a fixed underlying TCP connection to utilise order guarantee provided by TCP itself // so that frameSorter should have few to none ooo frames to deal with // overall the streams in a session should be uniformly distributed across all connections // This is not used in unordered connection mode assignedConnId uint32 } func makeStream(sesh *Session, id uint32) *Stream { var recvBuf recvBuffer if sesh.Unordered { recvBuf = NewDatagramBuffer() } else { recvBuf = NewStreamBuffer() } stream := &Stream{ id: id, session: sesh, recvBuf: recvBuf, } return stream } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) writeFrame(frame Frame) error { toBeClosed, err := s.recvBuf.Write(frame) if toBeClosed { return s.passiveClose() } return err } // Read implements io.Read func (s *Stream) Read(buf []byte) (n int, err error) { //log.Tracef("attempting to read from stream %v", s.id) if len(buf) == 0 { return 0, nil } n, err = s.recvBuf.Read(buf) if err == io.EOF { return n, ErrBrokenStream } log.Tracef("%v read from stream %v with err %v", n, s.id, err) return } func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed n, err := s.recvBuf.WriteTo(w) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream } return n, nil } func (s *Stream) writePayload(seq uint64, payload []byte) error { f := &Frame{ StreamID: s.id, Seq: seq, Closing: C_NOOP, Payload: payload, } var cipherTextLen int cipherTextLen, err := s.session.Obfs(f, s.obfsBuf) if err != nil { return err } _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) s.session.passiveClose() } return err } return nil } // Write implements io.Write func (s *Stream) Write(in []byte) (n int, err error) { s.writingM.Lock() defer s.writingM.Unlock() if s.isClosed() { return 0, ErrBrokenStream } if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.SendBufferSize) } for n < len(in) { var framePayload []byte if len(in)-n <= s.session.maxStreamUnitWrite { framePayload = in[n:] } else { if s.session.Unordered { // no splitting err = io.ErrShortBuffer return } framePayload = in[n : s.session.maxStreamUnitWrite+n] } err = s.writePayload(atomic.AddUint64(&s.nextSendSeq, 1)-1, framePayload) if err != nil { return } n += len(framePayload) } return } func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { s.writingM.Lock() defer s.writingM.Unlock() if s.obfsBuf == nil { s.obfsBuf = make([]byte, s.session.SendBufferSize) } for { read, er := r.Read(s.obfsBuf[HEADER_LEN : HEADER_LEN+s.session.maxStreamUnitWrite]) if er != nil { return n, er } if s.isClosed() { return 0, ErrBrokenStream } seq := atomic.AddUint64(&s.nextSendSeq, 1) - 1 err = s.writePayload(seq, s.obfsBuf[HEADER_LEN:HEADER_LEN+read]) if err != nil { return } n += int64(read) } } func (s *Stream) passiveClose() error { return s.session.closeStream(s, false) } // active close. Close locally and tell the remote that this stream is being closed func (s *Stream) Close() error { return s.session.closeStream(s, true) } // the following functions are purely for implementing net.Conn interface. // they are not used var errNotImplemented = errors.New("Not implemented") func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } // TODO: implement the following func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented } func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }