diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 41c5a98..e495c8c 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log" + "math/rand" "net" "os" "time" @@ -60,7 +61,7 @@ func makeRemoteConn(sta *client.State) (net.Conn, error) { // Three discarded messages: ServerHello, ChangeCipherSpec and Finished discardBuf := make([]byte, 1024) for c := 0; c < 3; c++ { - _, err = util.ReadTillDrain(remoteConn, discardBuf) + _, err = util.ReadTLS(remoteConn, discardBuf) if err != nil { log.Printf("Reading discarded message %v: %v\n", c, err) return nil, err @@ -122,9 +123,13 @@ func main() { log.Printf("Starting standalone mode. Listening for ss on %v:%v\n", localHost, localPort) } - opaque := time.Now().UnixNano() + // sessionID is usergenerated. There shouldn't be a security concern because the scope of + // sessionID is limited to its UID. + rand.Seed(time.Now().UnixNano()) + sessionID := rand.Uint32() + // opaque is used to generate the padding of session ticket - sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, opaque) + sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, sessionID) err := sta.ParseConfig(pluginOpts) if err != nil { log.Fatal(err) @@ -140,19 +145,19 @@ func main() { log.Fatal("TicketTimeHint cannot be empty or 0") } - obfs := util.MakeObfs(sta.SID) - deobfs := util.MakeDeobfs(sta.SID) - sesh := mux.MakeSession(0, 1e9, 1e9, obfs, deobfs, util.ReadTillDrain) + valve := mux.MakeValve(1e9, 1e9, 1e9, 1e9) + obfs := util.MakeObfs(sta.UID) + deobfs := util.MakeDeobfs(sta.UID) + sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS) + // TODO: use sync group for i := 0; i < sta.NumConn; i++ { - go func() { - conn, err := makeRemoteConn(sta) - if err != nil { - log.Printf("Failed to establish new connections to remote: %v\n", err) - return - } - sesh.AddConnection(conn) - }() + conn, err := makeRemoteConn(sta) + if err != nil { + log.Printf("Failed to establish new connections to remote: %v\n", err) + return + } + sesh.AddConnection(conn) } listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) @@ -175,8 +180,12 @@ func main() { stream, err := sesh.OpenStream() if err != nil { ssConn.Close() + return + } + _, err = stream.Write(data[:i]) + if err != nil { + log.Println(err) } - stream.Write(data[:i]) go pipe(ssConn, stream) pipe(stream, ssConn) }() diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 6ee9be5..a4decb1 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -1,15 +1,16 @@ package main import ( + "encoding/hex" "flag" "fmt" "io" "log" "net" - //"net/http" - //_ "net/http/pprof" + "net/http" + _ "net/http/pprof" "os" - //"runtime" + "runtime" "strings" "time" @@ -70,14 +71,21 @@ func dispatchConnection(conn net.Conn, sta *server.State) { return } - isSS, SID := server.TouchStone(ch, sta) + isSS, UID, sessionID := server.TouchStone(ch, sta) if !isSS { log.Printf("+1 non SS TLS traffic from %v\n", conn.RemoteAddr()) goWeb(data) return } - // TODO: verify SID + var arrUID [32]byte + copy(arrUID[:], UID) + user, err := sta.Userpanel.GetAndActivateUser(arrUID) + log.Printf("UID: %x\n", UID) + if err != nil { + log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID) + goWeb(data) + } reply := server.ComposeReply(ch) _, err = conn.Write(reply) @@ -90,7 +98,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { // Two discarded messages: ChangeCipherSpec and Finished discardBuf := make([]byte, 1024) for c := 0; c < 2; c++ { - _, err = util.ReadTillDrain(conn, discardBuf) + _, err = util.ReadTLS(conn, discardBuf) if err != nil { log.Printf("Reading discarded message %v: %v\n", c, err) go conn.Close() @@ -98,45 +106,36 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } } - go func() { - var arrSID [32]byte - copy(arrSID[:], SID) - var sesh *mux.Session - if sesh = sta.GetSession(arrSID); sesh == nil { - sesh = mux.MakeSession(0, 1e9, 1e9, util.MakeObfs(SID), util.MakeDeobfs(SID), util.ReadTillDrain) - sta.PutSession(arrSID, sesh) - } - sesh.AddConnection(conn) - go func() { - for { - newStream, err := sesh.AcceptStream() - if err != nil { - log.Printf("Failed to get new stream: %v", err) - if err == mux.ErrBrokenSession { - sta.DelSession(arrSID) - return - } else { - continue - } - } - ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) - if err != nil { - log.Printf("Failed to connect to ssserver: %v", err) - continue - } - go pipe(ssConn, newStream) - go pipe(newStream, ssConn) + // FIXME: the following code should not be executed for every single remote connection + sesh := user.GetOrCreateSession(sessionID, util.MakeObfs(UID), util.MakeDeobfs(UID), util.ReadTLS) + sesh.AddConnection(conn) + for { + newStream, err := sesh.AcceptStream() + if err != nil { + log.Printf("Failed to get new stream: %v", err) + if err == mux.ErrBrokenSession { + user.DelSession(sessionID) + return + } else { + continue } - }() - }() + } + ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) + if err != nil { + log.Printf("Failed to connect to ssserver: %v", err) + continue + } + go pipe(ssConn, newStream) + go pipe(newStream, ssConn) + } } func main() { - //runtime.SetBlockProfileRate(5) - //go func() { - // log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) - //}() + runtime.SetBlockProfileRate(5) + go func() { + log.Println(http.ListenAndServe("0.0.0.0:8001", nil)) + }() // Should be 127.0.0.1 to listen to ss-server on this machine var localHost string // server_port in ss config, same as remotePort in plugin mode @@ -181,7 +180,13 @@ func main() { localPort = strings.Split(*localAddr, ":")[1] log.Printf("Starting standalone mode, listening on %v:%v to ss at %v:%v\n", remoteHost, remotePort, localHost, localPort) } - sta := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now) + sta, _ := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now, "userinfo.db") + + //debug + var arrUID [32]byte + UID, _ := hex.DecodeString("50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c") + copy(arrUID[:], UID) + sta.Userpanel.AddNewUser(arrUID, 10, 1e12, 1e12, 1e12, 1e12) err := sta.ParseConfig(pluginOpts) if err != nil { log.Fatalf("Configuration file error: %v", err) diff --git a/internal/client/auth.go b/internal/client/auth.go index 596600d..2e8ec66 100644 --- a/internal/client/auth.go +++ b/internal/client/auth.go @@ -21,7 +21,7 @@ func MakeRandomField(sta *State) []byte { rdm := make([]byte, 16) io.ReadFull(rand.Reader, rdm) preHash := make([]byte, 56) - copy(preHash[0:32], sta.SID) + copy(preHash[0:32], sta.UID) copy(preHash[32:40], t) copy(preHash[40:56], rdm) h := sha256.New() @@ -33,9 +33,9 @@ func MakeRandomField(sta *State) []byte { } func MakeSessionTicket(sta *State) []byte { - // sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted SID 32 bytes][padding 128 bytes] + // sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted UID+sessionID 36 bytes][padding 124 bytes] // The first 16 bytes of the marshalled ephemeral public key is used as the IV - // for encrypting the SID + // for encrypting the UID tthInterval := sta.Now().Unix() / int64(sta.TicketTimeHint) ec := ecdh.NewCurve25519ECDH() ephKP := sta.getKeyPair(tthInterval) @@ -50,8 +50,21 @@ func MakeSessionTicket(sta *State) []byte { ticket := make([]byte, 192) copy(ticket[0:32], ec.Marshal(ephKP.PublicKey)) key, _ := ec.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub) - cipherSID := util.AESEncrypt(ticket[0:16], key, sta.SID) - copy(ticket[32:64], cipherSID) - copy(ticket[64:192], util.PsudoRandBytes(128, tthInterval+sta.opaque)) + plainUIDsID := make([]byte, 36) + copy(plainUIDsID, sta.UID) + binary.BigEndian.PutUint32(plainUIDsID[32:36], sta.sessionID) + cipherUIDsID := util.AESEncrypt(ticket[0:16], key, plainUIDsID) + copy(ticket[32:68], cipherUIDsID) + // The purpose of adding sessionID is that, the generated padding of sessionTicket needs to be unpredictable. + // As shown in auth.go, the padding is generated by a psudo random generator. The seed + // needs to be the same for each TicketTimeHint interval. However the value of epoch/TicketTimeHint + // is public knowledge, so is the psudo random algorithm used by math/rand. Therefore not only + // can the firewall tell that the padding is generated in this specific way, this padding is identical + // for all ckclients in the same TicketTimeHint interval. This will expose us. + // + // With the sessionID value generated at startup of ckclient and used as a part of the seed, the + // sessionTicket is still identical for each TicketTimeHint interval, but others won't be able to know + // how it was generated. It will also be different for each client. + copy(ticket[68:192], util.PsudoRandBytes(124, tthInterval+int64(sta.sessionID))) return ticket } diff --git a/internal/client/state.go b/internal/client/state.go index 8be4310..affaa93 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -29,8 +29,8 @@ type State struct { SS_REMOTE_PORT string Now func() time.Time - opaque int64 - SID []byte + sessionID uint32 + UID []byte staticPub crypto.PublicKey keyPairsM sync.RWMutex keyPairs map[int64]*keyPair @@ -41,14 +41,14 @@ type State struct { NumConn int } -func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, opaque int64) *State { +func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, sessionID uint32) *State { ret := &State{ SS_LOCAL_HOST: localHost, SS_LOCAL_PORT: localPort, SS_REMOTE_HOST: remoteHost, SS_REMOTE_PORT: remotePort, Now: nowFunc, - opaque: opaque, + sessionID: sessionID, } ret.keyPairs = make(map[int64]*keyPair) return ret @@ -56,6 +56,7 @@ func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func // semi-colon separated value. This is for Android plugin options func ssvToJson(ssv string) (ret []byte) { + // TODO: base64 encoded data has =. How to escape? unescape := func(s string) string { r := strings.Replace(s, "\\\\", "\\", -1) r = strings.Replace(r, "\\=", "=", -1) @@ -104,16 +105,16 @@ func (sta *State) ParseConfig(conf string) (err error) { sta.TicketTimeHint = preParse.TicketTimeHint sta.MaskBrowser = preParse.MaskBrowser sta.NumConn = preParse.NumConn - sid, pub, err := parseKey(preParse.Key) + uid, pub, err := parseKey(preParse.Key) if err != nil { return errors.New("Failed to parse Key: " + err.Error()) } - sta.SID = sid + sta.UID = uid sta.staticPub = pub return nil } -// Structure: [SID 32 bytes][marshalled public key 32 bytes] +// Structure: [UID 32 bytes][marshalled public key 32 bytes] func parseKey(b64 string) ([]byte, crypto.PublicKey, error) { b, err := base64.StdEncoding.DecodeString(b64) if err != nil { diff --git a/internal/multiplex/frameSorter.go b/internal/multiplex/frameSorter.go index 395293a..2cf0022 100644 --- a/internal/multiplex/frameSorter.go +++ b/internal/multiplex/frameSorter.go @@ -15,8 +15,7 @@ import ( // make sure packets arrive in order. // // Cloak packets will have a 32-bit sequence number on them, so we know in which order -// they should be sent to shadowsocks. In the case that the packets arrive out-of-order, -// the code in this file provides buffering and sorting. +// they should be sent to shadowsocks. The code in this file provides buffering and sorting. // // Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around. // @@ -54,6 +53,12 @@ func (sh *sorterHeap) Pop() interface{} { return x } +func (s *Stream) writeNewFrame(f *Frame) { + s.newFrameCh <- f +} + +// recvNewFrame is a forever running loop which receives frames unordered, +// cache and order them and send them into sortedBufCh func (s *Stream) recvNewFrame() { for { var f *Frame @@ -69,7 +74,7 @@ func (s *Stream) recvNewFrame() { if len(s.sh) == 0 && f.Seq == s.nextRecvSeq { if f.Closing == 1 { - s.passiveClose() + s.sortedBufCh <- []byte{} return } @@ -115,7 +120,7 @@ func (s *Stream) recvNewFrame() { frame := heap.Pop(&s.sh).(*frameNode).frame if frame.Closing == 1 { - s.passiveClose() + s.sortedBufCh <- []byte{} return } payload := frame.Payload diff --git a/internal/multiplex/qos.go b/internal/multiplex/qos.go new file mode 100644 index 0000000..581a6ec --- /dev/null +++ b/internal/multiplex/qos.go @@ -0,0 +1,58 @@ +package multiplex + +import ( + "sync/atomic" + + "github.com/juju/ratelimit" +) + +// Valve needs to be universal, across all sessions that belong to a user +// gabe please don't sue +type Valve struct { + // traffic directions from the server's perspective are refered + // exclusively as rx and tx. + // rx is from client to server, tx is from server to client + // DO NOT use terms up or down as this is used in usermanager + // for bandwidth limiting + rxtb atomic.Value // *ratelimit.Bucket + txtb atomic.Value // *ratelimit.Bucket + + rxCredit int64 + txCredit int64 +} + +func MakeValve(rxRate, txRate, rxCredit, txCredit int64) *Valve { + v := &Valve{ + rxCredit: rxCredit, + txCredit: txCredit, + } + v.SetRxRate(rxRate) + v.SetTxRate(txRate) + return v +} + +func (v *Valve) SetRxRate(rate int64) { + v.rxtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) +} + +func (v *Valve) SetTxRate(rate int64) { + v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) +} + +func (v *Valve) rxWait(n int) { + v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) +} + +func (v *Valve) txWait(n int) { + v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) +} + +// n can be negative +func (v *Valve) AddRxCredit(n int64) int64 { + return atomic.AddInt64(&v.rxCredit, n) +} + +// n can be negative +func (v *Valve) AddTxCredit(n int64) int64 { + return atomic.AddInt64(&v.txCredit, n) +} diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index d70870b..1301008 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -2,6 +2,7 @@ package multiplex import ( "errors" + "log" "net" "sync" "sync/atomic" @@ -16,14 +17,14 @@ var ErrBrokenSession = errors.New("broken session") var errRepeatSessionClosing = errors.New("trying to close a closed session") type Session struct { - id int + id uint32 // This field isn't acutally used // Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header obfs func(*Frame) []byte // Remove TLS header, decrypt and unmarshall multiplexing headers deobfs func([]byte) *Frame // This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain - obfsedReader func(net.Conn, []byte) (int, error) + obfsedRead func(net.Conn, []byte) (int, error) // atomic nextStreamID uint32 @@ -37,24 +38,25 @@ type Session struct { // For accepting new streams acceptCh chan *Stream + // TODO: use sync.Once for this closingM sync.Mutex die chan struct{} closing bool } // 1 conn is needed to make a session -func MakeSession(id int, uprate, downrate float64, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedReader func(net.Conn, []byte) (int, error)) *Session { +func MakeSession(id uint32, valve *Valve, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedRead func(net.Conn, []byte) (int, error)) *Session { sesh := &Session{ id: id, obfs: obfs, deobfs: deobfs, - obfsedReader: obfsedReader, + obfsedRead: obfsedRead, nextStreamID: 1, streams: make(map[uint32]*Stream), acceptCh: make(chan *Stream, acceptBacklog), die: make(chan struct{}), } - sesh.sb = makeSwitchboard(sesh, uprate, downrate) + sesh.sb = makeSwitchboard(sesh, valve) return sesh } @@ -63,12 +65,18 @@ func (sesh *Session) AddConnection(conn net.Conn) { } func (sesh *Session) OpenStream() (*Stream, error) { - id := atomic.AddUint32(&sesh.nextStreamID, 1) - id -= 1 // Because atomic.AddUint32 returns the value after incrementation + select { + case <-sesh.die: + return nil, ErrBrokenSession + default: + } + id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 + // Because atomic.AddUint32 returns the value after incrementation stream := makeStream(id, sesh) sesh.streamsM.Lock() sesh.streams[id] = stream sesh.streamsM.Unlock() + log.Printf("Opening stream %v\n", id) return stream, nil } @@ -108,6 +116,7 @@ func (sesh *Session) addStream(id uint32) *Stream { sesh.streams[id] = stream sesh.streamsM.Unlock() sesh.acceptCh <- stream + log.Printf("Adding stream %v\n", id) return stream } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 777e4ea..0625b92 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -31,7 +31,7 @@ type Stream struct { // atomic nextSendSeq uint32 - closingM sync.Mutex + closingM sync.RWMutex // close(die) is used to notify different goroutines that this stream is closing die chan struct{} // to prevent closing a closed channel @@ -45,7 +45,7 @@ func makeStream(id uint32, sesh *Session) *Stream { die: make(chan struct{}), sh: []*frameNode{}, newFrameCh: make(chan *Frame, 1024), - sortedBufCh: make(chan []byte, 4096), + sortedBufCh: make(chan []byte, 1024), } go stream.recvNewFrame() return stream @@ -64,6 +64,10 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { case <-stream.die: return 0, errBrokenStream case data := <-stream.sortedBufCh: + if len(data) == 0 { + stream.passiveClose() + return 0, errBrokenStream + } if len(buf) < len(data) { log.Println(len(data)) return 0, errors.New("buf too small") @@ -75,6 +79,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) { } func (stream *Stream) Write(in []byte) (n int, err error) { + // RWMutex used here isn't really for RW. + // we use it to exploit the fact that RLock doesn't create contention. + // The use of RWMutex is so that the stream will not actively close + // in the middle of the execution of Write. This may cause the closing frame + // to be sent before the data frame and cause loss of packet. + stream.closingM.RLock() + defer stream.closingM.RUnlock() select { case <-stream.die: return 0, errBrokenStream @@ -83,13 +94,11 @@ func (stream *Stream) Write(in []byte) (n int, err error) { f := &Frame{ StreamID: stream.id, - Seq: atomic.LoadUint32(&stream.nextSendSeq), + Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1, Closing: 0, Payload: in, } - atomic.AddUint32(&stream.nextSendSeq, 1) - tlsRecord := stream.session.obfs(f) n, err = stream.session.sb.send(tlsRecord) @@ -97,9 +106,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) { } -// only close locally. Used when the stream close is notified by the remote -func (stream *Stream) passiveClose() error { - +func (stream *Stream) shutdown() error { // Lock here because closing a closed channel causes panic stream.closingM.Lock() defer stream.closingM.Unlock() @@ -108,29 +115,36 @@ func (stream *Stream) passiveClose() error { } stream.closing = true close(stream.die) + return nil +} + +// only close locally. Used when the stream close is notified by the remote +func (stream *Stream) passiveClose() error { + err := stream.shutdown() + if err != nil { + return err + } stream.session.delStream(stream.id) + log.Printf("%v passive closing\n", stream.id) return nil } // active close. Close locally and tell the remote that this stream is being closed func (stream *Stream) Close() error { - // Lock here because closing a closed channel causes panic - stream.closingM.Lock() - defer stream.closingM.Unlock() - if stream.closing { - return errRepeatStreamClosing + err := stream.shutdown() + if err != nil { + return err } - stream.closing = true - close(stream.die) + // Notify remote that this stream is closed prand.Seed(int64(stream.id)) padLen := int(math.Floor(prand.Float64()*200 + 300)) pad := make([]byte, padLen) prand.Read(pad) f := &Frame{ StreamID: stream.id, - Seq: atomic.LoadUint32(&stream.nextSendSeq), + Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1, Closing: 1, Payload: pad, } @@ -138,20 +152,12 @@ func (stream *Stream) Close() error { stream.session.sb.send(tlsRecord) stream.session.delStream(stream.id) + log.Printf("%v actively closed\n", stream.id) return nil } // Same as Close() but no call to session.delStream. // This is called in session.Close() to avoid mutex deadlock func (stream *Stream) closeNoDelMap() error { - - // Lock here because closing a closed channel causes panic - stream.closingM.Lock() - defer stream.closingM.Unlock() - if stream.closing { - return errRepeatStreamClosing - } - stream.closing = true - close(stream.die) - return nil + return stream.shutdown() } diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 35cd5c9..9723281 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -6,20 +6,34 @@ import ( "net" "sync" "sync/atomic" - - "github.com/juju/ratelimit" ) // switchboard is responsible for keeping the reference of TLS connections between client and server type switchboard struct { session *Session - wtb *ratelimit.Bucket - rtb *ratelimit.Bucket + *Valve - optimum atomic.Value + // optimum is the connEnclave with the smallest sendQueue + optimum atomic.Value // *connEnclave cesM sync.RWMutex ces []*connEnclave + + //debug + hM sync.Mutex + used map[uint32]bool +} + +func (sb *switchboard) getOptimum() *connEnclave { + if i := sb.optimum.Load(); i == nil { + return nil + } else { + return i.(*connEnclave) + } +} + +func (sb *switchboard) setOptimum(ce *connEnclave) { + sb.optimum.Store(ce) } // Some data comes from a Stream to be sent through one of the many @@ -27,45 +41,51 @@ type switchboard struct { // // In this case, we pick the remoteConn that has about the smallest sendQueue. type connEnclave struct { - sb *switchboard remoteConn net.Conn sendQueue uint32 } -// It takes at least 1 conn to start a switchboard -// TODO: does it really? -func makeSwitchboard(sesh *Session, uprate, downrate float64) *switchboard { +func makeSwitchboard(sesh *Session, valve *Valve) *switchboard { + // rates are uint64 because in the usermanager we want the bandwidth to be atomically + // operated (so that the bandwidth can change on the fly). sb := &switchboard{ session: sesh, - wtb: ratelimit.NewBucketWithRate(uprate, int64(uprate)), - rtb: ratelimit.NewBucketWithRate(downrate, int64(downrate)), + Valve: valve, ces: []*connEnclave{}, + used: make(map[uint32]bool), } return sb } var errNilOptimum error = errors.New("The optimal connection is nil") +var ErrNoRxCredit error = errors.New("No Rx credit is left") +var ErrNoTxCredit error = errors.New("No Tx credit is left") + func (sb *switchboard) send(data []byte) (int, error) { - ce := sb.optimum.Load().(*connEnclave) + ce := sb.getOptimum() if ce == nil { return 0, errNilOptimum } - sb.wtb.Wait(int64(len(data))) atomic.AddUint32(&ce.sendQueue, uint32(len(data))) go sb.updateOptimum() n, err := ce.remoteConn.Write(data) if err != nil { - return 0, err + return n, err // TODO } + if sb.AddTxCredit(-int64(n)) < 0 { + log.Println(ErrNoTxCredit) + defer sb.session.Close() + return n, ErrNoTxCredit + } atomic.AddUint32(&ce.sendQueue, ^uint32(n-1)) go sb.updateOptimum() return n, nil } func (sb *switchboard) updateOptimum() { - currentOpti := sb.optimum.Load().(*connEnclave) + currentOpti := sb.getOptimum() currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue) sb.cesM.RLock() for _, ce := range sb.ces { @@ -76,20 +96,18 @@ func (sb *switchboard) updateOptimum() { } } sb.cesM.RUnlock() - sb.optimum.Store(currentOpti) + sb.setOptimum(currentOpti) } func (sb *switchboard) addConn(conn net.Conn) { - newCe := &connEnclave{ - sb: sb, remoteConn: conn, sendQueue: 0, } sb.cesM.Lock() sb.ces = append(sb.ces, newCe) sb.cesM.Unlock() - sb.optimum.Store(newCe) + sb.setOptimum(newCe) go sb.deplex(newCe) } @@ -101,10 +119,10 @@ func (sb *switchboard) removeConn(closing *connEnclave) { break } } - sb.cesM.Unlock() if len(sb.ces) == 0 { sb.session.Close() } + sb.cesM.Unlock() } func (sb *switchboard) shutdown() { @@ -118,19 +136,40 @@ func (sb *switchboard) shutdown() { func (sb *switchboard) deplex(ce *connEnclave) { buf := make([]byte, 20480) for { - i, err := sb.session.obfsedReader(ce.remoteConn, buf) - sb.rtb.Wait(int64(i)) + n, err := sb.session.obfsedRead(ce.remoteConn, buf) + sb.rxWait(n) if err != nil { log.Println(err) go ce.remoteConn.Close() sb.removeConn(ce) return } - frame := sb.session.deobfs(buf[:i]) + if sb.AddRxCredit(-int64(n)) < 0 { + log.Println(ErrNoRxCredit) + sb.session.Close() + return + } + frame := sb.session.deobfs(buf[:n]) + + //debug + var stream *Stream if stream = sb.session.getStream(frame.StreamID); stream == nil { + if frame.Closing == 1 { + // if the frame is telling us to close a closed stream + // (this happens when ss-server and ss-local closes the stream + // simutaneously), we don't do anything + continue + } + //debug + sb.hM.Lock() + if sb.used[frame.StreamID] { + log.Printf("%v lost!\n", frame.StreamID) + } + sb.used[frame.StreamID] = true + sb.hM.Unlock() stream = sb.session.addStream(frame.StreamID) } - stream.newFrameCh <- frame + stream.writeNewFrame(frame) } } diff --git a/internal/server/auth.go b/internal/server/auth.go index 0b746ec..7d5e36d 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -11,54 +11,54 @@ import ( ecdh "github.com/cbeuw/go-ecdh" ) -// input ticket, return SID -func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, error) { +// input ticket, return UID +func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, uint32, error) { ec := ecdh.NewCurve25519ECDH() ephPub, _ := ec.Unmarshal(ticket[0:32]) key, err := ec.GenerateSharedSecret(staticPv, ephPub) if err != nil { - return nil, err + return nil, 0, err } - SID := util.AESDecrypt(ticket[0:16], key, ticket[32:64]) - return SID, nil + UIDsID := util.AESDecrypt(ticket[0:16], key, ticket[32:68]) + sessionID := binary.BigEndian.Uint32(UIDsID[32:36]) + return UIDsID[0:32], sessionID, nil } -func validateRandom(random []byte, SID []byte, time int64) bool { +func validateRandom(random []byte, UID []byte, time int64) bool { t := make([]byte, 8) binary.BigEndian.PutUint64(t, uint64(time/(12*60*60))) rdm := random[0:16] preHash := make([]byte, 56) - copy(preHash[0:32], SID) + copy(preHash[0:32], UID) copy(preHash[32:40], t) copy(preHash[40:56], rdm) h := sha256.New() h.Write(preHash) return bytes.Equal(h.Sum(nil)[0:16], random[16:32]) } -func TouchStone(ch *ClientHello, sta *State) (bool, []byte) { +func TouchStone(ch *ClientHello, sta *State) (isSS bool, UID []byte, sessionID uint32) { var random [32]byte copy(random[:], ch.random) used := sta.getUsedRandom(random) if used != 0 { log.Println("Replay! Duplicate random") - return false, nil + return false, nil, 0 } sta.putUsedRandom(random) ticket := ch.extensions[[2]byte{0x00, 0x23}] if len(ticket) < 64 { - return false, nil + return false, nil, 0 } - SID, err := decryptSessionTicket(sta.staticPv, ticket) + UID, sessionID, err := decryptSessionTicket(sta.staticPv, ticket) if err != nil { log.Printf("ts: %v\n", err) - return false, nil + return false, nil, 0 } - log.Printf("SID: %x\n", SID) - isSS := validateRandom(ch.random, SID, sta.Now().Unix()) + isSS = validateRandom(ch.random, UID, sta.Now().Unix()) if !isSS { - return false, nil + return false, nil, 0 } - return true, SID + return } diff --git a/internal/server/state.go b/internal/server/state.go index 7971b8b..8807478 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -9,7 +9,7 @@ import ( "sync" "time" - mux "github.com/cbeuw/Cloak/internal/multiplex" + "github.com/cbeuw/Cloak/internal/server/usermanager" ) type rawConfig struct { @@ -31,25 +31,28 @@ type State struct { Now func() time.Time staticPv crypto.PrivateKey + Userpanel *usermanager.Userpanel usedRandomM sync.RWMutex usedRandom map[[32]byte]int - sessionsM sync.RWMutex - sessions map[[32]byte]*mux.Session WebServerAddr string } -func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State { +func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, dbPath string) (*State, error) { + up, err := usermanager.MakeUserpanel(dbPath) + if err != nil { + return nil, err + } ret := &State{ SS_LOCAL_HOST: localHost, SS_LOCAL_PORT: localPort, SS_REMOTE_HOST: remoteHost, SS_REMOTE_PORT: remotePort, Now: nowFunc, + Userpanel: up, } ret.usedRandom = make(map[[32]byte]int) - ret.sessions = make(map[[32]byte]*mux.Session) - return ret + return ret, nil } // semi-colon separated value. @@ -115,28 +118,6 @@ func (sta *State) ParseConfig(conf string) (err error) { return nil } -func (sta *State) GetSession(SID [32]byte) *mux.Session { - sta.sessionsM.RLock() - defer sta.sessionsM.RUnlock() - if sesh, ok := sta.sessions[SID]; ok { - return sesh - } else { - return nil - } -} - -func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) { - sta.sessionsM.Lock() - sta.sessions[SID] = sesh - sta.sessionsM.Unlock() -} - -func (sta *State) DelSession(SID [32]byte) { - sta.sessionsM.Lock() - delete(sta.sessions, SID) - sta.sessionsM.Unlock() -} - func (sta *State) getUsedRandom(random [32]byte) int { sta.usedRandomM.Lock() defer sta.usedRandomM.Unlock() diff --git a/internal/server/usermanager/user.go b/internal/server/usermanager/user.go new file mode 100644 index 0000000..7b0e72e --- /dev/null +++ b/internal/server/usermanager/user.go @@ -0,0 +1,86 @@ +package usermanager + +import ( + mux "github.com/cbeuw/Cloak/internal/multiplex" + "log" + "net" + "sync" + "sync/atomic" +) + +/* +type userParams struct { + sessionsCap uint32 + upRate int64 + downRate int64 + upCredit int64 + downCredit int64 +} +*/ + +type user struct { + up *Userpanel + + uid [32]byte + + sessionsCap uint32 //userParams + + valve *mux.Valve + + sessionsM sync.RWMutex + sessions map[uint32]*mux.Session +} + +func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) *user { + valve := mux.MakeValve(upRate, downRate, upCredit, downCredit) + u := &user{ + up: up, + uid: uid, + valve: valve, + sessionsCap: sessionsCap, + sessions: make(map[uint32]*mux.Session), + } + return u +} + +func (u *user) setSessionsCap(cap uint32) { + atomic.StoreUint32(&u.sessionsCap, cap) +} + +func (u *user) GetSession(sessionID uint32) *mux.Session { + u.sessionsM.RLock() + defer u.sessionsM.RUnlock() + if sesh, ok := u.sessions[sessionID]; ok { + return sesh + } else { + return nil + } +} + +func (u *user) PutSession(sessionID uint32, sesh *mux.Session) { + u.sessionsM.Lock() + u.sessions[sessionID] = sesh + u.sessionsM.Unlock() +} + +func (u *user) DelSession(sessionID uint32) { + u.sessionsM.Lock() + delete(u.sessions, sessionID) + if len(u.sessions) == 0 { + u.sessionsM.Unlock() + u.up.delActiveUser(u.uid) + return + } + u.sessionsM.Unlock() +} + +func (u *user) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session) { + log.Printf("getting sessionID %v\n", sessionID) + if sesh = u.GetSession(sessionID); sesh != nil { + return + } else { + sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) + u.PutSession(sessionID, sesh) + return + } +} diff --git a/internal/server/usermanager/userpanel.go b/internal/server/usermanager/userpanel.go new file mode 100644 index 0000000..8408c37 --- /dev/null +++ b/internal/server/usermanager/userpanel.go @@ -0,0 +1,151 @@ +package usermanager + +import ( + "encoding/binary" + "errors" + "github.com/boltdb/bolt" + "sync" +) + +type Userpanel struct { + db *bolt.DB + + activeUsersM sync.RWMutex + activeUsers map[[32]byte]*user +} + +func MakeUserpanel(dbPath string) (*Userpanel, error) { + db, err := bolt.Open(dbPath, 0600, nil) + if err != nil { + return nil, err + } + up := &Userpanel{ + db: db, + activeUsers: make(map[[32]byte]*user), + } + return up, nil +} + +var ErrUserNotFound = errors.New("User does not exist in memory or db") + +// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's infor +// from the db and mark it as an active user +func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*user, error) { + up.activeUsersM.RLock() + if user, ok := up.activeUsers[UID]; ok { + up.activeUsersM.RUnlock() + return user, nil + } + up.activeUsersM.RUnlock() + + var sessionsCap uint32 + var upRate, downRate, upCredit, downCredit int64 + err := up.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(UID[:]) + if b == nil { + return ErrUserNotFound + } + sessionsCap = binary.BigEndian.Uint32(b.Get([]byte("sessionsCap"))) + upRate = int64(binary.BigEndian.Uint64(b.Get([]byte("upRate")))) + downRate = int64(binary.BigEndian.Uint64(b.Get([]byte("downRate")))) + upCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("upCredit")))) // reee brackets + downCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("downCredit")))) + return nil + }) + if err != nil { + return nil, err + } + // TODO: put all of these parameters in a struct instead + u := MakeUser(up, UID, sessionsCap, upRate, downRate, upCredit, downCredit) + up.activeUsersM.Lock() + up.activeUsers[UID] = u + up.activeUsersM.Unlock() + return u, nil +} + +func (up *Userpanel) AddNewUser(UID [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) error { + err := up.db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket(UID[:]) + if err != nil { + return err + } + // FIXME: obnoxious code + quad := make([]byte, 4) + binary.BigEndian.PutUint32(quad, sessionsCap) + if err = b.Put([]byte("sessionsCap"), quad); err != nil { + return err + } + oct := make([]byte, 8) + binary.BigEndian.PutUint64(oct, uint64(upRate)) + if err = b.Put([]byte("upRate"), oct); err != nil { + return err + } + binary.BigEndian.PutUint64(oct, uint64(downRate)) + if err = b.Put([]byte("downRate"), oct); err != nil { + return err + } + binary.BigEndian.PutUint64(oct, uint64(upCredit)) + if err = b.Put([]byte("upCredit"), oct); err != nil { + return err + } + binary.BigEndian.PutUint64(oct, uint64(downCredit)) + if err = b.Put([]byte("downCredit"), oct); err != nil { + return err + } + return nil + }) + return err +} + +func (up *Userpanel) updateDBEntryUint32(UID [32]byte, key string, value uint32) error { + err := up.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(UID[:]) + if b == nil { + return ErrUserNotFound + } + quad := make([]byte, 4) + binary.BigEndian.PutUint32(quad, value) + if err := b.Put([]byte(key), quad); err != nil { + return err + } + return nil + }) + return err +} + +func (up *Userpanel) updateDBEntryInt64(UID [32]byte, key string, value int64) error { + err := up.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(UID[:]) + if b == nil { + return ErrUserNotFound + } + oct := make([]byte, 8) + binary.BigEndian.PutUint64(oct, uint64(value)) + if err := b.Put([]byte(key), oct); err != nil { + return err + } + return nil + }) + return err +} + +// This is used when all sessions of a user close +func (up *Userpanel) delActiveUser(UID [32]byte) { + up.activeUsersM.Lock() + delete(up.activeUsers, UID) + up.activeUsersM.Unlock() +} + +func (up *Userpanel) getActiveUser(UID [32]byte) *user { + up.activeUsersM.RLock() + defer up.activeUsersM.RUnlock() + return up.activeUsers[UID] +} + +func (up *Userpanel) SetSessionsCap(UID [32]byte, newSessionsCap uint32) error { + if u := up.getActiveUser(UID); u != nil { + u.setSessionsCap(newSessionsCap) + } + err := up.updateDBEntryUint32(UID, "sessionsCap", newSessionsCap) + return err +} diff --git a/internal/util/obfs.go b/internal/util/obfs.go index c0ee450..8286c69 100644 --- a/internal/util/obfs.go +++ b/internal/util/obfs.go @@ -9,12 +9,13 @@ import ( // For each frame, the three parts of the header is xored with three keys. // The keys are generated from the SID and the payload of the frame. -func genXorKeys(SID []byte, data []byte) (i uint32, ii uint32, iii uint32) { +// FIXME: this code will panic if len(data)<18. +func genXorKeys(secret []byte, data []byte) (i uint32, ii uint32, iii uint32) { h := xxhash.New32() ret := make([]uint32, 3) preHash := make([]byte, 16) for j := 0; j < 3; j++ { - copy(preHash[0:10], SID[j*10:j*10+10]) + copy(preHash[0:10], secret[j*10:j*10+10]) copy(preHash[10:16], data[j*6:j*6+6]) h.Write(preHash) ret[j] = h.Sum32() diff --git a/internal/util/util.go b/internal/util/util.go index ef38ffa..cc8e6d0 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -43,14 +43,14 @@ func BtoInt(b []byte) int { // PsudoRandBytes returns a byte slice filled with psudorandom bytes generated by the seed func PsudoRandBytes(length int, seed int64) []byte { - prand.Seed(seed) + r := prand.New(prand.NewSource(seed)) ret := make([]byte, length) - prand.Read(ret) + r.Read(ret) return ret } -// ReadTillDrain reads TLS data according to its record layer -func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) { +// ReadTLS reads TLS data according to its record layer +func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) { // TCP is a stream. Multiple TLS messages can arrive at the same time, // a single message can also be segmented due to MTU of the IP layer. // This function guareentees a single TLS message to be read and everything