Merge pull request #1849 from smallstep/mariano/log-errors

Log errors using slog.Logger
pull/1844/head^2
Mariano Cano 1 month ago committed by GitHub
commit 14959dbb2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -82,23 +82,23 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
var nar NewAccountRequest var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil { if err := json.Unmarshal(payload.value, &nar); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := nar.Validate(); err != nil { if err := nar.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov, err := acmeProvisionerFromContext(ctx) prov, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -108,26 +108,26 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
var acmeErr *acme.Error var acmeErr *acme.Error
if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest { if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ... // Something went wrong ...
render.Error(w, err) render.Error(w, r, err)
return return
} }
// Account does not exist // // Account does not exist //
if nar.OnlyReturnExisting { if nar.OnlyReturnExisting {
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist")) "account does not exist"))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
eak, err := validateExternalAccountBinding(ctx, &nar) eak, err := validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -140,17 +140,17 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
ProvisionerName: prov.Name, ProvisionerName: prov.Name,
} }
if err := db.CreateAccount(ctx, acc); err != nil { if err := db.CreateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, r, acme.WrapErrorISE(err, "error creating account"))
return return
} }
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
if err := eak.BindTo(acc); err != nil { if err := eak.BindTo(acc); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, r, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
acc.ExternalAccountBinding = nar.ExternalAccountBinding acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -163,7 +163,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID)) w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID))
render.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, r, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
@ -174,12 +174,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -188,12 +188,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet { if !payload.isPostAsGet {
var uar UpdateAccountRequest var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil { if err := json.Unmarshal(payload.value, &uar); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := uar.Validate(); err != nil { if err := uar.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if len(uar.Status) > 0 || len(uar.Contact) > 0 { if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -204,7 +204,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
} }
if err := db.UpdateAccount(ctx, acc); err != nil { if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, r, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
@ -213,7 +213,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc) render.JSON(w, r, acc)
} }
func logOrdersByAccount(w http.ResponseWriter, oids []string) { func logOrdersByAccount(w http.ResponseWriter, oids []string) {
@ -233,23 +233,23 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := db.GetOrdersByAccountID(ctx, acc.ID) orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
linker.LinkOrdersByAccountID(ctx, orders) linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders) render.JSON(w, r, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)
} }

@ -223,13 +223,13 @@ func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{ render.JSON(w, r, &Directory{
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
@ -273,8 +273,8 @@ func shouldAddMetaObject(p *provisioner.ACME) bool {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func NotImplemented(w http.ResponseWriter, _ *http.Request) { func NotImplemented(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
@ -285,28 +285,28 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
} }
if acc.ID != az.AccountID { if acc.ID != az.AccountID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, db); err != nil { if err = az.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, r, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
linker.LinkAuthorization(ctx, az) linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az) render.JSON(w, r, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
@ -317,13 +317,13 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -336,22 +336,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
} }
ch.AuthorizationID = azID ch.AuthorizationID = azID
if acc.ID != ch.AccountID { if acc.ID != ch.AccountID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID)) "account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if err = ch.Validate(ctx, db, jwk, payload.value); err != nil { if err = ch.Validate(ctx, db, jwk, payload.value); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, r, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
@ -359,7 +359,7 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch) render.JSON(w, r, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
@ -369,18 +369,18 @@ func GetCertificate(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
certID := chi.URLParam(r, "certID") certID := chi.URLParam(r, "certID")
cert, err := db.GetCertificate(ctx, certID) cert, err := db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate"))
return return
} }
if cert.AccountID != acc.ID { if cert.AccountID != acc.ID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID)) "account '%s' does not own certificate '%s'", acc.ID, certID))
return return
} }

@ -36,7 +36,7 @@ func addNonce(next nextHTTP) nextHTTP {
db := acme.MustDatabaseFromContext(r.Context()) db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context()) nonce, err := db.CreateNonce(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Replay-Nonce", string(nonce))
@ -64,7 +64,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -88,7 +88,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
return return
} }
} }
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct)) "expected content-type to be in %s, but got %s", expected, ct))
} }
} }
@ -98,12 +98,12 @@ func parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) render.Error(w, r, acme.WrapErrorISE(err, "failed to read request body"))
return return
} }
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
if err != nil { if err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return return
} }
ctx := context.WithValue(r.Context(), jwsContextKey, jws) ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -133,15 +133,15 @@ func validateJWS(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if len(jws.Signatures) == 0 { if len(jws.Signatures) == 0 {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return return
} }
if len(jws.Signatures) > 1 { if len(jws.Signatures) > 1 {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return return
} }
@ -152,7 +152,7 @@ func validateJWS(next nextHTTP) nextHTTP {
uh.Algorithm != "" || uh.Algorithm != "" ||
uh.Nonce != "" || uh.Nonce != "" ||
len(uh.ExtraHeaders) > 0 { len(uh.ExtraHeaders) > 0 {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return return
} }
hdr := sig.Protected hdr := sig.Protected
@ -162,13 +162,13 @@ func validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes { if k.Size() < keyutil.MinRSAKeyBytes {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size", "rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return return
} }
default: default:
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match")) "jws key type and algorithm do not match"))
return return
} }
@ -176,35 +176,35 @@ func validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) render.Error(w, r, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
// Check that the JWS url matches the requested url. // Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok { if !ok {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return return
} }
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() { if jwsURL != reqURL.String() {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return return
} }
if hdr.JSONWebKey != nil && hdr.KeyID != "" { if hdr.JSONWebKey != nil && hdr.KeyID != "" {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && hdr.KeyID == "" { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
next(w, r) next(w, r)
@ -221,23 +221,23 @@ func extractJWK(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
jwk := jws.Signatures[0].Protected.JSONWebKey jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil { if jwk == nil {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return return
} }
if !jwk.Valid() { if !jwk.Valid() {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return return
} }
// Overwrite KeyID with the JWK thumbprint. // Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk) jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) render.Error(w, r, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return return
} }
@ -251,11 +251,11 @@ func extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
break break
case err != nil: case err != nil:
render.Error(w, err) render.Error(w, r, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -274,11 +274,11 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
if ok { if ok {
ok, err := checkFunc(ctx) ok, err := checkFunc(ctx)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) render.Error(w, r, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return return
} }
if !ok { if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return return
} }
} }
@ -296,13 +296,13 @@ func lookupJWK(next nextHTTP) nextHTTP {
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if kid == "" { if kid == "" {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
return return
} }
@ -310,14 +310,14 @@ func lookupJWK(next nextHTTP) nextHTTP {
acc, err := db.GetAccount(ctx, accID) acc, err := db.GetAccount(ctx, accID)
switch { switch {
case acme.IsErrNotFound(err): case acme.IsErrNotFound(err):
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
render.Error(w, err) render.Error(w, r, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
@ -325,7 +325,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
if kid != storedLocation { if kid != storedLocation {
// ACME accounts should have a stored location equivalent to the // ACME accounts should have a stored location equivalent to the
// kid in the ACME request. // kid in the ACME request.
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"kid does not match stored account location; expected %s, but got %s", "kid does not match stored account location; expected %s, but got %s",
storedLocation, kid)) storedLocation, kid))
return return
@ -339,7 +339,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
if reqProvName != accProvName { if reqProvName != accProvName {
// Provisioner in the URL must match the provisioner with // Provisioner in the URL must match the provisioner with
// which the account was created. // which the account was created.
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
accProvName, reqProvName)) accProvName, reqProvName))
return return
@ -353,7 +353,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s", "kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid)) kidPrefix, kid))
return return
@ -374,7 +374,7 @@ func extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -410,16 +410,16 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }
@ -428,11 +428,11 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
case errors.Is(err, jose.ErrCryptoFailure): case errors.Is(err, jose.ErrCryptoFailure):
payload, err = retryVerificationWithPatchedSignatures(jws, jwk) payload, err = retryVerificationWithPatchedSignatures(jws, jwk)
if err != nil { if err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)")) render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)"))
return return
} }
case err != nil: case err != nil:
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return return
} }
@ -549,11 +549,11 @@ func isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if !payload.isPostAsGet { if !payload.isPostAsGet {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return return
} }
next(w, r) next(w, r)

@ -99,29 +99,29 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload")) "failed to unmarshal new-order request payload"))
return return
} }
if err := nor.Validate(); err != nil { if err := nor.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -130,39 +130,39 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
var eak *acme.ExternalAccountKey var eak *acme.ExternalAccountKey
if acmeProv.RequireEAB { if acmeProv.RequireEAB {
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving external account binding key"))
return return
} }
} }
acmePolicy, err := newACMEPolicyEngine(eak) acmePolicy, err := newACMEPolicyEngine(eak)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine")) render.Error(w, r, acme.WrapErrorISE(err, "error creating ACME policy engine"))
return return
} }
for _, identifier := range nor.Identifiers { for _, identifier := range nor.Identifiers {
// evaluate the ACME account level policy // evaluate the ACME account level policy
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil { if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return return
} }
// evaluate the provisioner level policy // evaluate the provisioner level policy
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil { if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return return
} }
// evaluate the authority level policy // evaluate the authority level policy
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return return
} }
} }
@ -188,7 +188,7 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := newAuthorization(ctx, az); err != nil { if err := newAuthorization(ctx, az); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
o.AuthorizationIDs[i] = az.ID o.AuthorizationIDs[i] = az.ID
@ -207,14 +207,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
} }
if err := db.CreateOrder(ctx, o); err != nil { if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, r, acme.WrapErrorISE(err, "error creating order"))
return return
} }
linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, r, o, http.StatusCreated)
} }
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error { func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
@ -288,39 +288,39 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, db); err != nil { if err = o.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, r, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, r, o)
} }
// FinalizeOrder attempts to finalize an order and create a certificate. // FinalizeOrder attempts to finalize an order and create a certificate.
@ -331,56 +331,56 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
var fr FinalizeRequest var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil { if err := json.Unmarshal(payload.value, &fr); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload")) "failed to unmarshal finalize-order request payload"))
return return
} }
if err := fr.Validate(); err != nil { if err := fr.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
ca := mustAuthority(ctx) ca := mustAuthority(ctx)
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, r, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, r, o)
} }
// challengeTypes determines the types of challenges that should be used // challengeTypes determines the types of challenges that should be used

@ -33,65 +33,65 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
var p revokePayload var p revokePayload
err = json.Unmarshal(payload.value, &p) err = json.Unmarshal(payload.value, &p)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) render.Error(w, r, acme.WrapErrorISE(err, "error unmarshaling payload"))
return return
} }
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil { if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate // in this case the most likely cause is a client that didn't properly encode the certificate
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return return
} }
certToBeRevoked, err := x509.ParseCertificate(certBytes) certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil { if err != nil {
// in this case a client may have encoded something different than a certificate // in this case a client may have encoded something different than a certificate
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return return
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := db.GetCertificateBySerial(ctx, serial) dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
} }
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen // this should never happen
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) render.Error(w, r, acme.NewErrorISE("certificate raw bytes are not equal"))
return return
} }
if shouldCheckAccountFrom(jws) { if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx) account, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
render.Error(w, acmeErr) render.Error(w, r, acmeErr)
return return
} }
} else { } else {
@ -100,7 +100,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey) _, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil { if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) render.Error(w, r, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return return
} }
} }
@ -108,19 +108,19 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
ca := mustAuthority(ctx) ca := mustAuthority(ctx)
hasBeenRevokedBefore, err := ca.IsRevoked(serial) hasBeenRevokedBefore, err := ca.IsRevoked(serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, r, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
} }
if hasBeenRevokedBefore { if hasBeenRevokedBefore {
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) render.Error(w, r, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return return
} }
reasonCode := p.ReasonCode reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode) acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil { if acmeErr != nil {
render.Error(w, acmeErr) render.Error(w, r, acmeErr)
return return
} }
@ -128,14 +128,14 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "") err = prov.AuthorizeRevoke(ctx, "")
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) render.Error(w, r, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return return
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = ca.Revoke(ctx, options) err = ca.Revoke(ctx, options)
if err != nil { if err != nil {
render.Error(w, wrapRevokeErr(err)) render.Error(w, r, wrapRevokeErr(err))
return return
} }

@ -424,7 +424,7 @@ func (e *Error) ToLog() (interface{}, error) {
} }
// Render implements render.RenderableError for Error. // Render implements render.RenderableError for Error.
func (e *Error) Render(w http.ResponseWriter) { func (e *Error) Render(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/problem+json") w.Header().Set("Content-Type", "application/problem+json")
render.JSONStatus(w, e, e.StatusCode()) render.JSONStatus(w, r, e, e.StatusCode())
} }

@ -186,19 +186,19 @@ func (l *linker) Middleware(next http.Handler) http.Handler {
nameEscaped := chi.URLParam(r, "provisionerID") nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped) name, err := url.PathUnescape(nameEscaped)
if err != nil { if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) render.Error(w, r, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return return
} }
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
acmeProv, ok := p.(*provisioner.ACME) acmeProv, ok := p.(*provisioner.ACME)
if !ok { if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) render.Error(w, r, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return return
} }

@ -353,15 +353,15 @@ func Route(r Router) {
// Version is an HTTP handler that returns the version of the server. // Version is an HTTP handler that returns the version of the server.
func Version(w http.ResponseWriter, r *http.Request) { func Version(w http.ResponseWriter, r *http.Request) {
v := mustAuthority(r.Context()).Version() v := mustAuthority(r.Context()).Version()
render.JSON(w, VersionResponse{ render.JSON(w, r, VersionResponse{
Version: v.Version, Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication, RequireClientAuthentication: v.RequireClientAuthentication,
}) })
} }
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func Health(w http.ResponseWriter, _ *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
render.JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, r, HealthResponse{Status: "ok"})
} }
// Root is an HTTP handler that using the SHA256 from the URL, returns the root // Root is an HTTP handler that using the SHA256 from the URL, returns the root
@ -372,11 +372,11 @@ func Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the // Load root certificate with the
cert, err := mustAuthority(r.Context()).Root(sum) cert, err := mustAuthority(r.Context()).Root(sum)
if err != nil { if err != nil {
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) render.Error(w, r, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
} }
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) render.JSON(w, r, &RootResponse{RootPEM: Certificate{cert}})
} }
func certChainToPEM(certChain []*x509.Certificate) []Certificate { func certChainToPEM(certChain []*x509.Certificate) []Certificate {
@ -391,17 +391,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func Provisioners(w http.ResponseWriter, r *http.Request) { func Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &ProvisionersResponse{ render.JSON(w, r, &ProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -412,18 +412,18 @@ func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
if err != nil { if err != nil {
render.Error(w, errs.NotFoundErr(err)) render.Error(w, r, errs.NotFoundErr(err))
return return
} }
render.JSON(w, &ProvisionerKeyResponse{key}) render.JSON(w, r, &ProvisionerKeyResponse{key})
} }
// Roots returns all the root certificates for the CA. // Roots returns all the root certificates for the CA.
func Roots(w http.ResponseWriter, r *http.Request) { func Roots(w http.ResponseWriter, r *http.Request) {
roots, err := mustAuthority(r.Context()).GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting roots")) render.Error(w, r, errs.ForbiddenErr(err, "error getting roots"))
return return
} }
@ -432,7 +432,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]} certs[i] = Certificate{roots[i]}
} }
render.JSONStatus(w, &RootsResponse{ render.JSONStatus(w, r, &RootsResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }
@ -441,7 +441,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
func RootsPEM(w http.ResponseWriter, r *http.Request) { func RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := mustAuthority(r.Context()).GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
@ -454,7 +454,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
}) })
if _, err := w.Write(block); err != nil { if _, err := w.Write(block); err != nil {
log.Error(w, err) log.Error(w, r, err)
return return
} }
} }
@ -464,7 +464,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
func Federation(w http.ResponseWriter, r *http.Request) { func Federation(w http.ResponseWriter, r *http.Request) {
federated, err := mustAuthority(r.Context()).GetFederation() federated, err := mustAuthority(r.Context()).GetFederation()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) render.Error(w, r, errs.ForbiddenErr(err, "error getting federated roots"))
return return
} }
@ -473,7 +473,7 @@ func Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]} certs[i] = Certificate{federated[i]}
} }
render.JSONStatus(w, &FederationResponse{ render.JSONStatus(w, r, &FederationResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }

@ -13,12 +13,12 @@ import (
func CRL(w http.ResponseWriter, r *http.Request) { func CRL(w http.ResponseWriter, r *http.Request) {
crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList() crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList()
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if crlInfo == nil { if crlInfo == nil {
render.Error(w, errs.New(http.StatusNotFound, "no CRL available")) render.Error(w, r, errs.New(http.StatusNotFound, "no CRL available"))
return return
} }

@ -2,6 +2,7 @@
package log package log
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@ -9,6 +10,29 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type errorLoggerKey struct{}
// ErrorLogger is the function type used to log errors.
type ErrorLogger func(http.ResponseWriter, *http.Request, error)
func (fn ErrorLogger) call(w http.ResponseWriter, r *http.Request, err error) {
if fn == nil {
return
}
fn(w, r, err)
}
// WithErrorLogger returns a new context with the given error logger.
func WithErrorLogger(ctx context.Context, fn ErrorLogger) context.Context {
return context.WithValue(ctx, errorLoggerKey{}, fn)
}
// ErrorLoggerFromContext returns an error logger from the context.
func ErrorLoggerFromContext(ctx context.Context) (fn ErrorLogger) {
fn, _ = ctx.Value(errorLoggerKey{}).(ErrorLogger)
return
}
// StackTracedError is the set of errors implementing the StackTrace function. // StackTracedError is the set of errors implementing the StackTrace function.
// //
// Errors implementing this interface have their stack traces logged when passed // Errors implementing this interface have their stack traces logged when passed
@ -27,8 +51,10 @@ type fieldCarrier interface {
// Error adds to the response writer the given error if it implements // Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error // logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package. // using the log package.
func Error(rw http.ResponseWriter, err error) { func Error(w http.ResponseWriter, r *http.Request, err error) {
fc, ok := rw.(fieldCarrier) ErrorLoggerFromContext(r.Context()).call(w, r, err)
fc, ok := w.(fieldCarrier)
if !ok { if !ok {
return return
} }
@ -51,7 +77,7 @@ func Error(rw http.ResponseWriter, err error) {
// EnabledResponse log the response object if it implements the EnableLogger // EnabledResponse log the response object if it implements the EnableLogger
// interface. // interface.
func EnabledResponse(rw http.ResponseWriter, v any) { func EnabledResponse(rw http.ResponseWriter, r *http.Request, v any) {
type enableLogger interface { type enableLogger interface {
ToLog() (any, error) ToLog() (any, error)
} }
@ -59,7 +85,7 @@ func EnabledResponse(rw http.ResponseWriter, v any) {
if el, ok := v.(enableLogger); ok { if el, ok := v.(enableLogger); ok {
out, err := el.ToLog() out, err := el.ToLog()
if err != nil { if err != nil {
Error(rw, err) Error(rw, r, err)
return return
} }

@ -1,6 +1,9 @@
package log package log
import ( import (
"bytes"
"encoding/json"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -27,21 +30,34 @@ func (stackTracedError) StackTrace() pkgerrors.StackTrace {
} }
func TestError(t *testing.T) { func TestError(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{}))
req := httptest.NewRequest("GET", "/test", http.NoBody)
reqWithLogger := req.WithContext(WithErrorLogger(req.Context(), func(w http.ResponseWriter, r *http.Request, err error) {
if err != nil {
logger.ErrorContext(r.Context(), "request failed", slog.Any("error", err))
}
}))
tests := []struct { tests := []struct {
name string name string
error error
rw http.ResponseWriter rw http.ResponseWriter
r *http.Request
isFieldCarrier bool isFieldCarrier bool
isSlogLogger bool
stepDebug bool stepDebug bool
expectStackTrace bool expectStackTrace bool
}{ }{
{"noLogger", nil, nil, false, false, false}, {"noLogger", nil, nil, req, false, false, false, false},
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, {"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, {"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, {"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false}, {"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, {"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true}, {"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
{"slogWithNoError", nil, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
{"slogWithError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -52,27 +68,41 @@ func TestError(t *testing.T) {
t.Setenv("STEPDEBUG", "0") t.Setenv("STEPDEBUG", "0")
} }
Error(tt.rw, tt.error) Error(tt.rw, tt.r, tt.error)
// return early if test case doesn't use logger // return early if test case doesn't use logger
if !tt.isFieldCarrier { if !tt.isFieldCarrier && !tt.isSlogLogger {
return return
} }
fields := tt.rw.(logging.ResponseLogger).Fields() if tt.isFieldCarrier {
fields := tt.rw.(logging.ResponseLogger).Fields()
// expect the error field to be (not) set and to be the same error that was fed to Error // expect the error field to be (not) set and to be the same error that was fed to Error
if tt.error == nil { if tt.error == nil {
assert.Nil(t, fields["error"]) assert.Nil(t, fields["error"])
} else { } else {
assert.Same(t, tt.error, fields["error"]) assert.Same(t, tt.error, fields["error"])
}
// check if stack-trace is set when expected
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
t.Error(`ResponseLogger["stack-trace"] not set`)
} else if !tt.expectStackTrace && hasStackTrace {
t.Error(`ResponseLogger["stack-trace"] was set`)
}
} }
// check if stack-trace is set when expected if tt.isSlogLogger {
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace { b := buf.Bytes()
t.Error(`ResponseLogger["stack-trace"] not set`) if tt.error == nil {
} else if !tt.expectStackTrace && hasStackTrace { assert.Empty(t, b)
t.Error(`ResponseLogger["stack-trace"] was set`) } else if assert.NotEmpty(t, b) {
var m map[string]any
assert.NoError(t, json.Unmarshal(b, &m))
assert.Equal(t, tt.error.Error(), m["error"])
}
buf.Reset()
} }
}) })
} }

@ -51,7 +51,7 @@ func (e badProtoJSONError) Error() string {
} }
// Render implements render.RenderableError for badProtoJSONError // Render implements render.RenderableError for badProtoJSONError
func (e badProtoJSONError) Render(w http.ResponseWriter) { func (e badProtoJSONError) Render(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Type string `json:"type"` Type string `json:"type"`
Detail string `json:"detail"` Detail string `json:"detail"`
@ -62,5 +62,5 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) {
// trim the proto prefix for the message // trim the proto prefix for the message
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
} }
render.JSONStatus(w, v, http.StatusBadRequest) render.JSONStatus(w, r, v, http.StatusBadRequest)
} }

@ -142,7 +142,8 @@ func Test_badProtoJSONError_Render(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
tt.e.Render(w) r := httptest.NewRequest("POST", "/test", http.NoBody)
tt.e.Render(w, r)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()

@ -29,25 +29,25 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func Rekey(w http.ResponseWriter, r *http.Request) { func Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
render.Error(w, errs.BadRequest("missing client certificate")) render.Error(w, r, errs.BadRequest("missing client certificate"))
return return
} }
var body RekeyRequest var body RekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
a := mustAuthority(r.Context()) a := mustAuthority(r.Context())
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -57,7 +57,7 @@ func Rekey(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
render.JSONStatus(w, &SignResponse{ render.JSONStatus(w, r, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

@ -13,8 +13,8 @@ import (
) )
// JSON is shorthand for JSONStatus(w, v, http.StatusOK). // JSON is shorthand for JSONStatus(w, v, http.StatusOK).
func JSON(w http.ResponseWriter, v interface{}) { func JSON(w http.ResponseWriter, r *http.Request, v interface{}) {
JSONStatus(w, v, http.StatusOK) JSONStatus(w, r, v, http.StatusOK)
} }
// JSONStatus marshals v into w. It additionally sets the status code of // JSONStatus marshals v into w. It additionally sets the status code of
@ -22,7 +22,7 @@ func JSON(w http.ResponseWriter, v interface{}) {
// //
// JSONStatus sets the Content-Type of w to application/json unless one is // JSONStatus sets the Content-Type of w to application/json unless one is
// specified. // specified.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) { func JSONStatus(w http.ResponseWriter, r *http.Request, v interface{}, status int) {
setContentTypeUnlessPresent(w, "application/json") setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status) w.WriteHeader(status)
@ -43,7 +43,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
} }
} }
log.EnabledResponse(w, v) log.EnabledResponse(w, r, v)
} }
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). // ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
@ -80,22 +80,22 @@ func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
type RenderableError interface { type RenderableError interface {
error error
Render(http.ResponseWriter) Render(http.ResponseWriter, *http.Request)
} }
// Error marshals the JSON representation of err to w. In case err implements // Error marshals the JSON representation of err to w. In case err implements
// RenderableError its own Render method will be called instead. // RenderableError its own Render method will be called instead.
func Error(w http.ResponseWriter, err error) { func Error(rw http.ResponseWriter, r *http.Request, err error) {
log.Error(w, err) log.Error(rw, r, err)
var r RenderableError var re RenderableError
if errors.As(err, &r) { if errors.As(err, &re) {
r.Render(w) re.Render(rw, r)
return return
} }
JSONStatus(w, err, statusCodeFromError(err)) JSONStatus(rw, r, err, statusCodeFromError(err))
} }
// StatusCodedError is the set of errors that implement the basic StatusCode // StatusCodedError is the set of errors that implement the basic StatusCode

@ -18,8 +18,8 @@ import (
func TestJSON(t *testing.T) { func TestJSON(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
rw := logging.NewResponseLogger(rec) rw := logging.NewResponseLogger(rec)
r := httptest.NewRequest("POST", "/test", http.NoBody)
JSON(rw, map[string]interface{}{"foo": "bar"}) JSON(rw, r, map[string]interface{}{"foo": "bar"})
assert.Equal(t, http.StatusOK, rec.Result().StatusCode) assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
@ -64,7 +64,8 @@ func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | js
assert.ErrorAs(t, err, &e) assert.ErrorAs(t, err, &e)
}() }()
JSON(httptest.NewRecorder(), v) r := httptest.NewRequest("POST", "/test", http.NoBody)
JSON(httptest.NewRecorder(), r, v)
} }
type renderableError struct { type renderableError struct {
@ -76,10 +77,9 @@ func (err renderableError) Error() string {
return err.Message return err.Message
} }
func (err renderableError) Render(w http.ResponseWriter) { func (err renderableError) Render(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "something/custom") w.Header().Set("Content-Type", "something/custom")
JSONStatus(w, r, err, err.Code)
JSONStatus(w, err, err.Code)
} }
type statusedError struct { type statusedError struct {
@ -116,8 +116,8 @@ func TestError(t *testing.T) {
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/test", http.NoBody)
Error(rec, kase.err) Error(rec, r, kase.err)
assert.Equal(t, kase.code, rec.Result().StatusCode) assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.body, rec.Body.String()) assert.Equal(t, kase.body, rec.Body.String())

@ -23,19 +23,20 @@ func Renew(w http.ResponseWriter, r *http.Request) {
// Get the leaf certificate from the peer or the token. // Get the leaf certificate from the peer or the token.
cert, token, err := getPeerCertificate(r) cert, token, err := getPeerCertificate(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
// The token can be used by RAs to renew a certificate. // The token can be used by RAs to renew a certificate.
if token != "" { if token != "" {
ctx = authority.NewTokenContext(ctx, token) ctx = authority.NewTokenContext(ctx, token)
logOtt(w, token)
} }
a := mustAuthority(ctx) a := mustAuthority(ctx)
certChain, err := a.RenewContext(ctx, cert, nil) certChain, err := a.RenewContext(ctx, cert, nil)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -45,7 +46,7 @@ func Renew(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
render.JSONStatus(w, &SignResponse{ render.JSONStatus(w, r, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

@ -57,12 +57,12 @@ func (r *RevokeRequest) Validate() (err error) {
func Revoke(w http.ResponseWriter, r *http.Request) { func Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -81,7 +81,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
if body.OTT != "" { if body.OTT != "" {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := a.Authorize(ctx, body.OTT); err != nil { if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
@ -90,12 +90,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
render.Error(w, errs.BadRequest("missing ott or client certificate")) render.Error(w, r, errs.BadRequest("missing ott or client certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
render.Error(w, errs.BadRequest("serial number in client certificate different than body")) render.Error(w, r, errs.BadRequest("serial number in client certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.
@ -106,12 +106,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := a.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }
logRevoke(w, opts) logRevoke(w, opts)
render.JSON(w, &RevokeResponse{Status: "ok"}) render.JSON(w, r, &RevokeResponse{Status: "ok"})
} }
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

@ -52,13 +52,13 @@ type SignResponse struct {
func Sign(w http.ResponseWriter, r *http.Request) { func Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -74,13 +74,13 @@ func Sign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error signing certificate"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -90,7 +90,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
render.JSONStatus(w, &SignResponse{ render.JSONStatus(w, r, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

@ -253,19 +253,19 @@ type SSHBastionResponse struct {
func SSHSign(w http.ResponseWriter, r *http.Request) { func SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -273,7 +273,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) render.Error(w, r, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -293,13 +293,13 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
a := mustAuthority(ctx) a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
@ -307,7 +307,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -320,7 +320,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
@ -332,14 +332,14 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
} }
LogSSHCertificate(w, cert) LogSSHCertificate(w, cert)
render.JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, r, &SSHSignResponse{
Certificate: SSHCertificate{cert}, Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate, AddUserCertificate: addUserCertificate,
IdentityCertificate: identityCertificate, IdentityCertificate: identityCertificate,
@ -352,12 +352,12 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHRoots(ctx) keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
render.Error(w, errs.NotFound("no keys found")) render.Error(w, r, errs.NotFound("no keys found"))
return return
} }
@ -369,7 +369,7 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
render.JSON(w, resp) render.JSON(w, r, resp)
} }
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
@ -378,12 +378,12 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHFederation(ctx) keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
render.Error(w, errs.NotFound("no keys found")) render.Error(w, r, errs.NotFound("no keys found"))
return return
} }
@ -395,7 +395,7 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
render.JSON(w, resp) render.JSON(w, r, resp)
} }
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients // SSHConfig is an HTTP handler that returns rendered templates for ssh clients
@ -403,18 +403,18 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
func SSHConfig(w http.ResponseWriter, r *http.Request) { func SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
ctx := r.Context() ctx := r.Context()
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
@ -425,32 +425,32 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
cfg.HostTemplates = ts cfg.HostTemplates = ts
default: default:
render.Error(w, errs.InternalServer("it should hot get here")) render.Error(w, r, errs.InternalServer("it should hot get here"))
return return
} }
render.JSON(w, cfg) render.JSON(w, r, cfg)
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func SSHCheckHost(w http.ResponseWriter, r *http.Request) { func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
ctx := r.Context() ctx := r.Context()
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &SSHCheckPrincipalResponse{ render.JSON(w, r, &SSHCheckPrincipalResponse{
Exists: exists, Exists: exists,
}) })
} }
@ -465,10 +465,10 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &SSHGetHostsResponse{ render.JSON(w, r, &SSHGetHostsResponse{
Hosts: hosts, Hosts: hosts,
}) })
} }
@ -477,22 +477,22 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func SSHBastion(w http.ResponseWriter, r *http.Request) { func SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
ctx := r.Context() ctx := r.Context()
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &SSHBastionResponse{ render.JSON(w, r, &SSHBastionResponse{
Hostname: body.Hostname, Hostname: body.Hostname,
Bastion: bastion, Bastion: bastion,
}) })

@ -42,19 +42,19 @@ type SSHRekeyResponse struct {
func SSHRekey(w http.ResponseWriter, r *http.Request) { func SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -64,18 +64,18 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
a := mustAuthority(ctx) a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
} }
@ -85,12 +85,12 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
LogSSHCertificate(w, newCert) LogSSHCertificate(w, newCert)
render.JSONStatus(w, &SSHRekeyResponse{ render.JSONStatus(w, r, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

@ -40,13 +40,13 @@ type SSHRenewResponse struct {
func SSHRenew(w http.ResponseWriter, r *http.Request) { func SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -56,18 +56,18 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
a := mustAuthority(ctx) a := mustAuthority(ctx)
_, err := a.Authorize(ctx, body.OTT) _, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
newCert, err := a.RenewSSH(ctx, oldCert) newCert, err := a.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
} }
@ -77,12 +77,12 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
LogSSHCertificate(w, newCert) LogSSHCertificate(w, newCert)
render.JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, r, &SSHSignResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

@ -51,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func SSHRevoke(w http.ResponseWriter, r *http.Request) { func SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -75,18 +75,18 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := a.Authorize(ctx, body.OTT); err != nil { if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, r, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := a.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) render.Error(w, r, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }
logSSHRevoke(w, opts) logSSHRevoke(w, opts)
render.JSON(w, &SSHRevokeResponse{Status: "ok"}) render.JSON(w, r, &SSHRevokeResponse{Status: "ok"})
} }
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

@ -40,12 +40,12 @@ func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
acmeProvisioner := prov.GetDetails().GetACME() acmeProvisioner := prov.GetDetails().GetACME()
if acmeProvisioner == nil { if acmeProvisioner == nil {
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) render.Error(w, r, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
return return
} }
if !acmeProvisioner.RequireEab { if !acmeProvisioner.RequireEab {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) render.Error(w, r, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
return return
} }
@ -69,18 +69,18 @@ func NewACMEAdminResponder() ACMEAdminResponder {
} }
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey { func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {

@ -90,7 +90,7 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) {
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
if !ok { if !ok {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, render.Error(w, r, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id)) "admin %s not found", id))
return return
} }
@ -101,17 +101,17 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) {
func GetAdmins(w http.ResponseWriter, r *http.Request) { func GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) render.Error(w, r, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return return
} }
render.JSON(w, &GetAdminsResponse{ render.JSON(w, r, &GetAdminsResponse{
Admins: admins, Admins: admins,
NextCursor: nextCursor, NextCursor: nextCursor,
}) })
@ -121,19 +121,19 @@ func GetAdmins(w http.ResponseWriter, r *http.Request) {
func CreateAdmin(w http.ResponseWriter, r *http.Request) { func CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
auth := mustAuthority(r.Context()) auth := mustAuthority(r.Context())
p, err := auth.LoadProvisionerByName(body.Provisioner) p, err := auth.LoadProvisionerByName(body.Provisioner)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return return
} }
adm := &linkedca.Admin{ adm := &linkedca.Admin{
@ -143,7 +143,7 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) {
} }
// Store to authority collection. // Store to authority collection.
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing admin")) render.Error(w, r, admin.WrapErrorISE(err, "error storing admin"))
return return
} }
@ -155,23 +155,23 @@ func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) render.Error(w, r, admin.WrapErrorISE(err, "error deleting admin %s", id))
return return
} }
render.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, r, &DeleteResponse{Status: "ok"})
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func UpdateAdmin(w http.ResponseWriter, r *http.Request) { func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -179,7 +179,7 @@ func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
auth := mustAuthority(r.Context()) auth := mustAuthority(r.Context())
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) render.Error(w, r, admin.WrapErrorISE(err, "error updating admin %s", id))
return return
} }

@ -19,7 +19,7 @@ import (
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !mustAuthority(r.Context()).IsAdminAPIEnabled() { if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
return return
} }
next(w, r) next(w, r)
@ -31,7 +31,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if tok == "" { if tok == "" {
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, render.Error(w, r, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token")) "missing authorization header token"))
return return
} }
@ -39,7 +39,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
ctx := r.Context() ctx := r.Context()
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -64,13 +64,13 @@ func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
// TODO(hs): distinguish 404 vs. 500 // TODO(hs): distinguish 404 vs. 500
if p, err = auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
prov, err := adminDB.GetProvisioner(ctx, p.GetID()) prov, err := adminDB.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) render.Error(w, r, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
return return
} }
@ -91,7 +91,7 @@ func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.Handler
// when an action is not supported in standalone mode and when // when an action is not supported in standalone mode and when
// using a nosql.DB backend, actions are not supported // using a nosql.DB backend, actions are not supported
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok { if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode")) "operation not supported in standalone mode"))
return return
} }
@ -125,15 +125,15 @@ func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
if err != nil { if err != nil {
if acme.IsErrNotFound(err) { if acme.IsErrNotFound(err) {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) render.Error(w, r, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
return return
} }
if eak == nil { if eak == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return return
} }

@ -44,7 +44,7 @@ func NewPolicyAdminResponder() PolicyAdminResponder {
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -52,12 +52,12 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
var ae *admin.Error var ae *admin.Error
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return return
} }
if authorityPolicy == nil { if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return return
} }
@ -68,7 +68,7 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -77,26 +77,26 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
var ae *admin.Error var ae *admin.Error
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy"))
return return
} }
if authorityPolicy != nil { if authorityPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy") adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
render.Error(w, adminErr) render.Error(w, r, adminErr)
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return return
} }
@ -105,11 +105,11 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
var createdPolicy *linkedca.Policy var createdPolicy *linkedca.Policy
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy")) render.Error(w, r, admin.WrapErrorISE(err, "error storing authority policy"))
return return
} }
@ -120,7 +120,7 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -129,25 +129,25 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
var ae *admin.Error var ae *admin.Error
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy"))
return return
} }
if authorityPolicy == nil { if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return return
} }
@ -156,11 +156,11 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
var updatedPolicy *linkedca.Policy var updatedPolicy *linkedca.Policy
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy")) render.Error(w, r, admin.WrapErrorISE(err, "error updating authority policy"))
return return
} }
@ -171,7 +171,7 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -180,35 +180,35 @@ func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
var ae *admin.Error var ae *admin.Error
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return return
} }
if authorityPolicy == nil { if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return return
} }
if err := auth.RemoveAuthorityPolicy(ctx); err != nil { if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) render.Error(w, r, admin.WrapErrorISE(err, "error deleting authority policy"))
return return
} }
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
} }
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil { if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return return
} }
@ -219,7 +219,7 @@ func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -227,20 +227,20 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy != nil { if provisionerPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
render.Error(w, adminErr) render.Error(w, r, adminErr)
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return return
} }
@ -248,11 +248,11 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
auth := mustAuthority(ctx) auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy")) render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner policy"))
return return
} }
@ -263,27 +263,27 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil { if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return return
} }
@ -291,11 +291,11 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
auth := mustAuthority(ctx) auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy")) render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner policy"))
return return
} }
@ -306,13 +306,13 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
if prov.Policy == nil { if prov.Policy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return return
} }
@ -321,24 +321,24 @@ func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
auth := mustAuthority(ctx) auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner policy"))
return return
} }
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
} }
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return return
} }
@ -348,7 +348,7 @@ func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -357,20 +357,20 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy != nil { if eakPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
render.Error(w, adminErr) render.Error(w, r, adminErr)
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return return
} }
@ -379,7 +379,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx) acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) render.Error(w, r, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
return return
} }
@ -389,7 +389,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -397,20 +397,20 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return return
} }
var newPolicy = new(linkedca.Policy) var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil { if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
newPolicy.Deduplicate() newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil { if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return return
} }
@ -418,7 +418,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx) acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) render.Error(w, r, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
return return
} }
@ -428,7 +428,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -436,7 +436,7 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return return
} }
@ -446,11 +446,11 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx) acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) render.Error(w, r, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
return return
} }
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
} }
// blockLinkedCA blocks all API operations on linked deployments // blockLinkedCA blocks all API operations on linked deployments

@ -40,19 +40,19 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
if id != "" { if id != "" {
if p, err = auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
prov, err := db.GetProvisioner(ctx, p.GetID()) prov, err := db.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
render.ProtoJSON(w, prov) render.ProtoJSON(w, prov)
@ -62,17 +62,17 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
func GetProvisioners(w http.ResponseWriter, r *http.Request) { func GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, r, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &GetProvisionersResponse{ render.JSON(w, r, &GetProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -82,24 +82,24 @@ func GetProvisioners(w http.ResponseWriter, r *http.Request) {
func CreateProvisioner(w http.ResponseWriter, r *http.Request) { func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(prov.Claims); err != nil { if err := authority.ValidateClaims(prov.Claims); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
// validate the templates and template data // validate the templates and template data
if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil { if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
return return
} }
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) render.Error(w, r, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return return
} }
render.ProtoJSONStatus(w, prov, http.StatusCreated) render.ProtoJSONStatus(w, prov, http.StatusCreated)
@ -118,29 +118,29 @@ func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
if id != "" { if id != "" {
if p, err = auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) render.Error(w, r, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return return
} }
render.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, r, &DeleteResponse{Status: "ok"})
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -151,51 +151,51 @@ func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
p, err := auth.LoadProvisionerByName(name) p, err := auth.LoadProvisionerByName(name)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return return
} }
old, err := db.GetProvisioner(r.Context(), p.GetID()) old, err := db.GetProvisioner(r.Context(), p.GetID())
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
return return
} }
if nu.Id != old.Id { if nu.Id != old.Id {
render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) render.Error(w, r, admin.NewErrorISE("cannot change provisioner ID"))
return return
} }
if nu.Type != old.Type { if nu.Type != old.Type {
render.Error(w, admin.NewErrorISE("cannot change provisioner type")) render.Error(w, r, admin.NewErrorISE("cannot change provisioner type"))
return return
} }
if nu.AuthorityId != old.AuthorityId { if nu.AuthorityId != old.AuthorityId {
render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) render.Error(w, r, admin.NewErrorISE("cannot change provisioner authorityID"))
return return
} }
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) render.Error(w, r, admin.NewErrorISE("cannot change provisioner createdAt"))
return return
} }
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) render.Error(w, r, admin.NewErrorISE("cannot change provisioner deletedAt"))
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(nu.Claims); err != nil { if err := authority.ValidateClaims(nu.Claims); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
// validate the templates and template data // validate the templates and template data
if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil { if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
return return
} }
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
render.ProtoJSON(w, nu) render.ProtoJSON(w, nu)

@ -71,28 +71,28 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
var newWebhook = new(linkedca.Webhook) var newWebhook = new(linkedca.Webhook)
if err := read.ProtoJSON(r.Body, newWebhook); err != nil { if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if err := validateWebhook(newWebhook); err != nil { if err := validateWebhook(newWebhook); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if newWebhook.Secret != "" { if newWebhook.Secret != "" {
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set") err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set")
render.Error(w, err) render.Error(w, r, err)
return return
} }
if newWebhook.Id != "" { if newWebhook.Id != "" {
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set") err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set")
render.Error(w, err) render.Error(w, r, err)
return return
} }
id, err := randutil.UUIDv4() id, err := randutil.UUIDv4()
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error generating webhook id")) render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook id"))
return return
} }
newWebhook.Id = id newWebhook.Id = id
@ -101,14 +101,14 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
for _, wh := range prov.Webhooks { for _, wh := range prov.Webhooks {
if wh.Name == newWebhook.Name { if wh.Name == newWebhook.Name {
err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name) err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name)
render.Error(w, err) render.Error(w, r, err)
return return
} }
} }
secret, err := randutil.Bytes(64) secret, err := randutil.Bytes(64)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error generating webhook secret")) render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook secret"))
return return
} }
newWebhook.Secret = base64.StdEncoding.EncodeToString(secret) newWebhook.Secret = base64.StdEncoding.EncodeToString(secret)
@ -117,11 +117,11 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner webhook")) render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner webhook"))
return return
} }
@ -145,21 +145,21 @@ func (war *webhookAdminResponder) DeleteProvisionerWebhook(w http.ResponseWriter
} }
} }
if !found { if !found {
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
return return
} }
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner webhook")) render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner webhook"))
return return
} }
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
} }
func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) { func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) {
@ -170,12 +170,12 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
var newWebhook = new(linkedca.Webhook) var newWebhook = new(linkedca.Webhook)
if err := read.ProtoJSON(r.Body, newWebhook); err != nil { if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
if err := validateWebhook(newWebhook); err != nil { if err := validateWebhook(newWebhook); err != nil {
render.Error(w, err) render.Error(w, r, err)
return return
} }
@ -186,13 +186,13 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
} }
if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret { if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret {
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated") err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated")
render.Error(w, err) render.Error(w, r, err)
return return
} }
newWebhook.Secret = wh.Secret newWebhook.Secret = wh.Secret
if newWebhook.Id != "" && newWebhook.Id != wh.Id { if newWebhook.Id != "" && newWebhook.Id != wh.Id {
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated") err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated")
render.Error(w, err) render.Error(w, r, err)
return return
} }
newWebhook.Id = wh.Id newWebhook.Id = wh.Id
@ -203,17 +203,17 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
if !found { if !found {
msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name) msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name)
err := admin.NewError(admin.ErrorNotFoundType, msg) err := admin.NewError(admin.ErrorNotFoundType, msg)
render.Error(w, err) render.Error(w, r, err)
return return
} }
if err := auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook")) render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook"))
return return
} }
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner webhook")) render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner webhook"))
return return
} }

@ -205,8 +205,8 @@ func (e *Error) ToLog() (interface{}, error) {
} }
// Render implements render.RenderableError for Error. // Render implements render.RenderableError for Error.
func (e *Error) Render(w http.ResponseWriter) { func (e *Error) Render(w http.ResponseWriter, r *http.Request) {
e.Message = e.Err.Error() e.Message = e.Err.Error()
render.JSONStatus(w, e, e.StatusCode()) render.JSONStatus(w, r, e, e.StatusCode())
} }

@ -108,19 +108,19 @@ func TestNewACMEClient(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
switch { switch {
case i == 0: case i == 0:
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
case i == 1: case i == 1:
w.Header().Set("Replay-Nonce", "abc123") w.Header().Set("Replay-Nonce", "abc123")
render.JSONStatus(w, []byte{}, 200) render.JSONStatus(w, r, []byte{}, 200)
i++ i++
default: default:
w.Header().Set("Location", accLocation) w.Header().Set("Location", accLocation)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
} }
}) })
@ -203,10 +203,10 @@ func TestACMEClient_GetNonce(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
}) })
if nonce, err := ac.GetNonce(); err != nil { if nonce, err := ac.GetNonce(); err != nil {
@ -310,18 +310,18 @@ func TestACMEClient_post(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -338,7 +338,7 @@ func TestACMEClient_post(t *testing.T) {
assert.Equals(t, hdr.KeyID, ac.kid) assert.Equals(t, hdr.KeyID, ac.kid)
} }
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
@ -450,18 +450,18 @@ func TestACMEClient_NewOrder(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -477,7 +477,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, norb) assert.Equals(t, payload, norb)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if res, err := ac.NewOrder(norb); err != nil { if res, err := ac.NewOrder(norb); err != nil {
@ -572,18 +572,18 @@ func TestACMEClient_GetOrder(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -599,7 +599,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if res, err := ac.GetOrder(url); err != nil { if res, err := ac.GetOrder(url); err != nil {
@ -694,18 +694,18 @@ func TestACMEClient_GetAuthz(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -721,7 +721,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if res, err := ac.GetAuthz(url); err != nil { if res, err := ac.GetAuthz(url); err != nil {
@ -816,18 +816,18 @@ func TestACMEClient_GetChallenge(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -844,7 +844,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if res, err := ac.GetChallenge(url); err != nil { if res, err := ac.GetChallenge(url); err != nil {
@ -939,18 +939,18 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -967,7 +967,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
assert.Equals(t, payload, []byte("{}")) assert.Equals(t, payload, []byte("{}"))
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if err := ac.ValidateChallenge(url); err != nil { if err := ac.ValidateChallenge(url); err != nil {
@ -983,22 +983,22 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) {
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
t.Log(req.RequestURI) t.Log(r.RequestURI)
w.Header().Set("Replay-Nonce", "nonce") w.Header().Set("Replay-Nonce", "nonce")
switch req.RequestURI { switch r.RequestURI {
case "/nonce": case "/nonce":
render.JSONStatus(w, []byte{}, 200) render.JSONStatus(w, r, []byte{}, 200)
return return
case "/fail-nonce": case "/fail-nonce":
render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
@ -1015,15 +1015,15 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, []byte("the-payload")) assert.Equals(t, payload, []byte("the-payload"))
switch req.RequestURI { switch r.RequestURI {
case "/ok": case "/ok":
render.JSONStatus(w, acme.Challenge{ render.JSONStatus(w, r, acme.Challenge{
Type: "device-attestation-01", Type: "device-attestation-01",
Status: "valid", Status: "valid",
Token: "foo", Token: "foo",
}, 200) }, 200)
case "/fail": case "/fail":
render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
} }
})) }))
defer srv.Close() defer srv.Close()
@ -1160,18 +1160,18 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1187,7 +1187,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, frb) assert.Equals(t, payload, frb)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if err := ac.FinalizeOrder(url, csr); err != nil { if err := ac.FinalizeOrder(url, csr); err != nil {
@ -1289,18 +1289,18 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1316,7 +1316,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
}) })
if res, err := tc.client.GetAccountOrders(); err != nil { if res, err := tc.client.GetAccountOrders(); err != nil {
@ -1420,18 +1420,18 @@ func TestACMEClient_GetCertificate(t *testing.T) {
tc := run(t) tc := run(t)
i := 0 i := 0
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
render.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, r, tc.r1, tc.rc1)
i++ i++
return return
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1450,7 +1450,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
if tc.certBytes != nil { if tc.certBytes != nil {
w.Write(tc.certBytes) w.Write(tc.certBytes)
} else { } else {
render.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, r, tc.r2, tc.rc2)
} }
}) })

@ -87,7 +87,7 @@ func startCAServer(configFile string) (*CA, string, error) {
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/version" { if r.URL.Path == "/version" {
render.JSON(w, api.VersionResponse{ render.JSON(w, r, api.VersionResponse{
Version: "test", Version: "test",
RequireClientAuthentication: true, RequireClientAuthentication: true,
}) })
@ -102,7 +102,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
} }
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
if !isMTLS { if !isMTLS {
render.Error(w, errs.Unauthorized("missing peer certificate")) render.Error(w, r, errs.Unauthorized("missing peer certificate"))
} else { } else {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
@ -412,7 +412,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
//nolint:gosec // insecure test server //nolint:gosec // insecure test server
server, err := BootstrapServer(context.Background(), token, &http.Server{ server, err := BootstrapServer(context.Background(), token, &http.Server{
Addr: ":0", Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok")) w.Write([]byte("ok"))
}), }),
}, RequireAndVerifyClientCert()) }, RequireAndVerifyClientCert())
@ -531,7 +531,7 @@ func TestBootstrapClientServerFederation(t *testing.T) {
//nolint:gosec // insecure test server //nolint:gosec // insecure test server
server, err := BootstrapServer(context.Background(), token, &http.Server{ server, err := BootstrapServer(context.Background(), token, &http.Server{
Addr: ":0", Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok")) w.Write([]byte("ok"))
}), }),
}, RequireAndVerifyClientCert(), AddFederationToClientCAs()) }, RequireAndVerifyClientCert(), AddFederationToClientCAs())

@ -177,8 +177,8 @@ func TestClient_Version(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Version() got, err := c.Version()
@ -218,8 +218,8 @@ func TestClient_Health(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Health() got, err := c.Health()
@ -262,12 +262,12 @@ func TestClient_Root(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expected := "/root/" + tt.shasum expected := "/root/" + tt.shasum
if req.RequestURI != expected { if r.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected)
} }
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
@ -323,12 +323,12 @@ func TestClient_Sign(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(r.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
require.True(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, r, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -339,7 +339,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
} }
} }
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
@ -385,12 +385,12 @@ func TestClient_Revoke(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(r.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
require.True(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, r, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -401,7 +401,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
} }
} }
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
@ -450,8 +450,8 @@ func TestClient_Renew(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
@ -504,11 +504,11 @@ func TestClient_RenewWithToken(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" { if r.Header.Get("Authorization") != "Bearer token" {
render.JSONStatus(w, errs.InternalServer("force"), 500) render.JSONStatus(w, r, errs.InternalServer("force"), 500)
} else { } else {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
} }
}) })
@ -567,8 +567,8 @@ func TestClient_Rekey(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Rekey(tt.request, nil) got, err := c.Rekey(tt.request, nil)
@ -619,11 +619,11 @@ func TestClient_Provisioners(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if req.RequestURI != tt.expectedURI { if r.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) t.Errorf("RequestURI = %s, want %s", r.RequestURI, tt.expectedURI)
} }
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
@ -666,12 +666,12 @@ func TestClient_ProvisionerKey(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expected := "/provisioners/" + tt.kid + "/encrypted-key" expected := "/provisioners/" + tt.kid + "/encrypted-key"
if req.RequestURI != expected { if r.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected)
} }
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
@ -720,8 +720,8 @@ func TestClient_Roots(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Roots() got, err := c.Roots()
@ -769,8 +769,8 @@ func TestClient_Federation(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.Federation() got, err := c.Federation()
@ -820,8 +820,8 @@ func TestClient_SSHRoots(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.SSHRoots() got, err := c.SSHRoots()
@ -912,8 +912,8 @@ func TestClient_RootFingerprint(t *testing.T) {
c, err := NewClient(tt.server.URL, WithTransport(tr)) c, err := NewClient(tt.server.URL, WithTransport(tr))
require.NoError(t, err) require.NoError(t, err)
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()
@ -970,8 +970,8 @@ func TestClient_SSHBastion(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
require.NoError(t, err) require.NoError(t, err)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, r, tt.response, tt.responseCode)
}) })
got, err := c.SSHBastion(tt.request) got, err := c.SSHBastion(tt.request)

@ -97,7 +97,7 @@ func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc
func Get(w http.ResponseWriter, r *http.Request) { func Get(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
fail(w, fmt.Errorf("invalid scep get request: %w", err)) fail(w, r, fmt.Errorf("invalid scep get request: %w", err))
return return
} }
@ -116,18 +116,18 @@ func Get(w http.ResponseWriter, r *http.Request) {
} }
if err != nil { if err != nil {
fail(w, fmt.Errorf("scep get request failed: %w", err)) fail(w, r, fmt.Errorf("scep get request failed: %w", err))
return return
} }
writeResponse(w, res) writeResponse(w, r, res)
} }
// Post handles all SCEP POST requests // Post handles all SCEP POST requests
func Post(w http.ResponseWriter, r *http.Request) { func Post(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
fail(w, fmt.Errorf("invalid scep post request: %w", err)) fail(w, r, fmt.Errorf("invalid scep post request: %w", err))
return return
} }
@ -140,11 +140,11 @@ func Post(w http.ResponseWriter, r *http.Request) {
} }
if err != nil { if err != nil {
fail(w, fmt.Errorf("scep post request failed: %w", err)) fail(w, r, fmt.Errorf("scep post request failed: %w", err))
return return
} }
writeResponse(w, res) writeResponse(w, r, res)
} }
func decodeRequest(r *http.Request) (request, error) { func decodeRequest(r *http.Request) (request, error) {
@ -274,7 +274,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
name := chi.URLParam(r, "provisionerName") name := chi.URLParam(r, "provisionerName")
provisionerName, err := url.PathUnescape(name) provisionerName, err := url.PathUnescape(name)
if err != nil { if err != nil {
fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name)) fail(w, r, fmt.Errorf("error url unescaping provisioner name '%s'", name))
return return
} }
@ -282,13 +282,13 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
auth := authority.MustFromContext(ctx) auth := authority.MustFromContext(ctx)
p, err := auth.LoadProvisionerByName(provisionerName) p, err := auth.LoadProvisionerByName(provisionerName)
if err != nil { if err != nil {
fail(w, err) fail(w, r, err)
return return
} }
prov, ok := p.(*provisioner.SCEP) prov, ok := p.(*provisioner.SCEP)
if !ok { if !ok {
fail(w, errors.New("provisioner must be of type SCEP")) fail(w, r, errors.New("provisioner must be of type SCEP"))
return return
} }
@ -430,9 +430,9 @@ func formatCapabilities(caps []string) []byte {
} }
// writeResponse writes a SCEP response back to the SCEP client. // writeResponse writes a SCEP response back to the SCEP client.
func writeResponse(w http.ResponseWriter, res Response) { func writeResponse(w http.ResponseWriter, r *http.Request, res Response) {
if res.Error != nil { if res.Error != nil {
log.Error(w, res.Error) log.Error(w, r, res.Error)
} }
if res.Certificate != nil { if res.Certificate != nil {
@ -443,8 +443,8 @@ func writeResponse(w http.ResponseWriter, res Response) {
_, _ = w.Write(res.Data) _, _ = w.Write(res.Data)
} }
func fail(w http.ResponseWriter, err error) { func fail(w http.ResponseWriter, r *http.Request, err error) {
log.Error(w, err) log.Error(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }

Loading…
Cancel
Save