Fix authority/admin/api tests

pull/914/head
Mariano Cano 2 years ago
parent 2ab7dc6f9d
commit a8a4261980

@ -29,6 +29,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
return protojson.Unmarshal(data, m)
}
func mockMustAuthority(t *testing.T, a adminAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) adminAuthority {
return a
}
}
func TestHandler_requireEABEnabled(t *testing.T) {
type test struct {
ctx context.Context
@ -54,6 +65,7 @@ func TestHandler_requireEABEnabled(t *testing.T) {
return test{
ctx: ctx,
auth: auth,
adminDB: &admin.MockDB{},
err: err,
statusCode: 500,
}
@ -143,16 +155,12 @@ func TestHandler_requireEABEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.requireEABEnabled(tc.next)(w, req)
requireEABEnabled(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -194,6 +202,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
}
return test{
auth: auth,
adminDB: &admin.MockDB{},
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
@ -358,12 +367,9 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(context.Background(), tc.adminDB)
got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName)
if (err != nil) != (tc.err != nil) {
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
return

@ -317,14 +317,11 @@ func TestHandler_GetAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAdmin(w, req)
GetAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -456,13 +453,10 @@ func TestHandler_GetAdmins(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAdmins(w, req)
GetAdmins(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -640,13 +634,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.CreateAdmin(w, req)
CreateAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -732,13 +724,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.DeleteAdmin(w, req)
DeleteAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -877,13 +867,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.UpdateAdmin(w, req)
UpdateAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

@ -64,13 +64,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.requireAPIEnabled(tc.next)(w, req)
requireAPIEnabled(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -194,13 +192,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
extractAuthorizeTokenAdmin(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

@ -47,6 +47,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx,
req: req,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{
Type: admin.ErrorServerInternalType.String(),
@ -71,6 +72,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx,
req: req,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{
Type: admin.ErrorServerInternalType.String(),
@ -153,13 +155,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
}
req := tc.req.WithContext(tc.ctx)
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := tc.req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetProvisioner(w, req)
GetProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -277,12 +277,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetProvisioners(w, req)
GetProvisioners(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -402,13 +400,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.CreateProvisioner(w, req)
CreateProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -562,12 +558,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.DeleteProvisioner(w, req)
DeleteProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -616,6 +610,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
@ -645,6 +640,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{
ctx: ctx,
body: body,
adminDB: &admin.MockDB{},
auth: auth,
statusCode: 500,
err: &admin.Error{
@ -1052,14 +1048,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
}
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.UpdateProvisioner(w, req)
UpdateProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

Loading…
Cancel
Save