mirror of
https://github.com/smallstep/certificates.git
synced 2024-10-31 03:20:16 +00:00
Fix authority/admin/api tests
This commit is contained in:
parent
2ab7dc6f9d
commit
a8a4261980
@ -29,6 +29,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
|||||||
return protojson.Unmarshal(data, m)
|
return protojson.Unmarshal(data, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||||
|
t.Helper()
|
||||||
|
fn := mustAuthority
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mustAuthority = fn
|
||||||
|
})
|
||||||
|
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@ -54,6 +65,7 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
}
|
}
|
||||||
@ -143,16 +155,12 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
|
||||||
acmeDB: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.requireEABEnabled(tc.next)(w, req)
|
requireEABEnabled(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -194,6 +202,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
provisionerName: "provName",
|
provisionerName: "provName",
|
||||||
want: false,
|
want: false,
|
||||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||||
@ -358,12 +367,9 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName)
|
||||||
acmeDB: nil,
|
|
||||||
}
|
|
||||||
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
|
||||||
if (err != nil) != (tc.err != nil) {
|
if (err != nil) != (tc.err != nil) {
|
||||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
||||||
return
|
return
|
||||||
|
@ -317,14 +317,11 @@ func TestHandler_GetAdmin(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAdmin(w, req)
|
GetAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -456,13 +453,10 @@ func TestHandler_GetAdmins(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAdmins(w, req)
|
GetAdmins(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -640,13 +634,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.CreateAdmin(w, req)
|
CreateAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -732,13 +724,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.DeleteAdmin(w, req)
|
DeleteAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
@ -877,13 +867,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.UpdateAdmin(w, req)
|
UpdateAdmin(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -64,13 +64,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.requireAPIEnabled(tc.next)(w, req)
|
requireAPIEnabled(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -194,13 +192,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
@ -47,6 +47,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: req,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
Type: admin.ErrorServerInternalType.String(),
|
Type: admin.ErrorServerInternalType.String(),
|
||||||
@ -71,6 +72,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: req,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
Type: admin.ErrorServerInternalType.String(),
|
Type: admin.ErrorServerInternalType.String(),
|
||||||
@ -153,13 +155,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
req := tc.req.WithContext(ctx)
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetProvisioner(w, req)
|
GetProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -277,12 +277,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetProvisioners(w, req)
|
GetProvisioners(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -402,13 +400,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.CreateProvisioner(w, req)
|
CreateProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -562,12 +558,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
|
||||||
}
|
|
||||||
req := tc.req.WithContext(tc.ctx)
|
req := tc.req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.DeleteProvisioner(w, req)
|
DeleteProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
@ -616,6 +610,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||||||
return test{
|
return test{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
body: body,
|
body: body,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||||
Type: "",
|
Type: "",
|
||||||
@ -645,6 +640,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
body: body,
|
body: body,
|
||||||
|
adminDB: &admin.MockDB{},
|
||||||
auth: auth,
|
auth: auth,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
err: &admin.Error{
|
err: &admin.Error{
|
||||||
@ -1052,14 +1048,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
mockMustAuthority(t, tc.auth)
|
||||||
auth: tc.auth,
|
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||||
adminDB: tc.adminDB,
|
|
||||||
}
|
|
||||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.UpdateProvisioner(w, req)
|
UpdateProvisioner(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
Loading…
Reference in New Issue
Block a user