diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 6ffe1418..6b89b288 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -29,6 +29,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } +func mockMustAuthority(t *testing.T, a adminAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) adminAuthority { + return a + } +} + func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context @@ -54,6 +65,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { return test{ ctx: ctx, auth: auth, + adminDB: &admin.MockDB{}, err: err, statusCode: 500, } @@ -143,16 +155,12 @@ func TestHandler_requireEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.requireEABEnabled(tc.next)(w, req) + requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -194,6 +202,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } return test{ auth: auth, + adminDB: &admin.MockDB{}, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), @@ -358,12 +367,9 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(context.Background(), tc.adminDB) + got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName) if (err != nil) != (tc.err != nil) { t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) return diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index 8d223b52..2f5528e1 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -317,14 +317,11 @@ func TestHandler_GetAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmin(w, req) + GetAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -456,13 +453,10 @@ func TestHandler_GetAdmins(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmins(w, req) + GetAdmins(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -640,13 +634,11 @@ func TestHandler_CreateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateAdmin(w, req) + CreateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -732,13 +724,11 @@ func TestHandler_DeleteAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteAdmin(w, req) + DeleteAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -877,13 +867,11 @@ func TestHandler_UpdateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.UpdateAdmin(w, req) + UpdateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 7fb4671a..3445a3b5 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -64,13 +64,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireAPIEnabled(tc.next)(w, req) + requireAPIEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -194,13 +192,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.extractAuthorizeTokenAdmin(tc.next)(w, req) + extractAuthorizeTokenAdmin(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 6d5024f2..6ee26dba 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -47,6 +47,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -71,6 +72,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -153,13 +155,11 @@ func TestHandler_GetProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } - req := tc.req.WithContext(tc.ctx) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + req := tc.req.WithContext(ctx) w := httptest.NewRecorder() - h.GetProvisioner(w, req) + GetProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -277,12 +277,10 @@ func TestHandler_GetProvisioners(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetProvisioners(w, req) + GetProvisioners(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -402,13 +400,11 @@ func TestHandler_CreateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateProvisioner(w, req) + CreateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -562,12 +558,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteProvisioner(w, req) + DeleteProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -616,6 +610,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ // TODO(hs): this probably needs a better error Type: "", @@ -645,6 +640,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: ctx, body: body, + adminDB: &admin.MockDB{}, auth: auth, statusCode: 500, err: &admin.Error{ @@ -1052,14 +1048,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.UpdateProvisioner(w, req) + UpdateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode)