Use defer to unlock mutexes

pull/71/head
Qian Wang 5 years ago
parent bf8d373f79
commit 059a222394

@ -74,9 +74,9 @@ func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func
func (sta *State) UpdateIntervalKeys() { func (sta *State) UpdateIntervalKeys() {
sta.intervalDataM.Lock() sta.intervalDataM.Lock()
defer sta.intervalDataM.Unlock()
currentInterval := sta.Now().Unix() / int64(sta.TicketTimeHint) currentInterval := sta.Now().Unix() / int64(sta.TicketTimeHint)
if currentInterval == sta.intervalData.interval { if currentInterval == sta.intervalData.interval {
sta.intervalDataM.Unlock()
return return
} }
sta.intervalData.interval = currentInterval sta.intervalData.interval = currentInterval
@ -84,7 +84,6 @@ func (sta *State) UpdateIntervalKeys() {
intervalKey := ecdh.GenerateSharedSecret(ephPv, sta.staticPub) intervalKey := ecdh.GenerateSharedSecret(ephPv, sta.staticPub)
seed := int64(binary.BigEndian.Uint64(ephPv.(*[32]byte)[0:8])) seed := int64(binary.BigEndian.Uint64(ephPv.(*[32]byte)[0:8]))
sta.intervalData.ephPv, sta.intervalData.ephPub, sta.intervalData.intervalKey, sta.intervalData.seed = ephPv, ephPub, intervalKey, seed sta.intervalData.ephPv, sta.intervalData.ephPub, sta.intervalData.intervalKey, sta.intervalData.seed = ephPv, ephPub, intervalKey, seed
sta.intervalDataM.Unlock()
} }
func (sta *State) GetIntervalKeys() (crypto.PublicKey, []byte, int64) { func (sta *State) GetIntervalKeys() (crypto.PublicKey, []byte, int64) {

@ -110,22 +110,20 @@ func (sesh *Session) getStream(id uint32, closingFrame bool) *Stream {
// it would have been neater to use defer Unlock(), however it gives // it would have been neater to use defer Unlock(), however it gives
// non-negligable overhead and this function is performance critical // non-negligable overhead and this function is performance critical
sesh.streamsM.Lock() sesh.streamsM.Lock()
defer sesh.streamsM.Unlock()
stream := sesh.streams[id] stream := sesh.streams[id]
if stream != nil { if stream != nil {
sesh.streamsM.Unlock()
return stream return stream
} else { } else {
if closingFrame { if closingFrame {
// If the stream has been closed and the current frame is a closing frame, // If the stream has been closed and the current frame is a closing frame,
// we return nil // we return nil
sesh.streamsM.Unlock()
return nil return nil
} else { } else {
stream = makeStream(id, sesh) stream = makeStream(id, sesh)
sesh.streams[id] = stream sesh.streams[id] = stream
sesh.acceptCh <- stream sesh.acceptCh <- stream
//log.Printf("Adding stream %v\n", id) //log.Printf("Adding stream %v\n", id)
sesh.streamsM.Unlock()
return stream return stream
} }
} }

@ -80,8 +80,8 @@ func (s *Stream) Write(in []byte) (n int, err error) {
// in the middle of the execution of Write. This may cause the closing frame // in the middle of the execution of Write. This may cause the closing frame
// to be sent before the data frame and cause loss of packet. // to be sent before the data frame and cause loss of packet.
s.writingM.RLock() s.writingM.RLock()
defer s.writingM.RUnlock()
if s.isClosed() { if s.isClosed() {
s.writingM.RUnlock()
return 0, ErrBrokenStream return 0, ErrBrokenStream
} }
@ -94,12 +94,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
tlsRecord, err := s.session.obfs(f) tlsRecord, err := s.session.obfs(f)
if err != nil { if err != nil {
s.writingM.RUnlock()
return 0, err return 0, err
} }
n, err = s.session.sb.send(tlsRecord) n, err = s.session.sb.send(tlsRecord)
s.writingM.RUnlock()
return return
} }
@ -122,8 +119,8 @@ func (s *Stream) passiveClose() {
func (s *Stream) Close() error { func (s *Stream) Close() error {
s.writingM.Lock() s.writingM.Lock()
defer s.writingM.Unlock()
if s.isClosed() { if s.isClosed() {
s.writingM.Unlock()
return errors.New("Already Closed") return errors.New("Already Closed")
} }
@ -144,7 +141,6 @@ func (s *Stream) Close() error {
s._close() s._close()
s.session.delStream(s.id) s.session.delStream(s.id)
//log.Printf("%v actively closed\n", stream.id) //log.Printf("%v actively closed\n", stream.id)
s.writingM.Unlock()
return nil return nil
} }

@ -32,18 +32,16 @@ func (u *ActiveUser) DelSession(sessionID uint32) {
func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock() u.sessionsM.Lock()
defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil { if sesh = u.sessions[sessionID]; sesh != nil {
u.sessionsM.Unlock()
return sesh, true, nil return sesh, true, nil
} else { } else {
err := u.panel.Manager.authoriseNewSession(u) err := u.panel.Manager.authoriseNewSession(u)
if err != nil { if err != nil {
u.sessionsM.Unlock()
return nil, false, err return nil, false, err
} }
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.sessions[sessionID] = sesh u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
return sesh, false, nil return sesh, false, nil
} }
} }
@ -64,7 +62,6 @@ func (u *ActiveUser) Terminate(reason string) {
func (u *ActiveUser) NumSession() int { func (u *ActiveUser) NumSession() int {
u.sessionsM.RLock() u.sessionsM.RLock()
l := len(u.sessions) defer u.sessionsM.RUnlock()
u.sessionsM.RUnlock() return len(u.sessions)
return l
} }

Loading…
Cancel
Save