mirror of
https://github.com/smallstep/certificates.git
synced 2024-10-31 03:20:16 +00:00
Fix authority/admin/api tests
This commit is contained in:
parent
2ab7dc6f9d
commit
a8a4261980
@ -29,6 +29,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||
return protojson.Unmarshal(data, m)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
@ -54,6 +65,7 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
@ -143,16 +155,12 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireEABEnabled(tc.next)(w, req)
|
||||
requireEABEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -194,6 +202,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
@ -358,12 +367,9 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||
got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName)
|
||||
if (err != nil) != (tc.err != nil) {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
||||
return
|
||||
|
@ -317,14 +317,11 @@ func TestHandler_GetAdmin(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmin(w, req)
|
||||
GetAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -456,13 +453,10 @@ func TestHandler_GetAdmins(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmins(w, req)
|
||||
GetAdmins(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -640,13 +634,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAdmin(w, req)
|
||||
CreateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -732,13 +724,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteAdmin(w, req)
|
||||
DeleteAdmin(w, req)
|
||||
res := w.Result()
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
@ -877,13 +867,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateAdmin(w, req)
|
||||
UpdateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -64,13 +64,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireAPIEnabled(tc.next)(w, req)
|
||||
requireAPIEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -194,13 +192,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -47,6 +47,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
@ -71,6 +72,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
@ -153,13 +155,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := tc.req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioner(w, req)
|
||||
GetProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -277,12 +277,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioners(w, req)
|
||||
GetProvisioners(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -402,13 +400,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateProvisioner(w, req)
|
||||
CreateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -562,12 +558,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteProvisioner(w, req)
|
||||
DeleteProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
@ -616,6 +610,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||
Type: "",
|
||||
@ -645,6 +640,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
@ -1052,14 +1048,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateProvisioner(w, req)
|
||||
UpdateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
Loading…
Reference in New Issue
Block a user