Introduce a special Session closing frame

This commit is contained in:
Andy Wang 2019-10-14 15:34:14 +01:00
parent c9318dc90b
commit 6580e38e44
3 changed files with 97 additions and 68 deletions

View File

@ -1,6 +1,7 @@
package multiplex package multiplex
import ( import (
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -134,14 +135,46 @@ func (sesh *Session) Accept() (net.Conn, error) {
return stream, nil return stream, nil
} }
func (sesh *Session) delStream(id uint32) { func (sesh *Session) closeStream(s *Stream, active bool) error {
atomic.StoreUint32(&s.closed, 1)
_ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
if active {
s.writingM.Lock()
defer s.writingM.Unlock()
if s.isClosed() {
return errors.New("Already Closed")
}
// Notify remote that this stream is closed
pad := genRandomPadding()
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 1,
Payload: pad,
}
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return err
}
_, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
if err != nil {
return err
}
log.Tracef("stream %v actively closed", s.id)
} else {
log.Tracef("stream %v passively closed", s.id)
}
sesh.streamsM.Lock() sesh.streamsM.Lock()
delete(sesh.streams, id) delete(sesh.streams, s.id)
if len(sesh.streams) == 0 { if len(sesh.streams) == 0 {
log.Tracef("session %v has no active stream left", sesh.id) log.Tracef("session %v has no active stream left", sesh.id)
go sesh.timeoutAfter(30 * time.Second) go sesh.timeoutAfter(30 * time.Second)
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
return nil
} }
func (sesh *Session) recvDataFromRemote(data []byte) error { func (sesh *Session) recvDataFromRemote(data []byte) error {
@ -159,6 +192,9 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
if frame.Closing == 1 { if frame.Closing == 1 {
// If the stream has been closed and the current frame is a closing frame, we do noop // If the stream has been closed and the current frame is a closing frame, we do noop
return nil return nil
} else if frame.Closing == 2 {
// Closing session
return sesh.passiveClose()
} else { } else {
// it may be tempting to use the connId from which the frame was received. However it doesn't make // it may be tempting to use the connId from which the frame was received. However it doesn't make
// any difference because we only care to send the data from the same stream through the same // any difference because we only care to send the data from the same stream through the same
@ -171,7 +207,6 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
return stream.writeFrame(*frame) return stream.writeFrame(*frame)
} }
} }
} }
func (sesh *Session) SetTerminalMsg(msg string) { func (sesh *Session) SetTerminalMsg(msg string) {
@ -187,20 +222,18 @@ func (sesh *Session) TerminalMsg() string {
} }
} }
func (sesh *Session) Close() error { func (sesh *Session) passiveClose() error {
log.Debugf("attempting to close session %v", sesh.id) log.Debugf("attempting to passively close session %v", sesh.id)
if atomic.SwapUint32(&sesh.closed, 1) == 1 { if atomic.SwapUint32(&sesh.closed, 1) == 1 {
log.Debugf("session %v has already been closed", sesh.id) log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing return errRepeatSessionClosing
} }
sesh.streamsM.Lock()
sesh.acceptCh <- nil sesh.acceptCh <- nil
sesh.streamsM.Lock()
for id, stream := range sesh.streams { for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock atomic.StoreUint32(&stream.closed, 1)
// because stream.Close calls sesh.delStream, which locks the mutex. _ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
// so we need to implement a method of stream that closes the stream without calling
// sesh.delStream
go stream.closeNoDelMap()
delete(sesh.streams, id) delete(sesh.streams, id)
} }
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
@ -208,7 +241,52 @@ func (sesh *Session) Close() error {
sesh.sb.closeAll() sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id) log.Debugf("session %v closed gracefully", sesh.id)
return nil return nil
}
func genRandomPadding() []byte {
lenB := make([]byte, 1)
rand.Read(lenB)
pad := make([]byte, lenB[0])
rand.Read(pad)
return pad
}
func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id)
if atomic.SwapUint32(&sesh.closed, 1) == 1 {
log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing
}
sesh.acceptCh <- nil
sesh.streamsM.Lock()
for id, stream := range sesh.streams {
atomic.StoreUint32(&stream.closed, 1)
_ = stream.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
delete(sesh.streams, id)
}
sesh.streamsM.Unlock()
pad := genRandomPadding()
f := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: 2,
Payload: pad,
}
obfsBuf := make([]byte, len(pad)+64)
i, err := sesh.Obfs(f, obfsBuf)
if err != nil {
return err
}
_, err = sesh.sb.send(obfsBuf[:i], new(uint32))
if err != nil {
return err
}
sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id)
return nil
} }
func (sesh *Session) IsClosed() bool { func (sesh *Session) IsClosed() bool {

View File

@ -7,8 +7,6 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math"
prand "math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
) )
@ -113,7 +111,8 @@ func (s *Stream) Write(in []byte) (n int, err error) {
log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err) log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.Close() s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
} }
return return
} }
@ -121,61 +120,13 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
// the necessary steps to mark the stream as closed and to release resources func (s *Stream) passiveClose() error {
func (s *Stream) _close() { return s.session.closeStream(s, false)
atomic.StoreUint32(&s.closed, 1)
_ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
}
// only close locally. Used when the stream close is notified by the remote
func (s *Stream) passiveClose() {
s._close()
s.session.delStream(s.id)
log.Tracef("stream %v passively closed", s.id)
} }
// active close. Close locally and tell the remote that this stream is being closed // active close. Close locally and tell the remote that this stream is being closed
func (s *Stream) Close() error { func (s *Stream) Close() error {
return s.session.closeStream(s, true)
s.writingM.Lock()
defer s.writingM.Unlock()
if s.isClosed() {
return errors.New("Already Closed")
}
// Notify remote that this stream is closed
prand.Seed(int64(s.id))
padLen := int(math.Floor(prand.Float64()*200 + 300))
pad := make([]byte, padLen)
prand.Read(pad)
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 1,
Payload: pad,
}
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return err
}
_, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
if err != nil {
return err
}
s._close()
s.session.delStream(s.id)
log.Tracef("stream %v actively closed", s.id)
return nil
}
// Same as passiveClose() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
// We don't notify the remote because session.Close() is always
// called when the session is passively closed
func (s *Stream) closeNoDelMap() {
log.Tracef("stream %v closed by session", s.id)
s._close()
} }
// the following functions are purely for implementing net.Conn interface. // the following functions are purely for implementing net.Conn interface.

View File

@ -62,7 +62,7 @@ func (sb *switchboard) removeConn(connId uint32) {
if remaining == 0 { if remaining == 0 {
atomic.StoreUint32(&sb.broken, 1) atomic.StoreUint32(&sb.broken, 1)
sb.session.SetTerminalMsg("no underlying connection left") sb.session.SetTerminalMsg("no underlying connection left")
sb.session.Close() sb.session.passiveClose()
} }
} }
@ -149,12 +149,12 @@ func (sb *switchboard) closeAll() {
if atomic.SwapUint32(&sb.broken, 1) == 1 { if atomic.SwapUint32(&sb.broken, 1) == 1 {
return return
} }
sb.connsM.RLock() sb.connsM.Lock()
for key, conn := range sb.conns { for key, conn := range sb.conns {
conn.Close() conn.Close()
delete(sb.conns, key) delete(sb.conns, key)
} }
sb.connsM.RUnlock() sb.connsM.Unlock()
} }
// deplex function costantly reads from a TCP connection // deplex function costantly reads from a TCP connection