From 6f9d847bc6489f7669997edd0e6db5dcb0b9e2d1 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 17:35:35 -0700 Subject: [PATCH] Fix panic in acme/api tests. --- acme/api/account_test.go | 78 ++++++------ acme/api/eab_test.go | 48 +++---- acme/api/handler.go | 1 - acme/api/handler_test.go | 93 ++++++++------ acme/api/middleware.go | 66 ++++------ acme/api/middleware_test.go | 241 ++++++++++++------------------------ acme/api/order.go | 18 ++- acme/api/order_test.go | 112 +++++++++-------- acme/api/revoke.go | 7 +- acme/api/revoke_test.go | 62 +++++----- 10 files changed, 333 insertions(+), 393 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 3fbabfe5..18d24ab6 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -296,10 +296,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -315,9 +314,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrdersByAccountID(w, req) res := w.Result() @@ -363,6 +362,7 @@ func TestHandler_NewAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -371,6 +371,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -379,6 +380,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to "+ @@ -393,6 +395,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -405,8 +408,9 @@ func TestHandler_NewAccount(t *testing.T) { b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -418,9 +422,10 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -432,10 +437,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -454,9 +460,9 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), @@ -471,7 +477,7 @@ func TestHandler_NewAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ @@ -510,9 +516,9 @@ func TestHandler_NewAccount(t *testing.T) { } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, scepProvisioner) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), @@ -551,8 +557,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) eak := &acme.ExternalAccountKey{ ID: "eakID", @@ -599,8 +604,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -635,11 +639,11 @@ func TestHandler_NewAccount(t *testing.T) { Status: acme.StatusValid, Contact: []string{"foo", "bar"}, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, acc: acc, statusCode: 200, @@ -664,8 +668,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = false ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -719,8 +722,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -759,9 +761,9 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() NewAccount(w, req) res := w.Result() @@ -814,6 +816,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -822,6 +825,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -830,6 +834,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -839,6 +844,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -848,6 +854,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), @@ -862,6 +869,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -894,10 +902,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -914,11 +921,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -929,10 +936,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -946,11 +952,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -959,9 +965,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrUpdateAccount(w, req) res := w.Result() diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index 1c76618b..ae47a1b9 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -98,8 +98,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -143,8 +142,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ @@ -198,8 +196,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, scepProvisioner) return test{ ctx: ctx, err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), @@ -218,8 +215,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -264,8 +260,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{}, @@ -310,8 +305,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -358,8 +352,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -408,8 +401,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -458,8 +450,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -506,8 +497,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() boundAt := time.Now().Add(1 * time.Second) @@ -565,8 +555,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -623,8 +612,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -678,8 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -734,8 +721,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -762,10 +748,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{ - // db: tc.db, - // } - got, err := validateExternalAccountBinding(tc.ctx, tc.nar) + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) + got, err := validateExternalAccountBinding(ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { diff --git a/acme/api/handler.go b/acme/api/handler.go index efe2b780..f6d79031 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -223,7 +223,6 @@ func (d *Directory) ToLog() (interface{}, error) { func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) - fmt.Println(acmeProv, err) if err != nil { render.Error(w, err) return diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index fcc33a87..2ac41228 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -24,6 +25,29 @@ import ( "go.step.sm/crypto/pemutil" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + +func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return a + } +} + func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string @@ -52,7 +76,7 @@ func TestHandler_GetNonce(t *testing.T) { } func TestHandler_GetDirectory(t *testing.T) { - linker := NewLinker("ca.smallstep.com", "acme") + linker := acme.NewLinker("ca.smallstep.com", "acme") _ = linker type test struct { ctx context.Context @@ -62,13 +86,10 @@ func TestHandler_GetDirectory(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - ctx: ctx, + ctx: context.Background(), statusCode: 500, - err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + err: acme.NewErrorISE("provisioner is not in context"), } }, "fail/different-provisioner": func(t *testing.T) test { @@ -76,9 +97,7 @@ func TestHandler_GetDirectory(t *testing.T) { Type: "SCEP", Name: "test@scep-provisioner.com", } - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ ctx: ctx, statusCode: 500, @@ -89,8 +108,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -109,8 +127,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -131,9 +148,9 @@ func TestHandler_GetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: linker} + ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetDirectory(w, req) res := w.Result() @@ -220,7 +237,7 @@ func TestHandler_GetAuthorization(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, @@ -286,10 +303,9 @@ func TestHandler_GetAuthorization(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { @@ -305,9 +321,9 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetAuthorization(w, req) res := w.Result() @@ -448,9 +464,9 @@ func TestHandler_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetCertificate(w, req) res := w.Result() @@ -492,7 +508,7 @@ func TestHandler_GetChallenge(t *testing.T) { type test struct { db acme.DB - vco *acme.ValidateChallengeOptions + vc acme.Client ctx context.Context statusCode int ch *acme.Challenge @@ -501,6 +517,7 @@ func TestHandler_GetChallenge(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -508,6 +525,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -517,6 +535,7 @@ func TestHandler_GetChallenge(t *testing.T) { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -524,10 +543,11 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -535,7 +555,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/db.GetChallenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -554,7 +574,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -573,7 +593,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/no-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -592,7 +612,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, jwkContextKey, nil) @@ -612,7 +632,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -640,8 +660,8 @@ func TestHandler_GetChallenge(t *testing.T) { return acme.NewErrorISE("force") }, }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -652,14 +672,13 @@ func TestHandler_GetChallenge(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -691,8 +710,8 @@ func TestHandler_GetChallenge(t *testing.T) { URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -704,9 +723,9 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetChallenge(w, req) res := w.Result() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 09e88b8d..a254a83b 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "net/url" "strings" - "github.com/go-chi/chi" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -63,7 +62,12 @@ func addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - p := acme.MustProvisionerFromContext(r.Context()) + p, err := provisionerFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } + u := &url.URL{ Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), } @@ -260,32 +264,6 @@ func extractJWK(next nextHTTP) nextHTTP { } } -// lookupProvisioner loads the provisioner associated with the request. -// Responds 404 if the provisioner does not exist. -func lookupProvisioner(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") - name, err := url.PathUnescape(nameEscaped) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) - return - } - p, err := mustAuthority(r.Context()).LoadProvisionerByName(name) - if err != nil { - render.Error(w, err) - return - } - acmeProv, ok := p.(*provisioner.ACME) - if !ok { - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) - return - } - ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) - next(w, r.WithContext(ctx)) - } -} - // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. func checkPrerequisites(next nextHTTP) nextHTTP { @@ -446,16 +424,12 @@ type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") - // baseURLContextKey baseURL key - baseURLContextKey = ContextKey("baseURL") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // accountFromContext searches the context for an ACME account. Returns the @@ -468,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) { return val, nil } -// baseURLFromContext returns the baseURL if one is stored in the context. -func baseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(baseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) @@ -495,14 +460,29 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { return val, nil } +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { + p, ok := acme.ProvisionerFromContext(ctx) + if !ok || p == nil { + return nil, acme.NewErrorISE("provisioner expected in request context") + } + return p, nil +} + // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME) + p, err := provisionerFromContext(ctx) + if err != nil { + return nil, err + } + ap, ok := p.(*provisioner.ACME) if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return p, nil + + return ap, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index f192e67e..39a696ae 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) { w.Write(testBody) } -func Test_baseURLFromRequest(t *testing.T) { - tests := []struct { - name string - targetURL string - expectedResult *url.URL - requestPreparer func(*http.Request) - }{ - { - "HTTPS host pass-through failed.", - "https://my.dummy.host", - &url.URL{Scheme: "https", Host: "my.dummy.host"}, - nil, - }, - { - "Port pass-through failed", - "https://host.with.port:8080", - &url.URL{Scheme: "https", Host: "host.with.port:8080"}, - nil, - }, - { - "Explicit host from Request.Host was not used.", - "https://some.target.host:8080", - &url.URL{Scheme: "https", Host: "proxied.host"}, - func(r *http.Request) { - r.Host = "proxied.host" - }, - }, - { - "Missing Request.Host value did not result in empty string result.", - "https://some.host", - nil, - func(r *http.Request) { - r.Host = "" - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", tc.targetURL, nil) - if tc.requestPreparer != nil { - tc.requestPreparer(request) - } - result := getBaseURLFromRequest(request) - if result == nil || tc.expectedResult == nil { - assert.Equals(t, result, tc.expectedResult) - } else if result.String() != tc.expectedResult.String() { - t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String()) - } - }) - } -} - -func TestHandler_baseURLFromRequest(t *testing.T) { - // h := &Handler{} - req := httptest.NewRequest("GET", "/foo", nil) - req.Host = "test.ca.smallstep.com:8080" - w := httptest.NewRecorder() - - next := func(w http.ResponseWriter, r *http.Request) { - bu := baseURLFromContext(r.Context()) - if assert.NotNil(t, bu) { - assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") - assert.Equals(t, bu.Scheme, "https") +func newBaseContext(ctx context.Context, args ...interface{}) context.Context { + for _, a := range args { + switch v := a.(type) { + case acme.DB: + ctx = acme.NewDatabaseContext(ctx, v) + case acme.Linker: + ctx = acme.NewLinkerContext(ctx, v) + case acme.PrerequisitesChecker: + ctx = acme.NewPrerequisitesCheckerContext(ctx, v) } } - - baseURLFromRequest(next)(w, req) - - req = httptest.NewRequest("GET", "/foo", nil) - req.Host = "" - - next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, baseURLFromContext(r.Context()), nil) - } - - baseURLFromRequest(next)(w, req) + return ctx } func TestHandler_addNonce(t *testing.T) { @@ -139,8 +74,8 @@ func TestHandler_addNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", u, nil) + ctx := newBaseContext(context.Background(), tc.db) + req := httptest.NewRequest("GET", u, nil).WithContext(ctx) w := httptest.NewRecorder() addNonce(testNext)(w, req) res := w.Result() @@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { link string - linker Linker statusCode int ctx context.Context err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) return test{ - linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -195,7 +128,6 @@ func TestHandler_addDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { - h Handler ctx context.Context contentType string err *acme.Error @@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/provisioner-not-set": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, url: u, ctx: context.Background(), contentType: "foo", @@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, url: u, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), @@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), @@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) { }, "ok": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkix-cert", statusCode: 200, } }, "ok/certificate/jose+json": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - linker Linker + linker acme.Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), @@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _parsed) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) @@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) { } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -881,9 +791,9 @@ func TestHandler_lookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() lookupJWK(tc.next)(w, req) res := w.Result() @@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), @@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1077,9 +991,9 @@ func TestHandler_extractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() extractJWK(tc.next)(w, req) res := w.Result() @@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/no-signature": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), @@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), @@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), @@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), @@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), @@ -1444,9 +1365,9 @@ func TestHandler_validateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() validateJWS(tc.next)(w, req) res := w.Result() @@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { u := "https://ca.smallstep.com/acme/account" type test struct { db acme.DB - linker Linker + linker acme.Linker statusCode int ctx context.Context err *acme.Error @@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) @@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ - linker: NewLinker("test.ca.smallstep.com", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, acc.ID) @@ -1628,9 +1548,9 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() extractOrLookupJWK(tc.next)(w, req) res := w.Result() @@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) { u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) type test struct { - linker Linker + linker acme.Linker ctx context.Context prerequisitesChecker func(context.Context) (bool, error) next func(http.ResponseWriter, *http.Request) @@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "fail/prerequisites-nok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, next: func(w http.ResponseWriter, r *http.Request) { diff --git a/acme/api/order.go b/acme/api/order.go index 2b9f912e..08718977 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -72,13 +72,17 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -189,13 +193,17 @@ func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { @@ -228,13 +236,17 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) diff --git a/acme/api/order_test.go b/acme/api/order_test.go index f0a2d1d4..0ab76778 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -276,15 +276,17 @@ func TestHandler_GetOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -294,6 +296,7 @@ func TestHandler_GetOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -301,9 +304,10 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -311,7 +315,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -325,7 +329,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -341,7 +345,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -357,7 +361,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/order-update-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -381,10 +385,9 @@ func TestHandler_GetOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { @@ -421,9 +424,9 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrder(w, req) res := w.Result() @@ -636,8 +639,8 @@ func TestHandler_newAuthorization(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - // h := &Handler{db: tc.db} - if err := newAuthorization(context.Background(), tc.az); err != nil { + ctx := newBaseContext(context.Background(), tc.db) + if err := newAuthorization(ctx, tc.az); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *acme.Error: @@ -677,15 +680,17 @@ func TestHandler_NewOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -695,6 +700,7 @@ func TestHandler_NewOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -702,9 +708,10 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -713,8 +720,9 @@ func TestHandler_NewOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -722,10 +730,11 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("paylod does not exist"), @@ -733,10 +742,11 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), @@ -747,10 +757,11 @@ func TestHandler_NewOrder(t *testing.T) { fr := &NewOrderRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), @@ -765,7 +776,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -793,7 +804,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( @@ -863,10 +874,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3, ch4 **acme.Challenge az1ID, az2ID *string @@ -978,10 +988,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1070,10 +1079,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1161,10 +1169,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1253,10 +1260,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1334,9 +1340,9 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() NewOrder(w, req) res := w.Result() @@ -1371,6 +1377,7 @@ func TestHandler_NewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { + mockMustAuthority(t, &mockCA{}) prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -1429,15 +1436,17 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -1447,6 +1456,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1454,9 +1464,10 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1465,8 +1476,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -1474,10 +1486,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("paylod does not exist"), @@ -1485,10 +1498,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), @@ -1499,10 +1513,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), @@ -1511,7 +1526,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1526,7 +1541,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1543,7 +1558,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1560,7 +1575,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/order-finalize-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1585,10 +1600,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -1624,9 +1638,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() FinalizeOrder(w, req) res := w.Result() diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 584ed27e..a8b98f3f 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -30,7 +30,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -38,6 +37,12 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 3a0ba70d..c746c11b 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -511,6 +511,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-jws": func(t *testing.T) test { ctx := context.Background() return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -519,6 +520,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/nil-jws": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -527,6 +529,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -534,8 +537,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, nil) + ctx = acme.NewProvisionerContext(ctx, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -543,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -552,9 +557,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -563,9 +569,10 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/unmarshal-payload": func(t *testing.T) test { malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error unmarshaling payload"), @@ -577,10 +584,11 @@ func TestHandler_RevokeCert(t *testing.T) { } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -596,10 +604,11 @@ func TestHandler_RevokeCert(t *testing.T) { } emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -610,7 +619,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -628,7 +637,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/different-certificate-contents": func(t *testing.T) test { aDifferentCert, _, err := generateCertKeyPair() assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -647,7 +656,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -666,7 +675,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, accContextKey, nil) @@ -687,11 +696,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -717,11 +725,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-authorized": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -771,10 +778,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -798,11 +804,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -832,7 +837,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -870,7 +875,7 @@ func TestHandler_RevokeCert(t *testing.T) { invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -908,7 +913,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) + ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -940,7 +945,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -972,7 +977,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -1003,11 +1008,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "ok/using-account-key": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1031,10 +1035,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1057,9 +1060,10 @@ func TestHandler_RevokeCert(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) + mockMustAuthority(t, tc.ca) req := httptest.NewRequest("POST", revokeURL, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() RevokeCert(w, req) res := w.Result()