diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index ca57f1d..703a465 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -1,6 +1,7 @@ package multiplex import ( + "crypto/rand" "errors" "fmt" "net" @@ -134,14 +135,46 @@ func (sesh *Session) Accept() (net.Conn, error) { return stream, nil } -func (sesh *Session) delStream(id uint32) { +func (sesh *Session) closeStream(s *Stream, active bool) error { + atomic.StoreUint32(&s.closed, 1) + _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() + + if active { + s.writingM.Lock() + defer s.writingM.Unlock() + if s.isClosed() { + return errors.New("Already Closed") + } + + // Notify remote that this stream is closed + pad := genRandomPadding() + f := &Frame{ + StreamID: s.id, + Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, + Closing: 1, + Payload: pad, + } + i, err := s.session.Obfs(f, s.obfsBuf) + if err != nil { + return err + } + _, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) + if err != nil { + return err + } + log.Tracef("stream %v actively closed", s.id) + } else { + log.Tracef("stream %v passively closed", s.id) + } + sesh.streamsM.Lock() - delete(sesh.streams, id) + delete(sesh.streams, s.id) if len(sesh.streams) == 0 { log.Tracef("session %v has no active stream left", sesh.id) go sesh.timeoutAfter(30 * time.Second) } sesh.streamsM.Unlock() + return nil } func (sesh *Session) recvDataFromRemote(data []byte) error { @@ -159,6 +192,9 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { if frame.Closing == 1 { // If the stream has been closed and the current frame is a closing frame, we do noop return nil + } else if frame.Closing == 2 { + // Closing session + return sesh.passiveClose() } else { // it may be tempting to use the connId from which the frame was received. However it doesn't make // any difference because we only care to send the data from the same stream through the same @@ -171,7 +207,6 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { return stream.writeFrame(*frame) } } - } func (sesh *Session) SetTerminalMsg(msg string) { @@ -187,20 +222,18 @@ func (sesh *Session) TerminalMsg() string { } } -func (sesh *Session) Close() error { - log.Debugf("attempting to close session %v", sesh.id) +func (sesh *Session) passiveClose() error { + log.Debugf("attempting to passively close session %v", sesh.id) if atomic.SwapUint32(&sesh.closed, 1) == 1 { log.Debugf("session %v has already been closed", sesh.id) return errRepeatSessionClosing } - sesh.streamsM.Lock() sesh.acceptCh <- nil + + sesh.streamsM.Lock() for id, stream := range sesh.streams { - // If we call stream.Close() here, streamsM will result in a deadlock - // because stream.Close calls sesh.delStream, which locks the mutex. - // so we need to implement a method of stream that closes the stream without calling - // sesh.delStream - go stream.closeNoDelMap() + atomic.StoreUint32(&stream.closed, 1) + _ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() delete(sesh.streams, id) } sesh.streamsM.Unlock() @@ -208,7 +241,52 @@ func (sesh *Session) Close() error { sesh.sb.closeAll() log.Debugf("session %v closed gracefully", sesh.id) return nil +} +func genRandomPadding() []byte { + lenB := make([]byte, 1) + rand.Read(lenB) + pad := make([]byte, lenB[0]) + rand.Read(pad) + return pad +} + +func (sesh *Session) Close() error { + log.Debugf("attempting to actively close session %v", sesh.id) + if atomic.SwapUint32(&sesh.closed, 1) == 1 { + log.Debugf("session %v has already been closed", sesh.id) + return errRepeatSessionClosing + } + sesh.acceptCh <- nil + + sesh.streamsM.Lock() + for id, stream := range sesh.streams { + atomic.StoreUint32(&stream.closed, 1) + _ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() + delete(sesh.streams, id) + } + sesh.streamsM.Unlock() + + pad := genRandomPadding() + f := &Frame{ + StreamID: 0xffffffff, + Seq: 0, + Closing: 2, + Payload: pad, + } + obfsBuf := make([]byte, len(pad)+64) + i, err := sesh.Obfs(f, obfsBuf) + if err != nil { + return err + } + _, err = sesh.sb.send(obfsBuf[:i], new(uint32)) + if err != nil { + return err + } + + sesh.sb.closeAll() + log.Debugf("session %v closed gracefully", sesh.id) + return nil } func (sesh *Session) IsClosed() bool { diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index ca403b8..84301a3 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -7,8 +7,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "math" - prand "math/rand" "sync" "sync/atomic" ) @@ -113,7 +111,8 @@ func (s *Stream) Write(in []byte) (n int, err error) { log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err) if err != nil { if err == errBrokenSwitchboard { - s.session.Close() + s.session.SetTerminalMsg(err.Error()) + s.session.passiveClose() } return } @@ -121,61 +120,13 @@ func (s *Stream) Write(in []byte) (n int, err error) { } -// the necessary steps to mark the stream as closed and to release resources -func (s *Stream) _close() { - atomic.StoreUint32(&s.closed, 1) - _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close() -} - -// only close locally. Used when the stream close is notified by the remote -func (s *Stream) passiveClose() { - s._close() - s.session.delStream(s.id) - log.Tracef("stream %v passively closed", s.id) +func (s *Stream) passiveClose() error { + return s.session.closeStream(s, false) } // active close. Close locally and tell the remote that this stream is being closed func (s *Stream) Close() error { - - s.writingM.Lock() - defer s.writingM.Unlock() - if s.isClosed() { - return errors.New("Already Closed") - } - - // Notify remote that this stream is closed - prand.Seed(int64(s.id)) - padLen := int(math.Floor(prand.Float64()*200 + 300)) - pad := make([]byte, padLen) - prand.Read(pad) - f := &Frame{ - StreamID: s.id, - Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1, - Closing: 1, - Payload: pad, - } - i, err := s.session.Obfs(f, s.obfsBuf) - if err != nil { - return err - } - _, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId) - if err != nil { - return err - } - - s._close() - s.session.delStream(s.id) - log.Tracef("stream %v actively closed", s.id) - return nil -} - -// Same as passiveClose() but no call to session.delStream. -// This is called in session.Close() to avoid mutex deadlock -// We don't notify the remote because session.Close() is always -// called when the session is passively closed -func (s *Stream) closeNoDelMap() { - log.Tracef("stream %v closed by session", s.id) - s._close() + return s.session.closeStream(s, true) } // the following functions are purely for implementing net.Conn interface. diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 837b907..20416f5 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -62,7 +62,7 @@ func (sb *switchboard) removeConn(connId uint32) { if remaining == 0 { atomic.StoreUint32(&sb.broken, 1) sb.session.SetTerminalMsg("no underlying connection left") - sb.session.Close() + sb.session.passiveClose() } } @@ -149,12 +149,12 @@ func (sb *switchboard) closeAll() { if atomic.SwapUint32(&sb.broken, 1) == 1 { return } - sb.connsM.RLock() + sb.connsM.Lock() for key, conn := range sb.conns { conn.Close() delete(sb.conns, key) } - sb.connsM.RUnlock() + sb.connsM.Unlock() } // deplex function costantly reads from a TCP connection