From 69a73ecfc0c66ae0af3712e466281397045b32e2 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Fri, 17 Apr 2020 14:21:17 +0100 Subject: [PATCH] Refactor usermanager --- internal/server/activeuser_test.go | 3 +- internal/server/dispatcher.go | 3 +- internal/server/state.go | 21 +- .../{localmanager_api.yaml => api.yaml} | 0 internal/server/usermanager/api_router.go | 134 +++++++++++ internal/server/usermanager/localmanager.go | 122 ++++++---- .../server/usermanager/localmanager_api.go | 165 -------------- .../server/usermanager/localmanager_test.go | 208 ++++++++++++++++++ internal/server/usermanager/usermanager.go | 14 ++ internal/server/userpanel_test.go | 3 +- 10 files changed, 451 insertions(+), 222 deletions(-) rename internal/server/usermanager/{localmanager_api.yaml => api.yaml} (100%) create mode 100644 internal/server/usermanager/api_router.go delete mode 100644 internal/server/usermanager/localmanager_api.go create mode 100644 internal/server/usermanager/localmanager_test.go diff --git a/internal/server/activeuser_test.go b/internal/server/activeuser_test.go index b92bfeb..996cb56 100644 --- a/internal/server/activeuser_test.go +++ b/internal/server/activeuser_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/rand" "encoding/base64" + "github.com/cbeuw/Cloak/internal/common" mux "github.com/cbeuw/Cloak/internal/multiplex" "github.com/cbeuw/Cloak/internal/server/usermanager" "os" @@ -23,7 +24,7 @@ func getSeshConfig(unordered bool) mux.SessionConfig { } func TestActiveUser_Bypass(t *testing.T) { - manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME) + manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME, common.RealWorldState) if err != nil { t.Fatal("failed to make local manager", err) } diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index f3179b1..05358e7 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "github.com/cbeuw/Cloak/internal/common" + "github.com/cbeuw/Cloak/internal/server/usermanager" "io" "net" "net/http" @@ -105,7 +106,7 @@ func dispatchConnection(conn net.Conn, sta *State) { sesh.AddConnection(preparedConn) //TODO: Router could be nil in cnc mode log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session") - err = http.Serve(sesh, sta.LocalAPIRouter) + err = http.Serve(sesh, usermanager.APIRouterOf(sta.Panel.Manager)) if err != nil { log.Error(err) return diff --git a/internal/server/state.go b/internal/server/state.go index ba686b0..365f255 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -12,8 +12,6 @@ import ( "strings" "sync" "time" - - gmux "github.com/gorilla/mux" ) type RawConfig struct { @@ -50,8 +48,7 @@ type State struct { usedRandomM sync.RWMutex UsedRandom map[[32]byte]int64 - Panel *userPanel - LocalAPIRouter *gmux.Router + Panel *userPanel } func parseRedirAddr(redirAddr string) (net.Addr, string, error) { @@ -86,17 +83,6 @@ func parseRedirAddr(redirAddr string) (net.Addr, string, error) { return redirHost, port, nil } -func parseLocalPanel(databasePath string) (*userPanel, *gmux.Router, error) { - manager, err := usermanager.MakeLocalManager(databasePath) - if err != nil { - return nil, nil, err - } - panel := MakeUserPanel(manager) - router := manager.Router - return panel, router, nil - -} - func parseProxyBook(bookEntries map[string][]string) (map[string]net.Addr, error) { proxyBook := map[string]net.Addr{} for name, pair := range bookEntries { @@ -156,10 +142,11 @@ func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, er err = errors.New("command & control mode not implemented") return } else { - sta.Panel, sta.LocalAPIRouter, err = parseLocalPanel(preParse.DatabasePath) + manager, err := usermanager.MakeLocalManager(preParse.DatabasePath, worldState) if err != nil { - return + return sta, err } + sta.Panel = MakeUserPanel(manager) } if preParse.StreamTimeout == 0 { diff --git a/internal/server/usermanager/localmanager_api.yaml b/internal/server/usermanager/api.yaml similarity index 100% rename from internal/server/usermanager/localmanager_api.yaml rename to internal/server/usermanager/api.yaml diff --git a/internal/server/usermanager/api_router.go b/internal/server/usermanager/api_router.go new file mode 100644 index 0000000..d27e7e9 --- /dev/null +++ b/internal/server/usermanager/api_router.go @@ -0,0 +1,134 @@ +package usermanager + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "net/http" + + gmux "github.com/gorilla/mux" +) + +type APIRouter struct { + *gmux.Router + manager UserManager +} + +func APIRouterOf(manager UserManager) *APIRouter { + ret := &APIRouter{ + manager: manager, + } + ret.registerMux() + return ret +} + +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + next.ServeHTTP(w, r) + }) +} + +func (ar *APIRouter) registerMux() { + ar.Router = gmux.NewRouter() + ar.HandleFunc("/admin/users", ar.listAllUsersHlr).Methods("GET") + ar.HandleFunc("/admin/users/{UID}", ar.getUserInfoHlr).Methods("GET") + ar.HandleFunc("/admin/users/{UID}", ar.writeUserInfoHlr).Methods("POST") + ar.HandleFunc("/admin/users/{UID}", ar.deleteUserHlr).Methods("DELETE") + ar.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS") + }) + ar.Use(corsMiddleware) +} + +func (ar *APIRouter) listAllUsersHlr(w http.ResponseWriter, r *http.Request) { + infos, err := ar.manager.ListAllUsers() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + resp, err := json.Marshal(infos) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, _ = w.Write(resp) +} + +func (ar *APIRouter) getUserInfoHlr(w http.ResponseWriter, r *http.Request) { + b64UID := gmux.Vars(r)["UID"] + if b64UID == "" { + http.Error(w, "UID cannot be empty", http.StatusBadRequest) + } + + UID, err := base64.URLEncoding.DecodeString(b64UID) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + uinfo, err := ar.manager.GetUserInfo(UID) + if err == ErrUserNotFound { + http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound) + return + } + resp, err := json.Marshal(uinfo) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, _ = w.Write(resp) +} + +func (ar *APIRouter) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) { + b64UID := gmux.Vars(r)["UID"] + if b64UID == "" { + http.Error(w, "UID cannot be empty", http.StatusBadRequest) + return + } + UID, err := base64.URLEncoding.DecodeString(b64UID) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + jsonUinfo := r.FormValue("UserInfo") + if jsonUinfo == "" { + http.Error(w, "UserInfo cannot be empty", http.StatusBadRequest) + return + } + var uinfo UserInfo + err = json.Unmarshal([]byte(jsonUinfo), &uinfo) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if !bytes.Equal(UID, uinfo.UID) { + http.Error(w, "UID mismatch", http.StatusBadRequest) + } + + err = ar.manager.WriteUserInfo(uinfo) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + w.WriteHeader(http.StatusCreated) +} + +func (ar *APIRouter) deleteUserHlr(w http.ResponseWriter, r *http.Request) { + b64UID := gmux.Vars(r)["UID"] + if b64UID == "" { + http.Error(w, "UID cannot be empty", http.StatusBadRequest) + return + } + UID, err := base64.URLEncoding.DecodeString(b64UID) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + err = ar.manager.DeleteUser(UID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/server/usermanager/localmanager.go b/internal/server/usermanager/localmanager.go index 7a8bdb8..b0281b1 100644 --- a/internal/server/usermanager/localmanager.go +++ b/internal/server/usermanager/localmanager.go @@ -2,68 +2,43 @@ package usermanager import ( "encoding/binary" + "github.com/cbeuw/Cloak/internal/common" log "github.com/sirupsen/logrus" - "net/http" - "time" - - gmux "github.com/gorilla/mux" bolt "go.etcd.io/bbolt" ) var Uint32 = binary.BigEndian.Uint32 var Uint64 = binary.BigEndian.Uint64 -var PutUint32 = binary.BigEndian.PutUint32 -var PutUint64 = binary.BigEndian.PutUint64 func i64ToB(value int64) []byte { oct := make([]byte, 8) - PutUint64(oct, uint64(value)) + binary.BigEndian.PutUint64(oct, uint64(value)) return oct } func i32ToB(value int32) []byte { nib := make([]byte, 4) - PutUint32(nib, uint32(value)) + binary.BigEndian.PutUint32(nib, uint32(value)) return nib } -// localManager is responsible for routing API calls to appropriate handlers and manage the local user database accordingly +// localManager is responsible for managing the local user database type localManager struct { - db *bolt.DB - Router *gmux.Router + db *bolt.DB + world common.WorldState } -func MakeLocalManager(dbPath string) (*localManager, error) { +func MakeLocalManager(dbPath string, worldState common.WorldState) (*localManager, error) { db, err := bolt.Open(dbPath, 0600, nil) if err != nil { return nil, err } ret := &localManager{ - db: db, + db: db, + world: worldState, } - ret.Router = ret.registerMux() return ret, nil } -func corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - next.ServeHTTP(w, r) - }) -} - -func (manager *localManager) registerMux() *gmux.Router { - r := gmux.NewRouter() - r.HandleFunc("/admin/users", manager.listAllUsersHlr).Methods("GET") - r.HandleFunc("/admin/users/{UID}", manager.getUserInfoHlr).Methods("GET") - r.HandleFunc("/admin/users/{UID}", manager.writeUserInfoHlr).Methods("POST") - r.HandleFunc("/admin/users/{UID}", manager.deleteUserHlr).Methods("DELETE") - r.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS") - }) - r.Use(corsMiddleware) - return r -} - // Authenticate user returns err==nil along with the users' up and down bandwidths if the UID is allowed to connect // More specifically it checks that the user exists, that it has positive credit and that it hasn't expired func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) { @@ -89,7 +64,7 @@ func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) if downCredit <= 0 { return 0, 0, ErrNoDownCredit } - if expiryTime < time.Now().Unix() { + if expiryTime < manager.world.Now().Unix() { return 0, 0, ErrUserExpired } @@ -123,7 +98,7 @@ func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo Authorisation if downCredit <= 0 { return ErrNoDownCredit } - if expiryTime < time.Now().Unix() { + if expiryTime < manager.world.Now().Unix() { return ErrUserExpired } @@ -190,7 +165,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo } expiry := int64(Uint64(bucket.Get([]byte("ExpiryTime")))) - if time.Now().Unix() > expiry { + if manager.world.Now().Unix() > expiry { resp = StatusResponse{ status.UID, TERMINATE, @@ -205,6 +180,79 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo return responses, err } +func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) { + err = manager.db.View(func(tx *bolt.Tx) error { + err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { + var uinfo UserInfo + uinfo.UID = UID + uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap")))) + uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate")))) + uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate")))) + uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) + uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) + uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) + infos = append(infos, uinfo) + return nil + }) + return err + }) + return +} + +func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) { + err = manager.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(UID) + if bucket == nil { + return ErrUserNotFound + } + uinfo.UID = UID + uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap")))) + uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate")))) + uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate")))) + uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) + uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) + uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) + return nil + }) + return +} + +func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) { + err = manager.db.Update(func(tx *bolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) + if err != nil { + return err + } + if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { + return err + } + if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { + return err + } + if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { + return err + } + if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { + return err + } + if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { + return err + } + if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { + return err + } + return nil + }) + return +} + +func (manager *localManager) DeleteUser(UID []byte) (err error) { + err = manager.db.Update(func(tx *bolt.Tx) error { + return tx.DeleteBucket(UID) + }) + return +} + func (manager *localManager) Close() error { return manager.db.Close() } diff --git a/internal/server/usermanager/localmanager_api.go b/internal/server/usermanager/localmanager_api.go deleted file mode 100644 index abc88e5..0000000 --- a/internal/server/usermanager/localmanager_api.go +++ /dev/null @@ -1,165 +0,0 @@ -package usermanager - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "net/http" - - gmux "github.com/gorilla/mux" - bolt "go.etcd.io/bbolt" -) - -type UserInfo struct { - UID []byte - SessionsCap int - UpRate int64 - DownRate int64 - UpCredit int64 - DownCredit int64 - ExpiryTime int64 -} - -func (manager *localManager) listAllUsersHlr(w http.ResponseWriter, r *http.Request) { - var infos []UserInfo - _ = manager.db.View(func(tx *bolt.Tx) error { - err := tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { - var uinfo UserInfo - uinfo.UID = UID - uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) - infos = append(infos, uinfo) - return nil - }) - return err - }) - resp, err := json.Marshal(infos) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _, _ = w.Write(resp) -} - -func (manager *localManager) getUserInfoHlr(w http.ResponseWriter, r *http.Request) { - b64UID := gmux.Vars(r)["UID"] - if b64UID == "" { - http.Error(w, "UID cannot be empty", http.StatusBadRequest) - } - - UID, err := base64.URLEncoding.DecodeString(b64UID) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - var uinfo UserInfo - err = manager.db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket(UID) - if bucket == nil { - return ErrUserNotFound - } - uinfo.UID = UID - uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) - return nil - }) - if err == ErrUserNotFound { - http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound) - return - } - resp, err := json.Marshal(uinfo) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _, _ = w.Write(resp) -} - -func (manager *localManager) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) { - b64UID := gmux.Vars(r)["UID"] - if b64UID == "" { - http.Error(w, "UID cannot be empty", http.StatusBadRequest) - return - } - UID, err := base64.URLEncoding.DecodeString(b64UID) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - jsonUinfo := r.FormValue("UserInfo") - if jsonUinfo == "" { - http.Error(w, "UserInfo cannot be empty", http.StatusBadRequest) - return - } - var uinfo UserInfo - err = json.Unmarshal([]byte(jsonUinfo), &uinfo) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !bytes.Equal(UID, uinfo.UID) { - http.Error(w, "UID mismatch", http.StatusBadRequest) - } - - err = manager.db.Update(func(tx *bolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) - if err != nil { - return err - } - if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { - return err - } - if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { - return err - } - if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { - return err - } - if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { - return err - } - if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { - return err - } - if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { - return err - } - return nil - }) - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - w.WriteHeader(http.StatusCreated) -} - -func (manager *localManager) deleteUserHlr(w http.ResponseWriter, r *http.Request) { - b64UID := gmux.Vars(r)["UID"] - if b64UID == "" { - http.Error(w, "UID cannot be empty", http.StatusBadRequest) - return - } - UID, err := base64.URLEncoding.DecodeString(b64UID) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - err = manager.db.Update(func(tx *bolt.Tx) error { - return tx.DeleteBucket(UID) - }) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - w.WriteHeader(http.StatusOK) -} diff --git a/internal/server/usermanager/localmanager_test.go b/internal/server/usermanager/localmanager_test.go new file mode 100644 index 0000000..5238abc --- /dev/null +++ b/internal/server/usermanager/localmanager_test.go @@ -0,0 +1,208 @@ +package usermanager + +import ( + "github.com/cbeuw/Cloak/internal/common" + "io/ioutil" + "os" + "reflect" + "testing" + "time" +) + +var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} +var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) +var mockUserInfo = UserInfo{ + UID: mockUID, + SessionsCap: 0, + UpRate: 0, + DownRate: 0, + UpCredit: 0, + DownCredit: 0, + ExpiryTime: 100, +} + +func TestLocalManager_WriteUserInfo(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState) + if err != nil { + t.Fatal(err) + } + + err = mgr.WriteUserInfo(mockUserInfo) + + if err != nil { + t.Error(err) + } +} + +func TestLocalManager_GetUserInfo(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState) + if err != nil { + t.Fatal(err) + } + + t.Run("simple fetch", func(t *testing.T) { + _ = mgr.WriteUserInfo(mockUserInfo) + gotInfo, err := mgr.GetUserInfo(mockUID) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(gotInfo, mockUserInfo) { + t.Errorf("got wrong user info: %v", gotInfo) + } + }) + + t.Run("update a field", func(t *testing.T) { + _ = mgr.WriteUserInfo(mockUserInfo) + updatedUserInfo := mockUserInfo + updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1 + + err = mgr.WriteUserInfo(updatedUserInfo) + if err != nil { + t.Error(err) + } + + gotInfo, err := mgr.GetUserInfo(mockUID) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(gotInfo, updatedUserInfo) { + t.Errorf("got wrong user info: %v", updatedUserInfo) + } + }) + + t.Run("non existent user", func(t *testing.T) { + _, err := mgr.GetUserInfo(make([]byte, 16)) + if err != ErrUserNotFound { + t.Errorf("expecting error %v, got %v", ErrUserNotFound, err) + } + }) +} + +func TestLocalManager_DeleteUser(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState) + if err != nil { + t.Fatal(err) + } + + _ = mgr.WriteUserInfo(mockUserInfo) + err = mgr.DeleteUser(mockUID) + if err != nil { + t.Error(err) + } + + _, err = mgr.GetUserInfo(mockUID) + if err != ErrUserNotFound { + t.Error("user not deleted") + } +} + +var validUserInfo = UserInfo{ + UID: mockUID, + SessionsCap: 10, + UpRate: 100, + DownRate: 1000, + UpCredit: 10000, + DownCredit: 100000, + ExpiryTime: 1000000, +} + +func TestLocalManager_AuthenticateUser(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState) + if err != nil { + t.Fatal(err) + } + + t.Run("normal auth", func(t *testing.T) { + _ = mgr.WriteUserInfo(validUserInfo) + upRate, downRate, err := mgr.AuthenticateUser(validUserInfo.UID) + if err != nil { + t.Error(err) + } + + if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate { + t.Error("wrong up or down rate") + } + }) + + t.Run("non existent user", func(t *testing.T) { + _, _, err := mgr.AuthenticateUser(make([]byte, 16)) + if err != ErrUserNotFound { + t.Error("user found") + } + }) + + t.Run("expired user", func(t *testing.T) { + expiredUserInfo := validUserInfo + expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + + _ = mgr.WriteUserInfo(expiredUserInfo) + + _, _, err := mgr.AuthenticateUser(expiredUserInfo.UID) + if err != ErrUserExpired { + t.Error("user not expired") + } + }) + + t.Run("no credit", func(t *testing.T) { + creditlessUserInfo := validUserInfo + creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1 + + _ = mgr.WriteUserInfo(creditlessUserInfo) + + _, _, err := mgr.AuthenticateUser(creditlessUserInfo.UID) + if err != ErrNoUpCredit && err != ErrNoDownCredit { + t.Error("user not creditless") + } + }) +} + +func TestLocalManager_AuthoriseNewSession(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState) + if err != nil { + t.Fatal(err) + } + + t.Run("normal auth", func(t *testing.T) { + _ = mgr.WriteUserInfo(validUserInfo) + err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) + if err != nil { + t.Error(err) + } + }) + + t.Run("non existent user", func(t *testing.T) { + err := mgr.AuthoriseNewSession(make([]byte, 16), AuthorisationInfo{NumExistingSessions: 0}) + if err != ErrUserNotFound { + t.Error("user found") + } + }) + + t.Run("expired user", func(t *testing.T) { + expiredUserInfo := validUserInfo + expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + + _ = mgr.WriteUserInfo(expiredUserInfo) + err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) + if err != ErrUserExpired { + t.Error("user not expired") + } + }) + + t.Run("too many sessions", func(t *testing.T) { + _ = mgr.WriteUserInfo(validUserInfo) + err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: validUserInfo.SessionsCap + 1}) + if err != ErrSessionsCapReached { + t.Error("session cap not reached") + } + }) +} diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index 99ac8d9..95045bd 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -14,6 +14,16 @@ type StatusUpdate struct { Timestamp int64 } +type UserInfo struct { + UID []byte + SessionsCap int + UpRate int64 + DownRate int64 + UpCredit int64 + DownCredit int64 + ExpiryTime int64 +} + type StatusResponse struct { UID []byte Action int @@ -39,4 +49,8 @@ type UserManager interface { AuthenticateUser([]byte) (int64, int64, error) AuthoriseNewSession([]byte, AuthorisationInfo) error UploadStatus([]StatusUpdate) ([]StatusResponse, error) + ListAllUsers() ([]UserInfo, error) + GetUserInfo(UID []byte) (UserInfo, error) + WriteUserInfo(UserInfo) error + DeleteUser(UID []byte) error } diff --git a/internal/server/userpanel_test.go b/internal/server/userpanel_test.go index 4e13399..cad0283 100644 --- a/internal/server/userpanel_test.go +++ b/internal/server/userpanel_test.go @@ -2,6 +2,7 @@ package server import ( "encoding/base64" + "github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/server/usermanager" "os" "testing" @@ -10,7 +11,7 @@ import ( const MOCK_DB_NAME = "userpanel_test_mock_database.db" func TestUserPanel_BypassUser(t *testing.T) { - manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME) + manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME, common.RealWorldState) if err != nil { t.Error("failed to make local manager", err) }