diff --git a/acme/api/account.go b/acme/api/account.go index 62d62f09..3114dcb3 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -82,23 +82,23 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -108,26 +108,26 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { var acmeErr *acme.Error if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... - render.Error(w, err) + render.Error(w, r, err) return } // Account does not exist // if nar.OnlyReturnExisting { - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -140,17 +140,17 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { ProvisionerName: prov.Name, } if err := db.CreateAccount(ctx, acc); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error creating account")) + render.Error(w, r, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if err := eak.BindTo(acc); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) + render.Error(w, r, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding @@ -163,7 +163,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { linker.LinkAccount(ctx, acc) w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID)) - render.JSONStatus(w, acc, httpStatus) + render.JSONStatus(w, r, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. @@ -174,12 +174,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -188,12 +188,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { @@ -204,7 +204,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } if err := db.UpdateAccount(ctx, acc); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error updating account")) + render.Error(w, r, acme.WrapErrorISE(err, "error updating account")) return } } @@ -213,7 +213,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { linker.LinkAccount(ctx, acc) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) - render.JSON(w, acc) + render.JSON(w, r, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { @@ -233,23 +233,23 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } linker.LinkOrdersByAccountID(ctx, orders) - render.JSON(w, orders) + render.JSON(w, r, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/handler.go b/acme/api/handler.go index d2940f49..0722bd9b 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -223,13 +223,13 @@ func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } linker := acme.MustLinkerFromContext(ctx) - render.JSON(w, &Directory{ + render.JSON(w, r, &Directory{ NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), @@ -273,8 +273,8 @@ func shouldAddMetaObject(p *provisioner.ACME) bool { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. -func NotImplemented(w http.ResponseWriter, _ *http.Request) { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) +func NotImplemented(w http.ResponseWriter, r *http.Request) { + render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. @@ -285,28 +285,28 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, db); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) + render.Error(w, r, acme.WrapErrorISE(err, "error updating authorization status")) return } linker.LinkAuthorization(ctx, az) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) - render.JSON(w, az) + render.JSON(w, r, az) } // GetChallenge ACME api for retrieving a Challenge. @@ -317,13 +317,13 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -336,22 +336,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) { azID := chi.URLParam(r, "authzID") ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if err = ch.Validate(ctx, db, jwk, payload.value); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) + render.Error(w, r, acme.WrapErrorISE(err, "error validating challenge")) return } @@ -359,7 +359,7 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) - render.JSON(w, ch) + render.JSON(w, r, ch) } // GetCertificate ACME api for retrieving a Certificate. @@ -369,18 +369,18 @@ func GetCertificate(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } certID := chi.URLParam(r, "certID") cert, err := db.GetCertificate(ctx, certID) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index db3f3d6c..628da7ed 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -36,7 +36,7 @@ func addNonce(next nextHTTP) nextHTTP { db := acme.MustDatabaseFromContext(r.Context()) nonce, err := db.CreateNonce(r.Context()) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } w.Header().Set("Replay-Nonce", string(nonce)) @@ -64,7 +64,7 @@ func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { p, err := provisionerFromContext(r.Context()) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -88,7 +88,7 @@ func verifyContentType(next nextHTTP) nextHTTP { return } } - render.Error(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -98,12 +98,12 @@ func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) + render.Error(w, r, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) @@ -133,15 +133,15 @@ func validateJWS(next nextHTTP) nextHTTP { jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if len(jws.Signatures) == 0 { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -152,7 +152,7 @@ func validateJWS(next nextHTTP) nextHTTP { uh.Algorithm != "" || uh.Nonce != "" || len(uh.ExtraHeaders) > 0 { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -162,13 +162,13 @@ func validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - render.Error(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - render.Error(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } @@ -176,35 +176,35 @@ func validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) + render.Error(w, r, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - render.Error(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && hdr.KeyID != "" { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -221,23 +221,23 @@ func extractJWK(next nextHTTP) nextHTTP { jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + render.Error(w, r, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } @@ -251,11 +251,11 @@ func extractJWK(next nextHTTP) nextHTTP { // For NewAccount and Revoke requests ... break case err != nil: - render.Error(w, err) + render.Error(w, r, err) return default: if !acc.IsValid() { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -274,11 +274,11 @@ func checkPrerequisites(next nextHTTP) nextHTTP { if ok { ok, err := checkFunc(ctx) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + render.Error(w, r, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return } if !ok { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } } @@ -296,13 +296,13 @@ func lookupJWK(next nextHTTP) nextHTTP { jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } kid := jws.Signatures[0].Protected.KeyID if kid == "" { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) return } @@ -310,14 +310,14 @@ func lookupJWK(next nextHTTP) nextHTTP { acc, err := db.GetAccount(ctx, accID) switch { case acme.IsErrNotFound(err): - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) + render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: - render.Error(w, err) + render.Error(w, r, err) return default: if !acc.IsValid() { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } @@ -325,7 +325,7 @@ func lookupJWK(next nextHTTP) nextHTTP { if kid != storedLocation { // ACME accounts should have a stored location equivalent to the // kid in the ACME request. - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected %s, but got %s", storedLocation, kid)) return @@ -336,12 +336,12 @@ func lookupJWK(next nextHTTP) nextHTTP { reqProv := acme.MustProvisionerFromContext(ctx) switch { case acc.ProvisionerID == "" && acc.ProvisionerName != reqProv.GetName(): - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", acc.ProvisionerName, reqProv.GetName())) return case acc.ProvisionerID != "" && acc.ProvisionerID != reqProv.GetID(): - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", acc.ProvisionerID, reqProv.GetID())) return @@ -355,7 +355,7 @@ func lookupJWK(next nextHTTP) nextHTTP { linker := acme.MustLinkerFromContext(ctx) kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") if !strings.HasPrefix(kid, kidPrefix) { - render.Error(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return @@ -376,7 +376,7 @@ func extractOrLookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -412,16 +412,16 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } jwk, err := jwkFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } @@ -430,11 +430,11 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { case errors.Is(err, jose.ErrCryptoFailure): payload, err = retryVerificationWithPatchedSignatures(jws, jwk) if err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)")) + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)")) return } case err != nil: - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } @@ -551,11 +551,11 @@ func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if !payload.isPostAsGet { - render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) + render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) diff --git a/acme/api/order.go b/acme/api/order.go index b207f87c..6aa079fc 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -99,29 +99,29 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -130,39 +130,39 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } var eak *acme.ExternalAccountKey if acmeProv.RequireEAB { if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving external account binding key")) return } } acmePolicy, err := newACMEPolicyEngine(eak) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine")) + render.Error(w, r, acme.WrapErrorISE(err, "error creating ACME policy engine")) return } for _, identifier := range nor.Identifiers { // evaluate the ACME account level policy if err = isIdentifierAllowed(acmePolicy, identifier); err != nil { - render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } // evaluate the provisioner level policy orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil { - render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } // evaluate the authority level policy if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { - render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) + render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } } @@ -188,7 +188,7 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { Status: acme.StatusPending, } if err := newAuthorization(ctx, az); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } o.AuthorizationIDs[i] = az.ID @@ -207,14 +207,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { } if err := db.CreateOrder(ctx, o); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error creating order")) + render.Error(w, r, acme.WrapErrorISE(err, "error creating order")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) - render.JSONStatus(w, o, http.StatusCreated) + render.JSONStatus(w, r, o, http.StatusCreated) } func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error { @@ -288,39 +288,39 @@ func GetOrder(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, db); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error updating order status")) + render.Error(w, r, acme.WrapErrorISE(err, "error updating order status")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) - render.JSON(w, o) + render.JSON(w, r, o) } // FinalizeOrder attempts to finalize an order and create a certificate. @@ -331,56 +331,56 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { acc, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } ca := mustAuthority(ctx) if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { - render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) + render.Error(w, r, acme.WrapErrorISE(err, "error finalizing order")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) - render.JSON(w, o) + render.JSON(w, r, o) } // challengeTypes determines the types of challenges that should be used diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 270a9fbb..c97d54c1 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -33,65 +33,65 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { jws, err := jwsFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + render.Error(w, r, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate - render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen - render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) + render.Error(w, r, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { - render.Error(w, acmeErr) + render.Error(w, r, acmeErr) return } } else { @@ -100,7 +100,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + render.Error(w, r, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } @@ -108,19 +108,19 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { ca := mustAuthority(ctx) hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + render.Error(w, r, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { - render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + render.Error(w, r, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { - render.Error(w, acmeErr) + render.Error(w, r, acmeErr) return } @@ -128,14 +128,14 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + render.Error(w, r, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = ca.Revoke(ctx, options) if err != nil { - render.Error(w, wrapRevokeErr(err)) + render.Error(w, r, wrapRevokeErr(err)) return } diff --git a/acme/errors.go b/acme/errors.go index 658ec6e0..586cfb9b 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -424,7 +424,7 @@ func (e *Error) ToLog() (interface{}, error) { } // Render implements render.RenderableError for Error. -func (e *Error) Render(w http.ResponseWriter) { +func (e *Error) Render(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/problem+json") - render.JSONStatus(w, e, e.StatusCode()) + render.JSONStatus(w, r, e, e.StatusCode()) } diff --git a/acme/linker.go b/acme/linker.go index d142bf10..18997c5c 100644 --- a/acme/linker.go +++ b/acme/linker.go @@ -186,19 +186,19 @@ func (l *linker) Middleware(next http.Handler) http.Handler { nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { - render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + render.Error(w, r, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + render.Error(w, r, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } diff --git a/api/api.go b/api/api.go index 6916983b..0b139a71 100644 --- a/api/api.go +++ b/api/api.go @@ -353,15 +353,15 @@ func Route(r Router) { // Version is an HTTP handler that returns the version of the server. func Version(w http.ResponseWriter, r *http.Request) { v := mustAuthority(r.Context()).Version() - render.JSON(w, VersionResponse{ + render.JSON(w, r, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) } // Health is an HTTP handler that returns the status of the server. -func Health(w http.ResponseWriter, _ *http.Request) { - render.JSON(w, HealthResponse{Status: "ok"}) +func Health(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root @@ -372,11 +372,11 @@ func Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { - render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) + render.Error(w, r, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } - render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) + render.JSON(w, r, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { @@ -391,17 +391,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } - render.JSON(w, &ProvisionersResponse{ + render.JSON(w, r, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -412,18 +412,18 @@ func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { - render.Error(w, errs.NotFoundErr(err)) + render.Error(w, r, errs.NotFoundErr(err)) return } - render.JSON(w, &ProvisionerKeyResponse{key}) + render.JSON(w, r, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. func Roots(w http.ResponseWriter, r *http.Request) { roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error getting roots")) + render.Error(w, r, errs.ForbiddenErr(err, "error getting roots")) return } @@ -432,7 +432,7 @@ func Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - render.JSONStatus(w, &RootsResponse{ + render.JSONStatus(w, r, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } @@ -441,7 +441,7 @@ func Roots(w http.ResponseWriter, r *http.Request) { func RootsPEM(w http.ResponseWriter, r *http.Request) { roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } @@ -454,7 +454,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) { }) if _, err := w.Write(block); err != nil { - log.Error(w, err) + log.Error(w, r, err) return } } @@ -464,7 +464,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) { func Federation(w http.ResponseWriter, r *http.Request) { federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) + render.Error(w, r, errs.ForbiddenErr(err, "error getting federated roots")) return } @@ -473,7 +473,7 @@ func Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - render.JSONStatus(w, &FederationResponse{ + render.JSONStatus(w, r, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } diff --git a/api/crl.go b/api/crl.go index c10d08ca..92e14815 100644 --- a/api/crl.go +++ b/api/crl.go @@ -13,12 +13,12 @@ import ( func CRL(w http.ResponseWriter, r *http.Request) { crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList() if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if crlInfo == nil { - render.Error(w, errs.New(http.StatusNotFound, "no CRL available")) + render.Error(w, r, errs.New(http.StatusNotFound, "no CRL available")) return } diff --git a/api/log/log.go b/api/log/log.go index 687d61c6..6cc61a77 100644 --- a/api/log/log.go +++ b/api/log/log.go @@ -2,6 +2,7 @@ package log import ( + "context" "fmt" "net/http" "os" @@ -9,6 +10,29 @@ import ( "github.com/pkg/errors" ) +type errorLoggerKey struct{} + +// ErrorLogger is the function type used to log errors. +type ErrorLogger func(http.ResponseWriter, *http.Request, error) + +func (fn ErrorLogger) call(w http.ResponseWriter, r *http.Request, err error) { + if fn == nil { + return + } + fn(w, r, err) +} + +// WithErrorLogger returns a new context with the given error logger. +func WithErrorLogger(ctx context.Context, fn ErrorLogger) context.Context { + return context.WithValue(ctx, errorLoggerKey{}, fn) +} + +// ErrorLoggerFromContext returns an error logger from the context. +func ErrorLoggerFromContext(ctx context.Context) (fn ErrorLogger) { + fn, _ = ctx.Value(errorLoggerKey{}).(ErrorLogger) + return +} + // StackTracedError is the set of errors implementing the StackTrace function. // // Errors implementing this interface have their stack traces logged when passed @@ -27,8 +51,10 @@ type fieldCarrier interface { // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. -func Error(rw http.ResponseWriter, err error) { - fc, ok := rw.(fieldCarrier) +func Error(w http.ResponseWriter, r *http.Request, err error) { + ErrorLoggerFromContext(r.Context()).call(w, r, err) + + fc, ok := w.(fieldCarrier) if !ok { return } @@ -51,7 +77,7 @@ func Error(rw http.ResponseWriter, err error) { // EnabledResponse log the response object if it implements the EnableLogger // interface. -func EnabledResponse(rw http.ResponseWriter, v any) { +func EnabledResponse(rw http.ResponseWriter, r *http.Request, v any) { type enableLogger interface { ToLog() (any, error) } @@ -59,7 +85,7 @@ func EnabledResponse(rw http.ResponseWriter, v any) { if el, ok := v.(enableLogger); ok { out, err := el.ToLog() if err != nil { - Error(rw, err) + Error(rw, r, err) return } diff --git a/api/log/log_test.go b/api/log/log_test.go index 7c08b771..e1da274f 100644 --- a/api/log/log_test.go +++ b/api/log/log_test.go @@ -1,6 +1,9 @@ package log import ( + "bytes" + "encoding/json" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -27,21 +30,34 @@ func (stackTracedError) StackTrace() pkgerrors.StackTrace { } func TestError(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{})) + req := httptest.NewRequest("GET", "/test", http.NoBody) + reqWithLogger := req.WithContext(WithErrorLogger(req.Context(), func(w http.ResponseWriter, r *http.Request, err error) { + if err != nil { + logger.ErrorContext(r.Context(), "request failed", slog.Any("error", err)) + } + })) + tests := []struct { name string error rw http.ResponseWriter + r *http.Request isFieldCarrier bool + isSlogLogger bool stepDebug bool expectStackTrace bool }{ - {"noLogger", nil, nil, false, false, false}, - {"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, - {"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, - {"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, - {"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, - {"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, - {"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, + {"noLogger", nil, nil, req, false, false, false, false}, + {"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false}, + {"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false}, + {"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false}, + {"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false}, + {"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true}, + {"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true}, + {"slogWithNoError", nil, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false}, + {"slogWithError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false}, } for _, tt := range tests { @@ -52,27 +68,41 @@ func TestError(t *testing.T) { t.Setenv("STEPDEBUG", "0") } - Error(tt.rw, tt.error) + Error(tt.rw, tt.r, tt.error) // return early if test case doesn't use logger - if !tt.isFieldCarrier { + if !tt.isFieldCarrier && !tt.isSlogLogger { return } - fields := tt.rw.(logging.ResponseLogger).Fields() + if tt.isFieldCarrier { + fields := tt.rw.(logging.ResponseLogger).Fields() - // expect the error field to be (not) set and to be the same error that was fed to Error - if tt.error == nil { - assert.Nil(t, fields["error"]) - } else { - assert.Same(t, tt.error, fields["error"]) + // expect the error field to be (not) set and to be the same error that was fed to Error + if tt.error == nil { + assert.Nil(t, fields["error"]) + } else { + assert.Same(t, tt.error, fields["error"]) + } + + // check if stack-trace is set when expected + if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace { + t.Error(`ResponseLogger["stack-trace"] not set`) + } else if !tt.expectStackTrace && hasStackTrace { + t.Error(`ResponseLogger["stack-trace"] was set`) + } } - // check if stack-trace is set when expected - if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace { - t.Error(`ResponseLogger["stack-trace"] not set`) - } else if !tt.expectStackTrace && hasStackTrace { - t.Error(`ResponseLogger["stack-trace"] was set`) + if tt.isSlogLogger { + b := buf.Bytes() + if tt.error == nil { + assert.Empty(t, b) + } else if assert.NotEmpty(t, b) { + var m map[string]any + assert.NoError(t, json.Unmarshal(b, &m)) + assert.Equal(t, tt.error.Error(), m["error"]) + } + buf.Reset() } }) } diff --git a/api/read/read.go b/api/read/read.go index 72530b8c..6f75c41a 100644 --- a/api/read/read.go +++ b/api/read/read.go @@ -51,7 +51,7 @@ func (e badProtoJSONError) Error() string { } // Render implements render.RenderableError for badProtoJSONError -func (e badProtoJSONError) Render(w http.ResponseWriter) { +func (e badProtoJSONError) Render(w http.ResponseWriter, r *http.Request) { v := struct { Type string `json:"type"` Detail string `json:"detail"` @@ -62,5 +62,5 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) { // trim the proto prefix for the message Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), } - render.JSONStatus(w, v, http.StatusBadRequest) + render.JSONStatus(w, r, v, http.StatusBadRequest) } diff --git a/api/read/read_test.go b/api/read/read_test.go index e46e7f61..e557a9a2 100644 --- a/api/read/read_test.go +++ b/api/read/read_test.go @@ -142,7 +142,8 @@ func Test_badProtoJSONError_Render(t *testing.T) { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() - tt.e.Render(w) + r := httptest.NewRequest("POST", "/test", http.NoBody) + tt.e.Render(w, r) res := w.Result() defer res.Body.Close() diff --git a/api/rekey.go b/api/rekey.go index cda843a3..772de217 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -29,25 +29,25 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - render.Error(w, errs.BadRequest("missing client certificate")) + render.Error(w, r, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } a := mustAuthority(r.Context()) certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { - render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) + render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) @@ -57,7 +57,7 @@ func Rekey(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - render.JSONStatus(w, &SignResponse{ + render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/render/render.go b/api/render/render.go index 7829ba25..1c66280c 100644 --- a/api/render/render.go +++ b/api/render/render.go @@ -13,8 +13,8 @@ import ( ) // JSON is shorthand for JSONStatus(w, v, http.StatusOK). -func JSON(w http.ResponseWriter, v interface{}) { - JSONStatus(w, v, http.StatusOK) +func JSON(w http.ResponseWriter, r *http.Request, v interface{}) { + JSONStatus(w, r, v, http.StatusOK) } // JSONStatus marshals v into w. It additionally sets the status code of @@ -22,7 +22,7 @@ func JSON(w http.ResponseWriter, v interface{}) { // // JSONStatus sets the Content-Type of w to application/json unless one is // specified. -func JSONStatus(w http.ResponseWriter, v interface{}, status int) { +func JSONStatus(w http.ResponseWriter, r *http.Request, v interface{}, status int) { setContentTypeUnlessPresent(w, "application/json") w.WriteHeader(status) @@ -43,7 +43,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { } } - log.EnabledResponse(w, v) + log.EnabledResponse(w, r, v) } // ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). @@ -80,22 +80,22 @@ func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) { type RenderableError interface { error - Render(http.ResponseWriter) + Render(http.ResponseWriter, *http.Request) } // Error marshals the JSON representation of err to w. In case err implements // RenderableError its own Render method will be called instead. -func Error(w http.ResponseWriter, err error) { - log.Error(w, err) +func Error(rw http.ResponseWriter, r *http.Request, err error) { + log.Error(rw, r, err) - var r RenderableError - if errors.As(err, &r) { - r.Render(w) + var re RenderableError + if errors.As(err, &re) { + re.Render(rw, r) return } - JSONStatus(w, err, statusCodeFromError(err)) + JSONStatus(rw, r, err, statusCodeFromError(err)) } // StatusCodedError is the set of errors that implement the basic StatusCode diff --git a/api/render/render_test.go b/api/render/render_test.go index e88544c7..d7ee37fd 100644 --- a/api/render/render_test.go +++ b/api/render/render_test.go @@ -18,8 +18,8 @@ import ( func TestJSON(t *testing.T) { rec := httptest.NewRecorder() rw := logging.NewResponseLogger(rec) - - JSON(rw, map[string]interface{}{"foo": "bar"}) + r := httptest.NewRequest("POST", "/test", http.NoBody) + JSON(rw, r, map[string]interface{}{"foo": "bar"}) assert.Equal(t, http.StatusOK, rec.Result().StatusCode) assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) @@ -64,7 +64,8 @@ func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | js assert.ErrorAs(t, err, &e) }() - JSON(httptest.NewRecorder(), v) + r := httptest.NewRequest("POST", "/test", http.NoBody) + JSON(httptest.NewRecorder(), r, v) } type renderableError struct { @@ -76,10 +77,9 @@ func (err renderableError) Error() string { return err.Message } -func (err renderableError) Render(w http.ResponseWriter) { +func (err renderableError) Render(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "something/custom") - - JSONStatus(w, err, err.Code) + JSONStatus(w, r, err, err.Code) } type statusedError struct { @@ -116,8 +116,8 @@ func TestError(t *testing.T) { t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { rec := httptest.NewRecorder() - - Error(rec, kase.err) + r := httptest.NewRequest("POST", "/test", http.NoBody) + Error(rec, r, kase.err) assert.Equal(t, kase.code, rec.Result().StatusCode) assert.Equal(t, kase.body, rec.Body.String()) diff --git a/api/renew.go b/api/renew.go index 1b9ed95f..7cd3707d 100644 --- a/api/renew.go +++ b/api/renew.go @@ -23,19 +23,20 @@ func Renew(w http.ResponseWriter, r *http.Request) { // Get the leaf certificate from the peer or the token. cert, token, err := getPeerCertificate(r) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } // The token can be used by RAs to renew a certificate. if token != "" { ctx = authority.NewTokenContext(ctx, token) + logOtt(w, token) } a := mustAuthority(ctx) certChain, err := a.RenewContext(ctx, cert, nil) if err != nil { - render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) @@ -45,7 +46,7 @@ func Renew(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - render.JSONStatus(w, &SignResponse{ + render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/revoke.go b/api/revoke.go index dc639d58..41969c08 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -57,12 +57,12 @@ func (r *RevokeRequest) Validate() (err error) { func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -81,7 +81,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) { if body.OTT != "" { logOtt(w, body.OTT) if _, err := a.Authorize(ctx, body.OTT); err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -90,12 +90,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - render.Error(w, errs.BadRequest("missing ott or client certificate")) + render.Error(w, r, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - render.Error(w, errs.BadRequest("serial number in client certificate different than body")) + render.Error(w, r, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -106,12 +106,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) { } if err := a.Revoke(ctx, opts); err != nil { - render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error revoking certificate")) return } logRevoke(w, opts) - render.JSON(w, &RevokeResponse{Status: "ok"}) + render.JSON(w, r, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/sign.go b/api/sign.go index 26b3c396..bff41763 100644 --- a/api/sign.go +++ b/api/sign.go @@ -52,13 +52,13 @@ type SignResponse struct { func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -74,13 +74,13 @@ func Sign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) @@ -90,7 +90,7 @@ func Sign(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - render.JSONStatus(w, &SignResponse{ + render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/ssh.go b/api/ssh.go index 08294c71..e0e4e01c 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -253,19 +253,19 @@ type SSHBastionResponse struct { func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -273,7 +273,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) + render.Error(w, r, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -293,13 +293,13 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -307,7 +307,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -320,7 +320,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } @@ -332,14 +332,14 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) } LogSSHCertificate(w, cert) - render.JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, r, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, @@ -352,12 +352,12 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) { ctx := r.Context() keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - render.Error(w, errs.NotFound("no keys found")) + render.Error(w, r, errs.NotFound("no keys found")) return } @@ -369,7 +369,7 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - render.JSON(w, resp) + render.JSON(w, r, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys @@ -378,12 +378,12 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) { ctx := r.Context() keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - render.Error(w, errs.NotFound("no keys found")) + render.Error(w, r, errs.NotFound("no keys found")) return } @@ -395,7 +395,7 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - render.JSON(w, resp) + render.JSON(w, r, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients @@ -403,18 +403,18 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) { func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } ctx := r.Context() ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } @@ -425,32 +425,32 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: cfg.HostTemplates = ts default: - render.Error(w, errs.InternalServer("it should hot get here")) + render.Error(w, r, errs.InternalServer("it should hot get here")) return } - render.JSON(w, cfg) + render.JSON(w, r, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } ctx := r.Context() exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } - render.JSON(w, &SSHCheckPrincipalResponse{ + render.JSON(w, r, &SSHCheckPrincipalResponse{ Exists: exists, }) } @@ -465,10 +465,10 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) { ctx := r.Context() hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } - render.JSON(w, &SSHGetHostsResponse{ + render.JSON(w, r, &SSHGetHostsResponse{ Hosts: hosts, }) } @@ -477,22 +477,22 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) { func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } ctx := r.Context() bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } - render.JSON(w, &SSHBastionResponse{ + render.JSON(w, r, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) diff --git a/api/sshRekey.go b/api/sshRekey.go index 80fc6d87..0db4d4da 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -42,19 +42,19 @@ type SSHRekeyResponse struct { func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -64,18 +64,18 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) { a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -85,12 +85,12 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate")) return } LogSSHCertificate(w, newCert) - render.JSONStatus(w, &SSHRekeyResponse{ + render.JSONStatus(w, r, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRenew.go b/api/sshRenew.go index cd6d9bde..dea7cea7 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -40,13 +40,13 @@ type SSHRenewResponse struct { func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -56,18 +56,18 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) { a := mustAuthority(ctx) _, err := a.Authorize(ctx, body.OTT) if err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -77,12 +77,12 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate")) return } LogSSHCertificate(w, newCert) - render.JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, r, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRevoke.go b/api/sshRevoke.go index d377def9..2fe49199 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -51,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) { func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -75,18 +75,18 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if _, err := a.Authorize(ctx, body.OTT); err != nil { - render.Error(w, errs.UnauthorizedErr(err)) + render.Error(w, r, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := a.Revoke(ctx, opts); err != nil { - render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) + render.Error(w, r, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } logSSHRevoke(w, opts) - render.JSON(w, &SSHRevokeResponse{Status: "ok"}) + render.JSON(w, r, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 32f2bdcc..6fc70896 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -40,12 +40,12 @@ func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { acmeProvisioner := prov.GetDetails().GetACME() if acmeProvisioner == nil { - render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) + render.Error(w, r, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) return } if !acmeProvisioner.RequireEab { - render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) + render.Error(w, r, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) return } @@ -69,18 +69,18 @@ func NewACMEAdminResponder() ACMEAdminResponder { } // GetExternalAccountKeys writes the response for the EAB keys GET endpoint -func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, _ *http.Request) { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) +func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { + render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint -func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, _ *http.Request) { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) +func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { + render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint -func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, _ *http.Request) { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) +func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { + render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey { diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index e4d9d9fe..ce22de05 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -90,7 +90,7 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) { adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } @@ -101,17 +101,17 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) { func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + render.Error(w, r, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } - render.JSON(w, &GetAdminsResponse{ + render.JSON(w, r, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) @@ -121,19 +121,19 @@ func GetAdmins(w http.ResponseWriter, r *http.Request) { func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } auth := mustAuthority(r.Context()) p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ @@ -143,7 +143,7 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) { } // Store to authority collection. if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error storing admin")) + render.Error(w, r, admin.WrapErrorISE(err, "error storing admin")) return } @@ -155,23 +155,23 @@ func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + render.Error(w, r, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } - render.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, r, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -179,7 +179,7 @@ func UpdateAdmin(w http.ResponseWriter, r *http.Request) { auth := mustAuthority(r.Context()) adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + render.Error(w, r, admin.WrapErrorISE(err, "error updating admin %s", id)) return } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index a37b6074..68006b7f 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -19,7 +19,7 @@ import ( func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !mustAuthority(r.Context()).IsAdminAPIEnabled() { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) + render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) @@ -31,7 +31,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { - render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, + render.Error(w, r, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } @@ -39,7 +39,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { ctx := r.Context() adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -64,13 +64,13 @@ func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { // TODO(hs): distinguish 404 vs. 500 if p, err = auth.LoadProvisionerByName(name); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } prov, err := adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) + render.Error(w, r, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) return } @@ -91,7 +91,7 @@ func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.Handler // when an action is not supported in standalone mode and when // using a nosql.DB backend, actions are not supported if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, + render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")) return } @@ -125,15 +125,15 @@ func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { if err != nil { if acme.IsErrNotFound(err) { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) return } - render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) + render.Error(w, r, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) return } if eak == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) return } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index c45bc947..d2d8183f 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -44,7 +44,7 @@ func NewPolicyAdminResponder() PolicyAdminResponder { func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -52,12 +52,12 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { - render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy")) return } if authorityPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } @@ -68,7 +68,7 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -77,26 +77,26 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { - render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy")) return } if authorityPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy") - render.Error(w, adminErr) + render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) return } @@ -105,11 +105,11 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r var createdPolicy *linkedca.Policy if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return } - render.Error(w, admin.WrapErrorISE(err, "error storing authority policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error storing authority policy")) return } @@ -120,7 +120,7 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -129,25 +129,25 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { - render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy")) return } if authorityPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) return } @@ -156,11 +156,11 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r var updatedPolicy *linkedca.Policy if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return } - render.Error(w, admin.WrapErrorISE(err, "error updating authority policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error updating authority policy")) return } @@ -171,7 +171,7 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -180,35 +180,35 @@ func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { - render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) + render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy")) return } if authorityPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } if err := auth.RemoveAuthorityPolicy(ctx); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error deleting authority policy")) return } - render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) + render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } @@ -219,7 +219,7 @@ func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r * func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -227,20 +227,20 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, provisionerPolicy := prov.GetPolicy() if provisionerPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) - render.Error(w, adminErr) + render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) return } @@ -248,11 +248,11 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) return } - render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner policy")) return } @@ -263,27 +263,27 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) return } @@ -291,11 +291,11 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) return } - render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner policy")) return } @@ -306,13 +306,13 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) if prov.Policy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } @@ -321,24 +321,24 @@ func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner policy")) return } - render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) + render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } @@ -348,7 +348,7 @@ func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r * func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -357,20 +357,20 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, eakPolicy := eak.GetPolicy() if eakPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) - render.Error(w, adminErr) + render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) return } @@ -379,7 +379,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error creating ACME EAK policy")) return } @@ -389,7 +389,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -397,20 +397,20 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) return } @@ -418,7 +418,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error updating ACME EAK policy")) return } @@ -428,7 +428,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -436,7 +436,7 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { - render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) + render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } @@ -446,11 +446,11 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) + render.Error(w, r, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) return } - render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) + render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } // blockLinkedCA blocks all API operations on linked deployments diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 709399dd..b2a59cfa 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -40,19 +40,19 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) { if id != "" { if p, err = auth.LoadProvisionerByID(id); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = auth.LoadProvisionerByName(name); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { - render.Error(w, err) + render.Error(w, r, err) return } render.ProtoJSON(w, prov) @@ -62,17 +62,17 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) { func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { - render.Error(w, errs.InternalServerErr(err)) + render.Error(w, r, errs.InternalServerErr(err)) return } - render.JSON(w, &GetProvisionersResponse{ + render.JSON(w, r, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -82,24 +82,24 @@ func GetProvisioners(w http.ResponseWriter, r *http.Request) { func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } // validate the templates and template data if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) return } if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + render.Error(w, r, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } render.ProtoJSONStatus(w, prov, http.StatusCreated) @@ -118,29 +118,29 @@ func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { if id != "" { if p, err = auth.LoadProvisionerByID(id); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = auth.LoadProvisionerByName(name); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { - render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + render.Error(w, r, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } - render.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, r, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -151,51 +151,51 @@ func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { p, err := auth.LoadProvisionerByName(name) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) + render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } if nu.Id != old.Id { - render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) + render.Error(w, r, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { - render.Error(w, admin.NewErrorISE("cannot change provisioner type")) + render.Error(w, r, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { - render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) + render.Error(w, r, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { - render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) + render.Error(w, r, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { - render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + render.Error(w, r, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } // validate the templates and template data if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) return } if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } render.ProtoJSON(w, nu) diff --git a/authority/admin/api/webhook.go b/authority/admin/api/webhook.go index f01ddb65..04255e15 100644 --- a/authority/admin/api/webhook.go +++ b/authority/admin/api/webhook.go @@ -71,28 +71,28 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter var newWebhook = new(linkedca.Webhook) if err := read.ProtoJSON(r.Body, newWebhook); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if err := validateWebhook(newWebhook); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if newWebhook.Secret != "" { err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set") - render.Error(w, err) + render.Error(w, r, err) return } if newWebhook.Id != "" { err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set") - render.Error(w, err) + render.Error(w, r, err) return } id, err := randutil.UUIDv4() if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error generating webhook id")) + render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook id")) return } newWebhook.Id = id @@ -101,14 +101,14 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter for _, wh := range prov.Webhooks { if wh.Name == newWebhook.Name { err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name) - render.Error(w, err) + render.Error(w, r, err) return } } secret, err := randutil.Bytes(64) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error generating webhook secret")) + render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook secret")) return } newWebhook.Secret = base64.StdEncoding.EncodeToString(secret) @@ -117,11 +117,11 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook")) return } - render.Error(w, admin.WrapErrorISE(err, "error creating provisioner webhook")) + render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner webhook")) return } @@ -145,21 +145,21 @@ func (war *webhookAdminResponder) DeleteProvisionerWebhook(w http.ResponseWriter } } if !found { - render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) + render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) return } if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook")) return } - render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner webhook")) + render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner webhook")) return } - render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) + render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) { @@ -170,12 +170,12 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter var newWebhook = new(linkedca.Webhook) if err := read.ProtoJSON(r.Body, newWebhook); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } if err := validateWebhook(newWebhook); err != nil { - render.Error(w, err) + render.Error(w, r, err) return } @@ -186,13 +186,13 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter } if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret { err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated") - render.Error(w, err) + render.Error(w, r, err) return } newWebhook.Secret = wh.Secret if newWebhook.Id != "" && newWebhook.Id != wh.Id { err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated") - render.Error(w, err) + render.Error(w, r, err) return } newWebhook.Id = wh.Id @@ -203,17 +203,17 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter if !found { msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name) err := admin.NewError(admin.ErrorNotFoundType, msg) - render.Error(w, err) + render.Error(w, r, err) return } if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { - render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook")) + render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook")) return } - render.Error(w, admin.WrapErrorISE(err, "error updating provisioner webhook")) + render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner webhook")) return } diff --git a/authority/admin/errors.go b/authority/admin/errors.go index c729c8b2..a14e2dee 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -205,8 +205,8 @@ func (e *Error) ToLog() (interface{}, error) { } // Render implements render.RenderableError for Error. -func (e *Error) Render(w http.ResponseWriter) { +func (e *Error) Render(w http.ResponseWriter, r *http.Request) { e.Message = e.Err.Error() - render.JSONStatus(w, e, e.StatusCode()) + render.JSONStatus(w, r, e, e.StatusCode()) } diff --git a/authority/root.go b/authority/root.go index 37038cfa..0a6ee639 100644 --- a/authority/root.go +++ b/authority/root.go @@ -57,3 +57,26 @@ func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) }) return } + +// GetIntermediateCertificate return the intermediate certificate that issues +// the leaf certificates in the CA. +// +// This method can return nil if the CA is configured with a Certificate +// Authority Service (CAS) that does not implement the +// CertificateAuthorityGetter interface. +func (a *Authority) GetIntermediateCertificate() *x509.Certificate { + if len(a.intermediateX509Certs) > 0 { + return a.intermediateX509Certs[0] + } + return nil +} + +// GetIntermediateCertificates returns a list of all intermediate certificates +// configured. The first certificate in the list will be the issuer certificate. +// +// This method can return an empty list or nil if the CA is configured with a +// Certificate Authority Service (CAS) that does not implement the +// CertificateAuthorityGetter interface. +func (a *Authority) GetIntermediateCertificates() []*x509.Certificate { + return a.intermediateX509Certs +} diff --git a/authority/root_test.go b/authority/root_test.go index e570b0be..a0811dd2 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -2,15 +2,18 @@ package authority import ( "crypto/x509" + "crypto/x509/pkix" "errors" "net/http" "reflect" "testing" - "go.step.sm/crypto/pemutil" - "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" ) func TestRoot(t *testing.T) { @@ -152,3 +155,63 @@ func TestAuthority_GetFederation(t *testing.T) { }) } } + +func TestAuthority_GetIntermediateCertificate(t *testing.T) { + ca, err := minica.New(minica.WithRootTemplate(`{ + "subject": {{ toJson .Subject }}, + "issuer": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": -1 + } + }`), minica.WithIntermediateTemplate(`{ + "subject": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": 1 + } + }`)) + require.NoError(t, err) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + cert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{CommonName: "MiniCA Intermediate CA 0"}, + PublicKey: signer.Public(), + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 0, + }) + require.NoError(t, err) + + type fields struct { + intermediateX509Certs []*x509.Certificate + } + tests := []struct { + name string + fields fields + want *x509.Certificate + wantSlice []*x509.Certificate + }{ + {"ok one", fields{[]*x509.Certificate{ca.Intermediate}}, ca.Intermediate, []*x509.Certificate{ca.Intermediate}}, + {"ok multiple", fields{[]*x509.Certificate{cert, ca.Intermediate}}, cert, []*x509.Certificate{cert, ca.Intermediate}}, + {"ok empty", fields{[]*x509.Certificate{}}, nil, []*x509.Certificate{}}, + {"ok nil", fields{nil}, nil, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + intermediateX509Certs: tt.fields.intermediateX509Certs, + } + if got := a.GetIntermediateCertificate(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetIntermediateCertificate() = %v, want %v", got, tt.want) + } + if got := a.GetIntermediateCertificates(); !reflect.DeepEqual(got, tt.wantSlice) { + t.Errorf("Authority.GetIntermediateCertificates() = %v, want %v", got, tt.wantSlice) + } + }) + } +} diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 77d380f9..b2c6b41a 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -108,19 +108,19 @@ func TestNewACMEClient(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header switch { case i == 0: - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ case i == 1: w.Header().Set("Replay-Nonce", "abc123") - render.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, r, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) } }) @@ -203,10 +203,10 @@ func TestACMEClient_GetNonce(t *testing.T) { t.Run(name, func(t *testing.T) { tc := run(t) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { @@ -310,18 +310,18 @@ func TestACMEClient_post(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -338,7 +338,7 @@ func TestACMEClient_post(t *testing.T) { assert.Equals(t, hdr.KeyID, ac.kid) } - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { @@ -450,18 +450,18 @@ func TestACMEClient_NewOrder(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -477,7 +477,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, norb) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { @@ -572,18 +572,18 @@ func TestACMEClient_GetOrder(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -599,7 +599,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { @@ -694,18 +694,18 @@ func TestACMEClient_GetAuthz(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -721,7 +721,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { @@ -816,18 +816,18 @@ func TestACMEClient_GetChallenge(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -844,7 +844,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { assert.Equals(t, len(payload), 0) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { @@ -939,18 +939,18 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -967,7 +967,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { assert.Equals(t, payload, []byte("{}")) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { @@ -983,22 +983,22 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header - t.Log(req.RequestURI) + t.Log(r.RequestURI) w.Header().Set("Replay-Nonce", "nonce") - switch req.RequestURI { + switch r.RequestURI { case "/nonce": - render.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, r, []byte{}, 200) return case "/fail-nonce": - render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) + render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) @@ -1015,15 +1015,15 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, []byte("the-payload")) - switch req.RequestURI { + switch r.RequestURI { case "/ok": - render.JSONStatus(w, acme.Challenge{ + render.JSONStatus(w, r, acme.Challenge{ Type: "device-attestation-01", Status: "valid", Token: "foo", }, 200) case "/fail": - render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) + render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) } })) defer srv.Close() @@ -1160,18 +1160,18 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -1187,7 +1187,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, frb) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { @@ -1289,18 +1289,18 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -1316,7 +1316,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { @@ -1420,18 +1420,18 @@ func TestACMEClient_GetCertificate(t *testing.T) { tc := run(t) i := 0 - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - render.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body - body, err := io.ReadAll(req.Body) + body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) @@ -1450,7 +1450,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { if tc.certBytes != nil { w.Write(tc.certBytes) } else { - render.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, r, tc.r2, tc.rc2) } }) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 62c422d4..da37eee5 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -87,7 +87,7 @@ func startCAServer(configFile string) (*CA, string, error) { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { - render.JSON(w, api.VersionResponse{ + render.JSON(w, r, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) @@ -102,7 +102,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { - render.Error(w, errs.Unauthorized("missing peer certificate")) + render.Error(w, r, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } @@ -412,7 +412,7 @@ func TestBootstrapClientServerRotation(t *testing.T) { //nolint:gosec // insecure test server server, err := BootstrapServer(context.Background(), token, &http.Server{ Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }), }, RequireAndVerifyClientCert()) @@ -531,7 +531,7 @@ func TestBootstrapClientServerFederation(t *testing.T) { //nolint:gosec // insecure test server server, err := BootstrapServer(context.Background(), token, &http.Server{ Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }), }, RequireAndVerifyClientCert(), AddFederationToClientCAs()) diff --git a/ca/client_test.go b/ca/client_test.go index 44d24c6e..bd05614b 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -177,8 +177,8 @@ func TestClient_Version(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Version() @@ -218,8 +218,8 @@ func TestClient_Health(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Health() @@ -262,12 +262,12 @@ func TestClient_Root(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expected := "/root/" + tt.shasum - if req.RequestURI != expected { - t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) + if r.RequestURI != expected { + t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected) } - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -323,12 +323,12 @@ func TestClient_Sign(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body := new(api.SignRequest) - if err := read.JSON(req.Body, body); err != nil { + if err := read.JSON(r.Body, body); err != nil { e, ok := tt.response.(error) require.True(t, ok, "response expected to be error type") - render.Error(w, e) + render.Error(w, r, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -339,7 +339,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -385,12 +385,12 @@ func TestClient_Revoke(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body := new(api.RevokeRequest) - if err := read.JSON(req.Body, body); err != nil { + if err := read.JSON(r.Body, body); err != nil { e, ok := tt.response.(error) require.True(t, ok, "response expected to be error type") - render.Error(w, e) + render.Error(w, r, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -401,7 +401,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -450,8 +450,8 @@ func TestClient_Renew(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -504,11 +504,11 @@ func TestClient_RenewWithToken(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Header.Get("Authorization") != "Bearer token" { - render.JSONStatus(w, errs.InternalServer("force"), 500) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer token" { + render.JSONStatus(w, r, errs.InternalServer("force"), 500) } else { - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) } }) @@ -567,8 +567,8 @@ func TestClient_Rekey(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) @@ -619,11 +619,11 @@ func TestClient_Provisioners(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.RequestURI != tt.expectedURI { - t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI != tt.expectedURI { + t.Errorf("RequestURI = %s, want %s", r.RequestURI, tt.expectedURI) } - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -666,12 +666,12 @@ func TestClient_ProvisionerKey(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expected := "/provisioners/" + tt.kid + "/encrypted-key" - if req.RequestURI != expected { - t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) + if r.RequestURI != expected { + t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected) } - render.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -720,8 +720,8 @@ func TestClient_Roots(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -769,8 +769,8 @@ func TestClient_Federation(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -820,8 +820,8 @@ func TestClient_SSHRoots(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.SSHRoots() @@ -912,8 +912,8 @@ func TestClient_RootFingerprint(t *testing.T) { c, err := NewClient(tt.server.URL, WithTransport(tr)) require.NoError(t, err) - tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() @@ -970,8 +970,8 @@ func TestClient_SSHBastion(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - render.JSONStatus(w, tt.response, tt.responseCode) + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) diff --git a/go.mod b/go.mod index 4c83e83e..d4725c7e 100644 --- a/go.mod +++ b/go.mod @@ -15,13 +15,13 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/go-tpm v0.9.0 github.com/google/uuid v1.6.0 - github.com/googleapis/gax-go/v2 v2.12.3 + github.com/googleapis/gax-go/v2 v2.12.4 github.com/hashicorp/vault/api v1.13.0 github.com/hashicorp/vault/api/auth/approle v0.6.0 github.com/hashicorp/vault/api/auth/kubernetes v0.6.0 github.com/newrelic/go-agent/v3 v3.33.0 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.19.0 + github.com/prometheus/client_golang v1.19.1 github.com/rs/xid v1.5.0 github.com/sirupsen/logrus v1.9.3 github.com/slackhq/nebula v1.6.1 @@ -33,23 +33,23 @@ require ( github.com/stretchr/testify v1.9.0 github.com/urfave/cli v1.22.15 go.step.sm/cli-utils v0.9.0 - go.step.sm/crypto v0.44.8 + go.step.sm/crypto v0.45.0 go.step.sm/linkedca v0.20.1 golang.org/x/crypto v0.23.0 golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 - golang.org/x/net v0.24.0 - google.golang.org/api v0.177.0 + golang.org/x/net v0.25.0 + google.golang.org/api v0.180.0 google.golang.org/grpc v1.63.2 - google.golang.org/protobuf v1.34.0 + google.golang.org/protobuf v1.34.1 ) require ( cloud.google.com/go v0.112.2 // indirect - cloud.google.com/go/auth v0.3.0 // indirect + cloud.google.com/go/auth v0.4.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect - cloud.google.com/go/iam v1.1.7 // indirect - cloud.google.com/go/kms v1.15.8 // indirect + cloud.google.com/go/iam v1.1.8 // indirect + cloud.google.com/go/kms v1.16.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect @@ -62,18 +62,18 @@ require ( github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/ThalesIgnite/crypto11 v1.2.5 // indirect github.com/aws/aws-sdk-go-v2 v1.26.1 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.11 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.13 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.13 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect - github.com/aws/aws-sdk-go-v2/service/kms v1.31.0 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/kms v1.31.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 // indirect github.com/aws/smithy-go v1.20.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v3 v3.0.0 // indirect @@ -156,13 +156,13 @@ require ( go.opentelemetry.io/otel v1.24.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/otel/trace v1.24.0 // indirect - golang.org/x/oauth2 v0.19.0 // indirect + golang.org/x/oauth2 v0.20.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e9db4c9e..aef49cbb 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,16 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.112.2 h1:ZaGT6LiG7dBzi6zNOvVZwacaXlmf3lRqnC4DQzqyRQw= cloud.google.com/go v0.112.2/go.mod h1:iEqjp//KquGIJV/m+Pk3xecgKNhV+ry+vVTsy4TbDms= -cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs= -cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w= +cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= +cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM= -cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA= -cloud.google.com/go/kms v1.15.8 h1:szIeDCowID8th2i8XE4uRev5PMxQFqW+JjwYxL9h6xs= -cloud.google.com/go/kms v1.15.8/go.mod h1:WoUHcDjD9pluCg7pNds131awnH429QGvRM3N/4MyoVs= +cloud.google.com/go/iam v1.1.8 h1:r7umDwhj+BQyz0ScZMp4QrGXjSTI3ZINnpgU2nlB/K0= +cloud.google.com/go/iam v1.1.8/go.mod h1:GvE6lyMmfxXauzNq8NbgJbeVQNspG+tcdL/W8QO1+zE= +cloud.google.com/go/kms v1.16.0 h1:1yZsRPhmargZOmY+fVAh8IKiR9HzCb0U1zsxb5g2nRY= +cloud.google.com/go/kms v1.16.0/go.mod h1:olQUXy2Xud+1GzYfiBO9N0RhjsJk5IJLU6n/ethLXVc= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= cloud.google.com/go/security v1.16.1 h1:9Jn8BJpkq8MflNzTdrX4m+SVp2+WeqVhbFiwyNIoXuM= @@ -48,10 +48,10 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= -github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= -github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= -github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= -github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/config v1.27.13 h1:WbKW8hOzrWoOA/+35S5okqO/2Ap8hkkFUzoW8Hzq24A= +github.com/aws/aws-sdk-go-v2/config v1.27.13/go.mod h1:XLiyiTMnguytjRER7u5RIkhIqS8Nyz41SwAWb4xEjxs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.13 h1:XDCJDzk/u5cN7Aple7D/MiAhx1Rjo/0nueJ0La8mRuE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.13/go.mod h1:FMNcjQrmuBYvOTZDtOLCIu0esmxjF7RuA/89iSXWzQI= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= @@ -64,14 +64,14 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1x github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= -github.com/aws/aws-sdk-go-v2/service/kms v1.31.0 h1:yl7wcqbisxPzknJVfWTLnK83McUvXba+pz2+tPbIUmQ= -github.com/aws/aws-sdk-go-v2/service/kms v1.31.0/go.mod h1:2snWQJQUKsbN66vAawJuOGX7dr37pfOq9hb0tZDGIqQ= -github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= -github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= -github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= -github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/aws-sdk-go-v2/service/kms v1.31.1 h1:5wtyAwuUiJiM3DHYeGZmP5iMonM7DFBWAEaaVPHYZA0= +github.com/aws/aws-sdk-go-v2/service/kms v1.31.1/go.mod h1:2snWQJQUKsbN66vAawJuOGX7dr37pfOq9hb0tZDGIqQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6 h1:o5cTaeunSpfXiLTIBx5xo2enQmiChtu1IBbzXnfU9Hs= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.6/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0 h1:Qe0r0lVURDDeBQJ4yP+BOrJkvkiCo/3FH/t+wY11dmw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.0/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7 h1:et3Ta53gotFR4ERLXXHIHl/Uuk1qYpP5uU7cvNql8ns= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.7/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -235,8 +235,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= -github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= +github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= +github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -398,8 +398,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= -github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= -github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= @@ -498,14 +498,14 @@ go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= +go.opentelemetry.io/otel/sdk v1.24.0 h1:YMPPDNymmQN3ZgczicBY3B6sf9n62Dlj9pWD3ucgoDw= +go.opentelemetry.io/otel/sdk v1.24.0/go.mod h1:KVrIYw6tEubO9E96HQpcmpTKDVn9gdv35HoYiQWGDFg= go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= go.step.sm/cli-utils v0.9.0 h1:55jYcsQbnArNqepZyAwcato6Zy2MoZDRkWW+jF+aPfQ= go.step.sm/cli-utils v0.9.0/go.mod h1:Y/CRoWl1FVR9j+7PnAewufAwKmBOTzR6l9+7EYGAnp8= -go.step.sm/crypto v0.44.8 h1:jDSHL6FdB1UTA0d56ECNx9XtLVkewzeg38Vy3HWB3N8= -go.step.sm/crypto v0.44.8/go.mod h1:QEmu4T9YewrDuaJnrV1I0zWZ15aJ/mqRUfL5w3R2WgU= +go.step.sm/crypto v0.45.0 h1:Z0WYAaaOYrJmKP9sJkPW+6wy3pgN3Ija8ek/D4serjc= +go.step.sm/crypto v0.45.0/go.mod h1:6IYlT0L2jfj81nVyCPpvA5cORy0EVHPhieSgQyuwHIY= go.step.sm/linkedca v0.20.1 h1:bHDn1+UG1NgRrERkWbbCiAIvv4lD5NOFaswPDTyO5vU= go.step.sm/linkedca v0.20.1/go.mod h1:Vaq4+Umtjh7DLFI1KuIxeo598vfBzgSYZUjgVJ7Syxw= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -566,11 +566,11 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg= -golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8= +golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= +golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -663,8 +663,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk= -google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw= +google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= +google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -672,8 +672,8 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda h1:wu/KJm9KJwpfHWhkkZGohVC6KRrc1oJNr4jwtQMOQXw= google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda/go.mod h1:g2LLCvCeCSir/JJSWosk19BR4NVxGqHUC6rxIRsd7Aw= -google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6 h1:DTJM0R8LECCgFeUwApvcEJHz85HLagW8uRENYxHh1ww= -google.golang.org/genproto/googleapis/api v0.0.0-20240429193739-8cf5692501f6/go.mod h1:10yRODfgim2/T8csjQsMPgZOMvtytXKTDRzH6HRGzRw= +google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae h1:AH34z6WAGVNkllnKs5raNq3yRq93VnjBG6rpfub/jYk= +google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6 h1:DujSIu+2tC9Ht0aPNA7jgj23Iq8Ewi5sgkQ++wdvonE= google.golang.org/genproto/googleapis/rpc v0.0.0-20240429193739-8cf5692501f6/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= @@ -692,8 +692,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.34.0 h1:Qo/qEd2RZPCf2nKuorzksSknv0d3ERwp1vFG38gSmH4= -google.golang.org/protobuf v1.34.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/scep/api/api.go b/scep/api/api.go index fd3c61ea..3649bf66 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -97,7 +97,7 @@ func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc func Get(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { - fail(w, fmt.Errorf("invalid scep get request: %w", err)) + fail(w, r, fmt.Errorf("invalid scep get request: %w", err)) return } @@ -116,18 +116,18 @@ func Get(w http.ResponseWriter, r *http.Request) { } if err != nil { - fail(w, fmt.Errorf("scep get request failed: %w", err)) + fail(w, r, fmt.Errorf("scep get request failed: %w", err)) return } - writeResponse(w, res) + writeResponse(w, r, res) } // Post handles all SCEP POST requests func Post(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { - fail(w, fmt.Errorf("invalid scep post request: %w", err)) + fail(w, r, fmt.Errorf("invalid scep post request: %w", err)) return } @@ -140,11 +140,11 @@ func Post(w http.ResponseWriter, r *http.Request) { } if err != nil { - fail(w, fmt.Errorf("scep post request failed: %w", err)) + fail(w, r, fmt.Errorf("scep post request failed: %w", err)) return } - writeResponse(w, res) + writeResponse(w, r, res) } func decodeRequest(r *http.Request) (request, error) { @@ -274,7 +274,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { - fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name)) + fail(w, r, fmt.Errorf("error url unescaping provisioner name '%s'", name)) return } @@ -282,13 +282,13 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { auth := authority.MustFromContext(ctx) p, err := auth.LoadProvisionerByName(provisionerName) if err != nil { - fail(w, err) + fail(w, r, err) return } prov, ok := p.(*provisioner.SCEP) if !ok { - fail(w, errors.New("provisioner must be of type SCEP")) + fail(w, r, errors.New("provisioner must be of type SCEP")) return } @@ -430,9 +430,9 @@ func formatCapabilities(caps []string) []byte { } // writeResponse writes a SCEP response back to the SCEP client. -func writeResponse(w http.ResponseWriter, res Response) { +func writeResponse(w http.ResponseWriter, r *http.Request, res Response) { if res.Error != nil { - log.Error(w, res.Error) + log.Error(w, r, res.Error) } if res.Certificate != nil { @@ -443,8 +443,8 @@ func writeResponse(w http.ResponseWriter, res Response) { _, _ = w.Write(res.Data) } -func fail(w http.ResponseWriter, err error) { - log.Error(w, err) +func fail(w http.ResponseWriter, r *http.Request, err error) { + log.Error(w, r, err) http.Error(w, err.Error(), http.StatusInternalServerError) }