Refactor for clarity and add comments

pull/132/head
Andy Wang 4 years ago
parent d706e8f087
commit c7c3f7706d

@ -54,11 +54,11 @@ func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.D
} }
seshConfig := mux.SessionConfig{ seshConfig := mux.SessionConfig{
Singleplex: connConfig.Singleplex, Singleplex: connConfig.Singleplex,
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
Unordered: authInfo.Unordered, Unordered: authInfo.Unordered,
MaxFrameSize: appDataMaxLength, MsgOnWireSizeLimit: appDataMaxLength,
} }
sesh := mux.MakeSession(authInfo.SessionId, seshConfig) sesh := mux.MakeSession(authInfo.SessionId, seshConfig)

@ -34,10 +34,14 @@ type SessionConfig struct {
Singleplex bool Singleplex bool
// maximum size of Frame.Payload // maximum size of an obfuscated frame, including headers and overhead
MaxFrameSize int MsgOnWireSizeLimit int
SendBufferSize int
ReceiveBufferSize int // this sets the buffer size used to send data from a Stream (Stream.obfsBuf)
StreamSendBufferSize int
// this sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
ConnReceiveBufferSize int
} }
type Session struct { type Session struct {
@ -66,6 +70,7 @@ type Session struct {
terminalMsg atomic.Value terminalMsg atomic.Value
// the max size passed to Write calls before it splits it into multiple frames // the max size passed to Write calls before it splits it into multiple frames
// i.e. the max size a piece of data can fit into a Frame.Payload
maxStreamUnitWrite int maxStreamUnitWrite int
} }
@ -81,29 +86,19 @@ func MakeSession(id uint32, config SessionConfig) *Session {
if config.Valve == nil { if config.Valve == nil {
sesh.Valve = UNLIMITED_VALVE sesh.Valve = UNLIMITED_VALVE
} }
if config.SendBufferSize <= 0 { if config.StreamSendBufferSize <= 0 {
sesh.SendBufferSize = defaultSendRecvBufSize sesh.StreamSendBufferSize = defaultSendRecvBufSize
} }
if config.ReceiveBufferSize <= 0 { if config.ConnReceiveBufferSize <= 0 {
sesh.ReceiveBufferSize = defaultSendRecvBufSize sesh.ConnReceiveBufferSize = defaultSendRecvBufSize
} }
if config.MaxFrameSize <= 0 { if config.MsgOnWireSizeLimit <= 0 {
sesh.MaxFrameSize = defaultSendRecvBufSize - 1024 sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024
} }
// todo: validation. this must be smaller than the buffer sizes // todo: validation. this must be smaller than StreamSendBufferSize
sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.maxOverhead sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - HEADER_LEN - sesh.Obfuscator.maxOverhead
sbConfig := switchboardConfig{ sesh.sb = makeSwitchboard(sesh)
valve: sesh.Valve,
recvBufferSize: sesh.ReceiveBufferSize,
}
if sesh.Unordered {
log.Debug("Connection is unordered")
sbConfig.strategy = UNIFORM_SPREAD
} else {
sbConfig.strategy = FIXED_CONN_MAPPING
}
sesh.sb = makeSwitchboard(sesh, sbConfig)
go sesh.timeoutAfter(30 * time.Second) go sesh.timeoutAfter(30 * time.Second)
return sesh return sesh
} }
@ -218,12 +213,12 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
// this is when the stream existed before but has since been closed. We do nothing // this is when the stream existed before but has since been closed. We do nothing
return nil return nil
} }
return existingStreamI.(*Stream).writeFrame(*frame) return existingStreamI.(*Stream).recvFrame(*frame)
} else { } else {
// new stream // new stream
sesh.streamCountIncr() sesh.streamCountIncr()
sesh.acceptCh <- newStream sesh.acceptCh <- newStream
return newStream.writeFrame(*frame) return newStream.recvFrame(*frame)
} }
} }

@ -27,9 +27,12 @@ type Stream struct {
// atomic // atomic
closed uint32 closed uint32
// only alloc when writing to the stream // lazy allocation for obfsBuf. This is desirable because obfsBuf is only used when data is sent from
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
// memory
allocIdempot sync.Once allocIdempot sync.Once
obfsBuf []byte // obfuscation happens in this buffer
obfsBuf []byte
// we assign each stream a fixed underlying TCP connection to utilise order guarantee provided by TCP itself // we assign each stream a fixed underlying TCP connection to utilise order guarantee provided by TCP itself
// so that frameSorter should have few to none ooo frames to deal with // so that frameSorter should have few to none ooo frames to deal with
@ -59,7 +62,8 @@ func makeStream(sesh *Session, id uint32) *Stream {
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) writeFrame(frame Frame) error { // receive a readily deobfuscated Frame so its payload can later be Read
func (s *Stream) recvFrame(frame Frame) error {
toBeClosed, err := s.recvBuf.Write(frame) toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed { if toBeClosed {
err = s.passiveClose() err = s.passiveClose()
@ -125,7 +129,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
if s.obfsBuf == nil { if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.SendBufferSize) s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
} }
for n < len(in) { for n < len(in) {
var framePayload []byte var framePayload []byte
@ -156,7 +160,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
if s.obfsBuf == nil { if s.obfsBuf == nil {
s.obfsBuf = make([]byte, s.session.SendBufferSize) s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
} }
for { for {
if s.rfTimeout != 0 { if s.rfTimeout != 0 {
@ -204,16 +208,17 @@ func (s *Stream) Close() error {
return s.session.closeStream(s, true) return s.session.closeStream(s, true)
} }
// the following functions are purely for implementing net.Conn interface.
// they are not used
var errNotImplemented = errors.New("Not implemented")
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
// TODO: implement the following // TODO: implement the following
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.rfTimeout = d } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.rfTimeout = d }
// the following functions are purely for implementing net.Conn interface.
// they are not used
var errNotImplemented = errors.New("Not implemented")
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented } func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }

@ -14,12 +14,6 @@ const (
UNIFORM_SPREAD UNIFORM_SPREAD
) )
type switchboardConfig struct {
valve Valve
strategy switchboardStrategy
recvBufferSize int
}
// switchboard is responsible for managing TCP connections between client and server. // switchboard is responsible for managing TCP connections between client and server.
// It has several purposes: constantly receiving incoming data from all connections // It has several purposes: constantly receiving incoming data from all connections
// and pass them to Session.recvDataFromRemote(); accepting data through // and pass them to Session.recvDataFromRemote(); accepting data through
@ -29,8 +23,10 @@ type switchboardConfig struct {
type switchboard struct { type switchboard struct {
session *Session session *Session
switchboardConfig valve Valve
strategy switchboardStrategy
// map of connId to net.Conn
conns sync.Map conns sync.Map
numConns uint32 numConns uint32
nextConnId uint32 nextConnId uint32
@ -38,13 +34,19 @@ type switchboard struct {
broken uint32 broken uint32
} }
func makeSwitchboard(sesh *Session, config switchboardConfig) *switchboard { func makeSwitchboard(sesh *Session) *switchboard {
// rates are uint64 because in the usermanager we want the bandwidth to be atomically var strategy switchboardStrategy
// operated (so that the bandwidth can change on the fly). if sesh.Unordered {
log.Debug("Connection is unordered")
strategy = UNIFORM_SPREAD
} else {
strategy = FIXED_CONN_MAPPING
}
sb := &switchboard{ sb := &switchboard{
session: sesh, session: sesh,
switchboardConfig: config, strategy: strategy,
nextConnId: 1, valve: sesh.Valve,
nextConnId: 1,
} }
return sb return sb
} }
@ -156,7 +158,7 @@ func (sb *switchboard) closeAll() {
// deplex function costantly reads from a TCP connection // deplex function costantly reads from a TCP connection
func (sb *switchboard) deplex(connId uint32, conn net.Conn) { func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
defer conn.Close() defer conn.Close()
buf := make([]byte, sb.recvBufferSize) buf := make([]byte, sb.session.ConnReceiveBufferSize)
for { for {
n, err := conn.Read(buf) n, err := conn.Read(buf)
sb.valve.rxWait(n) sb.valve.rxWait(n)

@ -181,10 +181,10 @@ func dispatchConnection(conn net.Conn, sta *State) {
} }
seshConfig := mux.SessionConfig{ seshConfig := mux.SessionConfig{
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
Unordered: ci.Unordered, Unordered: ci.Unordered,
MaxFrameSize: appDataMaxLength, MsgOnWireSizeLimit: appDataMaxLength,
} }
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not // adminUID can use the server as normal with unlimited QoS credits. The adminUID is not

Loading…
Cancel
Save