From 3b656c9360c3c0d9858bd87e7a7daebd1a070939 Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Fri, 23 Nov 2018 23:57:35 +0000 Subject: [PATCH] Use sync.Once to close die ch --- cmd/ck-server/ck-server.go | 5 ++-- internal/multiplex/session.go | 12 ++------- internal/multiplex/stream.go | 51 +++++++++++++---------------------- 3 files changed, 23 insertions(+), 45 deletions(-) diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index a5fde7d..5b25365 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -152,8 +152,9 @@ func dispatchConnection(conn net.Conn, sta *server.State) { for { newStream, err := sesh.AcceptStream() if err != nil { - log.Printf("Failed to get new stream: %v", err) + log.Printf("Failed to get new stream: %v\n", err) if err == mux.ErrBrokenSession { + log.Printf("Session closed: %x:%v\n", UID, sessionID) user.DelSession(sessionID) return } else { @@ -162,7 +163,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) if err != nil { - log.Printf("Failed to connect to ssserver: %v", err) + log.Printf("Failed to connect to ssserver: %v\n", err) continue } go pipe(ssConn, newStream) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6f2cfb2..8da565b 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -38,10 +38,8 @@ type Session struct { // For accepting new streams acceptCh chan *Stream - // TODO: use sync.Once for this - closingM sync.Mutex die chan struct{} - closing bool + overdose sync.Once // fentanyl? beware of respiratory depression } // 1 conn is needed to make a session @@ -123,13 +121,7 @@ func (sesh *Session) addStream(id uint32) *Stream { func (sesh *Session) Close() error { // Because closing a closed channel causes panic - sesh.closingM.Lock() - if sesh.closing { - sesh.closingM.Unlock() - return errRepeatSessionClosing - } - sesh.closing = true - close(sesh.die) + sesh.overdose.Do(func() { close(sesh.die) }) sesh.streamsM.Lock() for id, stream := range sesh.streams { // If we call stream.Close() here, streamsM will result in a deadlock diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 3d89e25..d42bd9e 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -10,7 +10,6 @@ import ( ) var errBrokenStream = errors.New("broken stream") -var errRepeatStreamClosing = errors.New("trying to close a closed stream") type Stream struct { id uint32 @@ -31,11 +30,11 @@ type Stream struct { // atomic nextSendSeq uint32 - closingM sync.RWMutex + writingM sync.RWMutex + // close(die) is used to notify different goroutines that this stream is closing - die chan struct{} - // to prevent closing a closed channel - closing bool + die chan struct{} + heliumMask sync.Once // my personal fav } func makeStream(id uint32, sesh *Session) *Stream { @@ -84,10 +83,10 @@ func (stream *Stream) Write(in []byte) (n int, err error) { // The use of RWMutex is so that the stream will not actively close // in the middle of the execution of Write. This may cause the closing frame // to be sent before the data frame and cause loss of packet. - stream.closingM.RLock() + stream.writingM.RLock() select { case <-stream.die: - stream.closingM.RUnlock() + stream.writingM.RUnlock() return 0, errBrokenStream default: } @@ -101,43 +100,26 @@ func (stream *Stream) Write(in []byte) (n int, err error) { tlsRecord := stream.session.obfs(f) n, err = stream.session.sb.send(tlsRecord) - stream.closingM.RUnlock() + stream.writingM.RUnlock() return } -func (stream *Stream) shutdown() error { - // Lock here because closing a closed channel causes panic - stream.closingM.Lock() - if stream.closing { - stream.closingM.Unlock() - return errRepeatStreamClosing - } - stream.closing = true - close(stream.die) - stream.closingM.Unlock() - return nil -} - // only close locally. Used when the stream close is notified by the remote func (stream *Stream) passiveClose() error { - err := stream.shutdown() - if err != nil { - return err - } + stream.heliumMask.Do(func() { close(stream.die) }) stream.session.delStream(stream.id) log.Printf("%v passive closing\n", stream.id) + // TODO: really need to return an error? return nil } // active close. Close locally and tell the remote that this stream is being closed func (stream *Stream) Close() error { - err := stream.shutdown() - if err != nil { - return err - } + stream.writingM.Lock() + stream.heliumMask.Do(func() { close(stream.die) }) // Notify remote that this stream is closed prand.Seed(int64(stream.id)) @@ -151,17 +133,20 @@ func (stream *Stream) Close() error { Payload: pad, } tlsRecord := stream.session.obfs(f) - // FIXME: despite sb.send being always called after Write(), the actual TCP sending - // may still be out of order stream.session.sb.send(tlsRecord) stream.session.delStream(stream.id) log.Printf("%v actively closed\n", stream.id) + stream.writingM.Unlock() return nil } -// Same as Close() but no call to session.delStream. +// Same as passiveClose() but no call to session.delStream. // This is called in session.Close() to avoid mutex deadlock +// We don't notify the remote because session.Close() is always +// called when the session is passively closed func (stream *Stream) closeNoDelMap() error { - return stream.shutdown() + stream.heliumMask.Do(func() { close(stream.die) }) + // TODO: really need to return an error? + return nil }