package api import ( "context" "crypto/rsa" "errors" "io" "net/http" "net/url" "strings" "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/nosql" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" ) type nextHTTP = func(http.ResponseWriter, *http.Request) func logNonce(w http.ResponseWriter, nonce string) { if rl, ok := w.(logging.ResponseLogger); ok { m := map[string]interface{}{ "nonce": nonce, } rl.WithFields(m) } } // baseURLFromRequest determines the base URL which should be used for // constructing link URLs in e.g. the ACME directory result by taking the // request Host into consideration. // // If the Request.Host is an empty string, we return an empty string, to // indicate that the configured URL values should be used instead. If this // function returns a non-empty result, then this should be used in // constructing ACME link URLs. func baseURLFromRequest(r *http.Request) *url.URL { // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go // for an implementation that allows HTTP requests using the x-forwarded-proto // header. if r.Host == "" { return nil } return &url.URL{Scheme: "https", Host: r.Host} } // baseURLFromRequest is a middleware that extracts and caches the baseURL // from the request. // E.g. https://ca.smallstep.com/ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } // addNonce is a middleware that adds a nonce to the response header. func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { nonce, err := h.db.CreateNonce(r.Context()) if err != nil { api.WriteError(w, err) return } w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Cache-Control", "no-store") logNonce(w, string(nonce)) next(w, r) } } // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { // By default every request should have content-type applictaion/jose+json. expected = []string{"application/jose+json"} } ct := r.Header.Get("Content-Type") for _, e := range expected { if ct == e { next(w, r) return } } api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) next(w, r.WithContext(ctx)) } } // validateJWS checks the request body for to verify that it meets ACME // requirements for a JWS. // // The JWS MUST NOT have multiple signatures // The JWS Unencoded Payload Option [RFC7797] MUST NOT be used // The JWS Unprotected Header [RFC7515] MUST NOT be used // The JWS Payload MUST NOT be detached // The JWS Protected Header MUST include the following fields: // * “alg” (Algorithm) // * This field MUST NOT contain “none” or a Message Authentication Code // (MAC) algorithm (e.g. one in which the algorithm registry description // mentions MAC/HMAC). // * “nonce” (defined in Section 6.5) // * “url” (defined in Section 6.4) // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func (h *Handler) validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } if len(jws.Signatures) == 0 { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } sig := jws.Signatures[0] uh := sig.Unprotected if len(uh.KeyID) > 0 || uh.JSONWebKey != nil || len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected switch hdr.Algorithm { case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512: if hdr.JSONWebKey != nil { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } } case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { api.WriteError(w, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) } } // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } // Store the JWK in the context. ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account or continue to generate a new one. acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): // For NewAccount requests ... break case err != nil: api.WriteError(w, err) return default: if !acc.IsValid() { api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) } next(w, r.WithContext(ctx)) } } // lookupProvisioner loads the provisioner associated with the request. // Responsds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := h.ca.LoadProvisionerByName(name) if err != nil { api.WriteError(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) next(w, r.WithContext(ctx)) } } // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return } accID := strings.TrimPrefix(kid, kidPrefix) acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: api.WriteError(w, err) return default: if !acc.IsValid() { api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } } } // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ value: payload, isPostAsGet: len(payload) == 0, isEmptyJSON: string(payload) == "{}", }) next(w, r.WithContext(ctx)) } } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } if !payload.isPostAsGet { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) } } // ContextKey is the key type for storing and searching for ACME request // essentials in the context of a request. type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") // baseURLContextKey baseURL key baseURLContextKey = ContextKey("baseURL") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") // provisionerContextKey provisioner key provisionerContextKey = ContextKey("provisioner") ) // accountFromContext searches the context for an ACME account. Returns the // account or an error. func accountFromContext(ctx context.Context) (*acme.Account, error) { val, ok := ctx.Value(accContextKey).(*acme.Account) if !ok || val == nil { return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context") } return val, nil } // baseURLFromContext returns the baseURL if one is stored in the context. func baseURLFromContext(ctx context.Context) *url.URL { val, ok := ctx.Value(baseURLContextKey).(*url.URL) if !ok || val == nil { return nil } return val } // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) if !ok || val == nil { return nil, acme.NewErrorISE("jwk expected in request context") } return val, nil } // jwsFromContext searches the context for a JWS. Returns the JWS or an error. func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature) if !ok || val == nil { return nil, acme.NewErrorISE("jws expected in request context") } return val, nil } // provisionerFromContext searches the context for a provisioner. Returns the // provisioner or an error. func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { val := ctx.Value(provisionerContextKey) if val == nil { return nil, acme.NewErrorISE("provisioner expected in request context") } pval, ok := val.(acme.Provisioner) if !ok || pval == nil { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } return pval, nil } // payloadFromContext searches the context for a payload. Returns the payload // or an error. func payloadFromContext(ctx context.Context) (*payloadInfo, error) { val, ok := ctx.Value(payloadContextKey).(*payloadInfo) if !ok || val == nil { return nil, acme.NewErrorISE("payload expected in request context") } return val, nil }