mirror of https://github.com/cbeuw/Cloak
Refactor usermanager
parent
b353638c1c
commit
69a73ecfc0
@ -0,0 +1,134 @@
|
||||
package usermanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
gmux "github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
type APIRouter struct {
|
||||
*gmux.Router
|
||||
manager UserManager
|
||||
}
|
||||
|
||||
func APIRouterOf(manager UserManager) *APIRouter {
|
||||
ret := &APIRouter{
|
||||
manager: manager,
|
||||
}
|
||||
ret.registerMux()
|
||||
return ret
|
||||
}
|
||||
|
||||
func corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (ar *APIRouter) registerMux() {
|
||||
ar.Router = gmux.NewRouter()
|
||||
ar.HandleFunc("/admin/users", ar.listAllUsersHlr).Methods("GET")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.getUserInfoHlr).Methods("GET")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.writeUserInfoHlr).Methods("POST")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.deleteUserHlr).Methods("DELETE")
|
||||
ar.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
|
||||
})
|
||||
ar.Use(corsMiddleware)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
|
||||
infos, err := ar.manager.ListAllUsers()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp, err := json.Marshal(infos)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
uinfo, err := ar.manager.GetUserInfo(UID)
|
||||
if err == ErrUserNotFound {
|
||||
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
resp, err := json.Marshal(uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
jsonUinfo := r.FormValue("UserInfo")
|
||||
if jsonUinfo == "" {
|
||||
http.Error(w, "UserInfo cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var uinfo UserInfo
|
||||
err = json.Unmarshal([]byte(jsonUinfo), &uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(UID, uinfo.UID) {
|
||||
http.Error(w, "UID mismatch", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = ar.manager.WriteUserInfo(uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = ar.manager.DeleteUser(UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
@ -1,165 +0,0 @@
|
||||
package usermanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
gmux "github.com/gorilla/mux"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
UID []byte
|
||||
SessionsCap int
|
||||
UpRate int64
|
||||
DownRate int64
|
||||
UpCredit int64
|
||||
DownCredit int64
|
||||
ExpiryTime int64
|
||||
}
|
||||
|
||||
func (manager *localManager) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
|
||||
var infos []UserInfo
|
||||
_ = manager.db.View(func(tx *bolt.Tx) error {
|
||||
err := tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
|
||||
uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
|
||||
infos = append(infos, uinfo)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
})
|
||||
resp, err := json.Marshal(infos)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (manager *localManager) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var uinfo UserInfo
|
||||
err = manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(UID)
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
|
||||
uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err == ErrUserNotFound {
|
||||
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
resp, err := json.Marshal(uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (manager *localManager) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
jsonUinfo := r.FormValue("UserInfo")
|
||||
if jsonUinfo == "" {
|
||||
http.Error(w, "UserInfo cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var uinfo UserInfo
|
||||
err = json.Unmarshal([]byte(jsonUinfo), &uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(UID, uinfo.UID) {
|
||||
http.Error(w, "UID mismatch", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = manager.db.Update(func(tx *bolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (manager *localManager) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.DeleteBucket(UID)
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
@ -0,0 +1,208 @@
|
||||
package usermanager
|
||||
|
||||
import (
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
||||
var mockUserInfo = UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: 0,
|
||||
UpRate: 0,
|
||||
DownRate: 0,
|
||||
UpCredit: 0,
|
||||
DownCredit: 0,
|
||||
ExpiryTime: 100,
|
||||
}
|
||||
|
||||
func TestLocalManager_WriteUserInfo(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = mgr.WriteUserInfo(mockUserInfo)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalManager_GetUserInfo(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("simple fetch", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
gotInfo, err := mgr.GetUserInfo(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotInfo, mockUserInfo) {
|
||||
t.Errorf("got wrong user info: %v", gotInfo)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("update a field", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
updatedUserInfo := mockUserInfo
|
||||
updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1
|
||||
|
||||
err = mgr.WriteUserInfo(updatedUserInfo)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
gotInfo, err := mgr.GetUserInfo(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotInfo, updatedUserInfo) {
|
||||
t.Errorf("got wrong user info: %v", updatedUserInfo)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
_, err := mgr.GetUserInfo(make([]byte, 16))
|
||||
if err != ErrUserNotFound {
|
||||
t.Errorf("expecting error %v, got %v", ErrUserNotFound, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalManager_DeleteUser(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
err = mgr.DeleteUser(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
_, err = mgr.GetUserInfo(mockUID)
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
var validUserInfo = UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: 10,
|
||||
UpRate: 100,
|
||||
DownRate: 1000,
|
||||
UpCredit: 10000,
|
||||
DownCredit: 100000,
|
||||
ExpiryTime: 1000000,
|
||||
}
|
||||
|
||||
func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("normal auth", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
upRate, downRate, err := mgr.AuthenticateUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate {
|
||||
t.Error("wrong up or down rate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
_, _, err := mgr.AuthenticateUser(make([]byte, 16))
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired user", func(t *testing.T) {
|
||||
expiredUserInfo := validUserInfo
|
||||
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
|
||||
|
||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||
|
||||
_, _, err := mgr.AuthenticateUser(expiredUserInfo.UID)
|
||||
if err != ErrUserExpired {
|
||||
t.Error("user not expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no credit", func(t *testing.T) {
|
||||
creditlessUserInfo := validUserInfo
|
||||
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1
|
||||
|
||||
_ = mgr.WriteUserInfo(creditlessUserInfo)
|
||||
|
||||
_, _, err := mgr.AuthenticateUser(creditlessUserInfo.UID)
|
||||
if err != ErrNoUpCredit && err != ErrNoDownCredit {
|
||||
t.Error("user not creditless")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalManager_AuthoriseNewSession(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("normal auth", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
err := mgr.AuthoriseNewSession(make([]byte, 16), AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired user", func(t *testing.T) {
|
||||
expiredUserInfo := validUserInfo
|
||||
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
|
||||
|
||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != ErrUserExpired {
|
||||
t.Error("user not expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too many sessions", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: validUserInfo.SessionsCap + 1})
|
||||
if err != ErrSessionsCapReached {
|
||||
t.Error("session cap not reached")
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue