Use sync.Map for lock free pickRandConn

pull/158/head
Andy Wang 3 years ago
parent 8ab0c2d96b
commit b4d65d8a0e
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374

@ -23,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type switchboardStrategy int
type SessionConfig struct {
Obfuscator

@ -10,6 +10,8 @@ import (
"time"
)
type switchboardStrategy int
const (
FIXED_CONN_MAPPING switchboardStrategy = iota
UNIFORM_SPREAD
@ -28,9 +30,9 @@ type switchboard struct {
valve Valve
strategy switchboardStrategy
connsM sync.RWMutex
conns []net.Conn
randPool sync.Pool
conns sync.Map
connsCount uint32
randPool sync.Pool
broken uint32
}
@ -57,27 +59,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) delConn(conn net.Conn) {
sb.connsM.Lock()
defer sb.connsM.Unlock()
if len(sb.conns) <= 1 {
sb.conns = nil
return
if _, ok := sb.conns.LoadAndDelete(conn); ok {
atomic.AddUint32(&sb.connsCount, ^uint32(0))
}
var i int
var c net.Conn
for i, c = range sb.conns {
if c == conn {
break
}
}
sb.conns = append(sb.conns[:i], sb.conns[i+1:]...)
}
func (sb *switchboard) addConn(conn net.Conn) {
sb.connsM.Lock()
sb.conns = append(sb.conns, conn)
sb.connsM.Unlock()
atomic.AddUint32(&sb.connsCount, 1)
sb.conns.Store(conn, conn)
go sb.deplex(conn)
}
@ -133,18 +122,28 @@ func (sb *switchboard) pickRandConn() (net.Conn, error) {
return nil, errBrokenSwitchboard
}
randReader := sb.randPool.Get().(*rand.Rand)
sb.connsM.RLock()
defer sb.connsM.RUnlock()
connsCount := len(sb.conns)
connsCount := atomic.LoadUint32(&sb.connsCount)
if connsCount == 0 {
return nil, errBrokenSwitchboard
}
r := randReader.Intn(connsCount)
randReader := sb.randPool.Get().(*rand.Rand)
r := randReader.Intn(int(connsCount))
sb.randPool.Put(randReader)
return sb.conns[r], nil
var c int
var ret net.Conn
sb.conns.Range(func(_, conn interface{}) bool {
if r == c {
ret = conn.(net.Conn)
return false
}
c++
return true
})
return ret, nil
}
// actively triggered by session.Close()
@ -152,12 +151,12 @@ func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return
}
sb.connsM.Lock()
for _, conn := range sb.conns {
conn.Close()
}
sb.conns = nil
sb.connsM.Unlock()
sb.conns.Range(func(_, conn interface{}) bool {
conn.(net.Conn).Close()
sb.conns.Delete(conn)
atomic.AddUint32(&sb.connsCount, ^uint32(0))
return true
})
}
// deplex function costantly reads from a TCP connection

@ -5,6 +5,7 @@ import (
"github.com/stretchr/testify/assert"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"
)
@ -173,17 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
}
wg.Wait()
sesh.sb.connsM.RLock()
if len(sesh.sb.conns) != 1000 {
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
t.Error("connsCount incorrect")
}
sesh.sb.connsM.RUnlock()
sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool {
sesh.sb.connsM.RLock()
defer sesh.sb.connsM.RUnlock()
return len(sesh.sb.conns) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect")
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
}

Loading…
Cancel
Save