Allow partial json to be POSTed to admin/user/{UID} for only updating select fields

This commit is contained in:
Andy Wang 2021-01-05 21:48:02 +00:00
parent a643402e11
commit f27889af11
No known key found for this signature in database
GPG Key ID: 181B49F9F38F3374
6 changed files with 142 additions and 88 deletions

View File

@ -2,7 +2,7 @@ swagger: '2.0'
info: info:
description: | description: |
This is the API of Cloak server This is the API of Cloak server
version: 1.0.0 version: 0.0.2
title: Cloak Server title: Cloak Server
contact: contact:
email: cbeuw.andy@gmail.com email: cbeuw.andy@gmail.com
@ -12,8 +12,6 @@ info:
# host: petstore.swagger.io # host: petstore.swagger.io
# basePath: /v2 # basePath: /v2
tags: tags:
- name: admin
description: Endpoints used by the host administrators
- name: users - name: users
description: Operations related to user controls by admin description: Operations related to user controls by admin
# schemes: # schemes:
@ -22,7 +20,6 @@ paths:
/admin/users: /admin/users:
get: get:
tags: tags:
- admin
- users - users
summary: Show all users summary: Show all users
description: Returns an array of all UserInfo description: Returns an array of all UserInfo
@ -41,7 +38,6 @@ paths:
/admin/users/{UID}: /admin/users/{UID}:
get: get:
tags: tags:
- admin
- users - users
summary: Show userinfo by UID summary: Show userinfo by UID
description: Returns a UserInfo object description: Returns a UserInfo object
@ -68,7 +64,6 @@ paths:
description: internal error description: internal error
post: post:
tags: tags:
- admin
- users - users
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
operationId: writeUserInfo operationId: writeUserInfo
@ -100,7 +95,6 @@ paths:
description: internal error description: internal error
delete: delete:
tags: tags:
- admin
- users - users
summary: Deletes a user summary: Deletes a user
operationId: deleteUser operationId: deleteUser

View File

@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) {
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body) assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
}) })
t.Run("partial update", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
assert.NoError(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
partialUserInfo := UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
}
partialMarshalled, _ := json.Marshal(partialUserInfo)
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
var got UserInfo
err = json.Unmarshal(rr.Body.Bytes(), &got)
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = partialUserInfo.SessionsCap
assert.EqualValues(t, expected, got)
})
t.Run("empty parameter", func(t *testing.T) { t.Run("empty parameter", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled)) req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
if err != nil { if err != nil {

View File

@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
"User no longer exists", "User no longer exists",
} }
responses = append(responses, resp) responses = append(responses, resp)
continue
} }
oldUp := int64(u64(bucket.Get([]byte("UpCredit")))) oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo var uinfo UserInfo
uinfo.UID = UID uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
infos = append(infos, uinfo) infos = append(infos, uinfo)
return nil return nil
}) })
@ -200,40 +201,52 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error)
return ErrUserNotFound return ErrUserNotFound
} }
uinfo.UID = UID uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
return nil return nil
}) })
return return
} }
func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) { func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error { err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) bucket, err := tx.CreateBucketIfNotExists(u.UID)
if err != nil { if err != nil {
return err return err
} }
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { if u.SessionsCap != nil {
return err if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
return err
}
} }
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { if u.UpRate != nil {
return err if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
return err
}
} }
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { if u.DownRate != nil {
return err if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
return err
}
} }
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { if u.UpCredit != nil {
return err if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
return err
}
} }
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { if u.DownCredit != nil {
return err if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
return err
}
} }
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { if u.ExpiryTime != nil {
return err if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
return err
}
} }
return nil return nil
}) })

View File

@ -3,6 +3,7 @@ package usermanager
import ( import (
"encoding/binary" "encoding/binary"
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"os" "os"
@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var mockUserInfo = UserInfo{ var mockUserInfo = UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 0, SessionsCap: JustInt32(10),
UpRate: 0, UpRate: JustInt64(100),
DownRate: 0, DownRate: JustInt64(1000),
UpCredit: 0, UpCredit: JustInt64(10000),
DownCredit: 0, DownCredit: JustInt64(100000),
ExpiryTime: 100, ExpiryTime: JustInt64(1000000),
} }
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) { func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
got, err := mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, mockUserInfo, got)
/* Partial update */
err = mgr.WriteUserInfo(UserInfo{
UID: mockUID,
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
})
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
got, err = mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, expected, got)
} }
func TestLocalManager_GetUserInfo(t *testing.T) { func TestLocalManager_GetUserInfo(t *testing.T) {
@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) {
t.Run("update a field", func(t *testing.T) { t.Run("update a field", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo) _ = mgr.WriteUserInfo(mockUserInfo)
updatedUserInfo := mockUserInfo updatedUserInfo := mockUserInfo
updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1 updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
err := mgr.WriteUserInfo(updatedUserInfo) err := mgr.WriteUserInfo(updatedUserInfo)
if err != nil { if err != nil {
@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) {
} }
} }
var validUserInfo = UserInfo{ var validUserInfo = mockUserInfo
UID: mockUID,
SessionsCap: 10,
UpRate: 100,
DownRate: 1000,
UpCredit: 10000,
DownCredit: 100000,
ExpiryTime: 1000000,
}
func TestLocalManager_AuthenticateUser(t *testing.T) { func TestLocalManager_AuthenticateUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Error(err) t.Error(err)
} }
if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate { if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
t.Error("wrong up or down rate") t.Error("wrong up or down rate")
} }
}) })
@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("expired user", func(t *testing.T) { t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo) _ = mgr.WriteUserInfo(expiredUserInfo)
@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("no credit", func(t *testing.T) { t.Run("no credit", func(t *testing.T) {
creditlessUserInfo := validUserInfo creditlessUserInfo := validUserInfo
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1 creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
_ = mgr.WriteUserInfo(creditlessUserInfo) _ = mgr.WriteUserInfo(creditlessUserInfo)
@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("expired user", func(t *testing.T) { t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo) _ = mgr.WriteUserInfo(expiredUserInfo)
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("too many sessions", func(t *testing.T) { t.Run("too many sessions", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo) _ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)}) err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
if err != ErrSessionsCapReached { if err != ErrSessionsCapReached {
t.Error("session cap not reached") t.Error("session cap not reached")
} }
@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) {
t.Error(err) t.Error(err)
} }
if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage { if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
t.Error("up usage incorrect") t.Error("up usage incorrect")
} }
if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage { if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
t.Error("down usage incorrect") t.Error("down usage incorrect")
} }
}) })
@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) {
UID: validUserInfo.UID, UID: validUserInfo.UID,
Active: true, Active: true,
NumSession: 1, NumSession: 1,
UpUsage: validUserInfo.UpCredit + 100, UpUsage: *validUserInfo.UpCredit + 100,
DownUsage: 0, DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(), Timestamp: mockWorldState.Now().Unix(),
}, },
@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) {
Active: true, Active: true,
NumSession: 1, NumSession: 1,
UpUsage: 0, UpUsage: 0,
DownUsage: validUserInfo.DownCredit + 100, DownUsage: *validUserInfo.DownCredit + 100,
Timestamp: mockWorldState.Now().Unix(), Timestamp: mockWorldState.Now().Unix(),
}, },
}, },
{"expired", {"expired",
UserInfo{ UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 10, SessionsCap: JustInt32(10),
UpRate: 0, UpRate: JustInt64(0),
DownRate: 0, DownRate: JustInt64(0),
UpCredit: 0, UpCredit: JustInt64(0),
DownCredit: 0, DownCredit: JustInt64(0),
ExpiryTime: -1, ExpiryTime: JustInt64(-1),
}, },
StatusUpdate{ StatusUpdate{
UID: mockUserInfo.UID, UID: mockUserInfo.UID,
@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) {
rand.Read(randUID) rand.Read(randUID)
newUser := UserInfo{ newUser := UserInfo{
UID: randUID, UID: randUID,
SessionsCap: rand.Int31(), SessionsCap: JustInt32(rand.Int31()),
UpRate: rand.Int63(), UpRate: JustInt64(rand.Int63()),
DownRate: rand.Int63(), DownRate: JustInt64(rand.Int63()),
UpCredit: rand.Int63(), UpCredit: JustInt64(rand.Int63()),
DownCredit: rand.Int63(), DownCredit: JustInt64(rand.Int63()),
ExpiryTime: rand.Int63(), ExpiryTime: JustInt64(rand.Int63()),
} }
users = append(users, newUser) users = append(users, newUser)
wg.Add(1) wg.Add(1)

View File

@ -14,16 +14,23 @@ type StatusUpdate struct {
Timestamp int64 Timestamp int64
} }
type MaybeInt32 *int32
type MaybeInt64 *int64
type UserInfo struct { type UserInfo struct {
UID []byte UID []byte
SessionsCap int32 SessionsCap MaybeInt32
UpRate int64 UpRate MaybeInt64
DownRate int64 DownRate MaybeInt64
UpCredit int64 UpCredit MaybeInt64
DownCredit int64 DownCredit MaybeInt64
ExpiryTime int64 ExpiryTime MaybeInt64
} }
func JustInt32(v int32) MaybeInt32 { return &v }
func JustInt64(v int64) MaybeInt64 { return &v }
type StatusResponse struct { type StatusResponse struct {
UID []byte UID []byte
Action int Action int

View File

@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var validUserInfo = usermanager.UserInfo{ var validUserInfo = usermanager.UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 10, SessionsCap: usermanager.JustInt32(10),
UpRate: 100, UpRate: usermanager.JustInt64(100),
DownRate: 1000, DownRate: usermanager.JustInt64(1000),
UpCredit: 10000, UpCredit: usermanager.JustInt64(10000),
DownCredit: 100000, DownCredit: usermanager.JustInt64(100000),
ExpiryTime: 1000000, ExpiryTime: usermanager.JustInt64(1000000),
} }
func TestUserPanel_GetUser(t *testing.T) { func TestUserPanel_GetUser(t *testing.T) {
@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 { if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
t.Error("down credit incorrect update") t.Error("down credit incorrect update")
} }
if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 { if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
t.Error("up credit incorrect update") t.Error("up credit incorrect update")
} }
@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) { if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
t.Error("down credit incorrect update") t.Error("down credit incorrect update")
} }
if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) { if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
t.Error("up credit incorrect update") t.Error("up credit incorrect update")
} }
}) })
@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
t.Error(err) t.Error(err)
} }
user.valve.AddTx(validUserInfo.DownCredit + 100) user.valve.AddTx(*validUserInfo.DownCredit + 100)
panel.updateUsageQueue() panel.updateUsageQueue()
err = panel.commitUpdate() err = panel.commitUpdate()
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != -100 { if *updatedUinfo.DownCredit != -100 {
t.Error("down credit not updated correctly after the user has been terminated") t.Error("down credit not updated correctly after the user has been terminated")
} }
}) })