From 80a6640103132db93fa8d45edf4c3f11a2f2533b Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 4 Mar 2021 23:10:46 -0800 Subject: [PATCH] [acme db interface] wip --- acme/account.go | 4 +- acme/api/account.go | 92 ++++---- acme/api/handler.go | 220 +++++++++++++------ acme/api/linker.go | 164 +++++++++++++++ acme/api/middleware.go | 195 ++++++++++++----- acme/api/order.go | 170 ++++++++++++--- acme/authority.go | 420 ------------------------------------- acme/authorization.go | 16 +- acme/certificate.go | 3 +- acme/challenge.go | 45 ++-- acme/common.go | 104 ++------- acme/db/nosql/account.go | 1 - acme/db/nosql/authz.go | 20 +- acme/db/nosql/challenge.go | 3 +- acme/db/nosql/order.go | 73 +++---- acme/directory.go | 148 ------------- acme/errors.go | 34 +-- acme/nonce.go | 6 + acme/order.go | 17 +- ca/acmeClient.go | 12 +- ca/ca.go | 7 +- 21 files changed, 783 insertions(+), 971 deletions(-) create mode 100644 acme/api/linker.go delete mode 100644 acme/authority.go delete mode 100644 acme/directory.go diff --git a/acme/account.go b/acme/account.go index 80cc66ef..354ebdc7 100644 --- a/acme/account.go +++ b/acme/account.go @@ -22,7 +22,7 @@ type Account struct { func (a *Account) ToLog() (interface{}, error) { b, err := json.Marshal(a) if err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error marshaling account for logging") + return nil, WrapErrorISE(err, "error marshaling account for logging") } return string(b), nil } @@ -46,7 +46,7 @@ func (a *Account) IsValid() bool { func KeyToID(jwk *jose.JSONWebKey) (string, error) { kid, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ErrorWrap(ErrorServerInternalType, err, "error generating jwk thumbprint") + return "", WrapErrorISE(err, "error generating jwk thumbprint") } return base64.RawURLEncoding.EncodeToString(kid), nil } diff --git a/acme/api/account.go b/acme/api/account.go index 5e208a5f..16cc1f79 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -4,8 +4,6 @@ import ( "encoding/json" "net/http" - "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/logging" @@ -37,14 +35,8 @@ func (n *NewAccountRequest) Validate() error { // UpdateAccountRequest represents an update-account request. type UpdateAccountRequest struct { - Contact []string `json:"contact"` - Status string `json:"status"` -} - -// IsDeactivateRequest returns true if the update request is a deactivation -// request, false otherwise. -func (u *UpdateAccountRequest) IsDeactivateRequest() bool { - return u.Status == string(acme.StatusDeactivated) + Contact []string `json:"contact"` + Status acme.Status `json:"status"` } // Validate validates a update-account request body. @@ -59,7 +51,7 @@ func (u *UpdateAccountRequest) Validate() error { } return nil case len(u.Status) > 0: - if u.Status != string(acme.StatusDeactivated) { + if u.Status != acme.StatusDeactivated { return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ "status to %s, only deactivated", u.Status) } @@ -80,7 +72,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } @@ -90,7 +82,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } httpStatus := http.StatusCreated - acc, err := acme.AccountFromContext(r.Context()) + acc, err := accountFromContext(r.Context()) if err != nil { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { @@ -105,18 +97,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { "account does not exist")) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - if acc, err = h.Auth.NewAccount(r.Context(), &acme.Account{ + acc := &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, - }); err != nil { - api.WriteError(w, err) + } + if err := h.db.CreateAccount(r.Context(), acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) return } } else { @@ -124,14 +117,16 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, + true, acc.ID)) api.JSONStatus(w, acc, httpStatus) } // GetUpdateAccount is the api for updating an ACME account. func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + acc, err := accountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -147,7 +142,7 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } @@ -159,18 +154,18 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { // If neither the status nor the contacts are being updated then ignore // the updates and return 200. This conforms with the behavior detailed // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). - if uar.IsDeactivateRequest() { - acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID()) - } else if len(uar.Contact) > 0 { - acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact) - } - if err != nil { - api.WriteError(w, err) + acc.Status = uar.Status + acc.Contact = uar.Contact + if err = h.db.UpdateAccount(r.Context(), acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) return } } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, + true, acc.ID)) api.JSON(w, acc) } @@ -185,21 +180,24 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - accID := chi.URLParam(r, "accID") - if acc.ID != accID { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) - return - } - orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, orders) - logOrdersByAccount(w, orders) + /* + acc, err := acme.AccountFromContext(r.Context()) + if err != nil { + api.WriteError(w, err) + return + } + accID := chi.URLParam(r, "accID") + if acc.ID != accID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param")) + return + } + orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + api.JSON(w, orders) + logOrdersByAccount(w, orders) + */ + return } diff --git a/acme/api/handler.go b/acme/api/handler.go index 921e614e..997456a7 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,56 +1,82 @@ package api import ( - "context" - "crypto/x509" - "encoding/pem" + "crypto/tls" + "encoding/json" "fmt" + "net" "net/http" + "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" ) func link(url, typ string) string { return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) } +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) + type payloadInfo struct { value []byte isPostAsGet bool isEmptyJSON bool } -// payloadFromContext searches the context for a payload. Returns the payload -// or an error. -func payloadFromContext(ctx context.Context) (*payloadInfo, error) { - val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) - } - return val, nil +// Handler is the ACME API request handler. +type Handler struct { + db acme.DB + backdate provisioner.Duration + ca acme.CertificateAuthority + linker *Linker } -// New returns a new ACME API router. -func New(acmeAuth acme.Interface) api.RouterHandler { - return &Handler{acmeAuth} +// HandlerOptions required to create a new ACME API request handler. +type HandlerOptions struct { + Backdate provisioner.Duration + // DB storage backend that impements the acme.DB interface. + DB acme.DB + // DNS the host used to generate accurate ACME links. By default the authority + // will use the Host from the request, so this value will only be used if + // request.Host is empty. + DNS string + // Prefix is a URL path prefix under which the ACME api is served. This + // prefix is required to generate accurate ACME links. + // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- + // "acme" is the prefix from which the ACME api is accessed. + Prefix string + CA acme.CertificateAuthority } -// Handler is the ACME request handler. -type Handler struct { - Auth acme.Interface +// NewHandler returns a new ACME API handler. +func NewHandler(ops HandlerOptions) api.RouterHandler { + return &Handler{ + ca: ops.CA, + db: ops.DB, + backdate: ops.Backdate, + linker: NewLinker(ops.DNS, ops.Prefix), + } } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - getLink := h.Auth.GetLinkExplicit + getLink := h.linker.GetLinkExplicit // Standard ACME API - r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) - r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) @@ -59,16 +85,16 @@ func (h *Handler) Route(r api.Router) { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) } - r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) - r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) - r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) - r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) + r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) + r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } // GetNonce just sets the right header since a Nonce is added to each response @@ -81,101 +107,165 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { } } +// Directory represents an ACME directory for configuring clients. +type Directory struct { + NewNonce string `json:"newNonce,omitempty"` + NewAccount string `json:"newAccount,omitempty"` + NewOrder string `json:"newOrder,omitempty"` + NewAuthz string `json:"newAuthz,omitempty"` + RevokeCert string `json:"revokeCert,omitempty"` + KeyChange string `json:"keyChange,omitempty"` +} + +// ToLog enables response logging for the Directory type. +func (d *Directory) ToLog() (interface{}, error) { + b, err := json.Marshal(d) + if err != nil { + return nil, acme.WrapErrorISE(err, "error marshaling directory for logging") + } + return string(b), nil +} + +type directory struct { + prefix, dns string +} + // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { - dir, err := h.Auth.GetDirectory(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, dir) + ctx := r.Context() + api.JSON(w, &Directory{ + NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true), + NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true), + NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true), + RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true), + KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true), + }) } // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NotImplemented(nil).ToACME()) + api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthz ACME api for retrieving an Authz. func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID")) + az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + return + } + if acc.ID != az.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } + if err = az.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + } + + h.linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID())) - api.JSON(w, authz) + w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID)) + api.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. - _, err = payloadFromContext(r.Context()) + _, err = payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return } - // NOTE: We should be checking that the request is either a POST-as-GET, or + // NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or // that the payload is an empty JSON block ({}). However, older ACME clients // still send a vestigial body (rather than an empty JSON block) and // strict enforcement would render these clients broken. For the time being // we'll just ignore the body. - var ( - ch *acme.Challenge - chID = chi.URLParam(r, "chID") - ) - ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey()) + + ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), chi.URLParam(r, "authzID")) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + return + } + if acc.ID != ch.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) + return + } + client := http.Client{ + Timeout: time.Duration(30 * time.Second), + } + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + } + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } + if err = ch.Validate(ctx, h.db, jwk, acme.ValidateOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) + }, + }); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + return + } - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up")) - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID())) + h.linker.LinkChallenge(ctx, ch) + + w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, ch.AuthzID), "up")) + w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID)) api.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } certID := chi.URLParam(r, "certID") - certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID) + + cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } - - block, _ := pem.Decode(certBytes) - if block == nil { - api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes"))) + if cert.AccountID != acc.ID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own certificate '%s'", acc.ID, certID)) return } - cert, err := x509.ParseCertificate(block.Bytes) + + certBytes, err := cert.ToACME() if err != nil { - api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate")) + api.WriteError(w, acme.WrapErrorISE(err, "error converting cert to ACME representation")) return } - api.LogCertificate(w, cert) + api.LogCertificate(w, cert.Leaf) w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8") w.Write(certBytes) } diff --git a/acme/api/linker.go b/acme/api/linker.go new file mode 100644 index 00000000..dd3b4540 --- /dev/null +++ b/acme/api/linker.go @@ -0,0 +1,164 @@ +package api + +import ( + "context" + "fmt" + "net/url" + + "github.com/smallstep/certificates/acme" +) + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) *Linker { + return &Linker{Prefix: prefix, DNS: dns} +} + +// Linker generates ACME links. +type Linker struct { + Prefix string + DNS string +} + +// GetLink is a helper for GetLinkExplicit +func (l *Linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { + var provName string + if p, err := provisionerFromContext(ctx); err == nil && p != nil { + provName = p.GetName() + } + return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...) +} + +// GetLinkExplicit returns an absolute or partial path to the given resource and a base +// URL dynamically obtained from the request for which the link is being +// calculated. +func (l *Linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { + var link string + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + link = fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + link = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + } + + if abs { + // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 + u := url.URL{} + if baseURL != nil { + u = *baseURL + } + + // If no Scheme is set, then default to https. + if u.Scheme == "" { + u.Scheme = "https" + } + + // If no Host is set, then use the default (first DNS attr in the ca.json). + if u.Host == "" { + u.Host = l.DNS + } + + u.Path = l.Prefix + link + return u.String() + } + return link +} + +// LinkType captures the link type. +type LinkType int + +const ( + // NewNonceLinkType new-nonce + NewNonceLinkType LinkType = iota + // NewAccountLinkType new-account + NewAccountLinkType + // AccountLinkType account + AccountLinkType + // OrderLinkType order + OrderLinkType + // NewOrderLinkType new-order + NewOrderLinkType + // OrdersByAccountLinkType list of orders owned by account + OrdersByAccountLinkType + // FinalizeLinkType finalize order + FinalizeLinkType + // NewAuthzLinkType authz + NewAuthzLinkType + // AuthzLinkType new-authz + AuthzLinkType + // ChallengeLinkType challenge + ChallengeLinkType + // CertificateLinkType certificate + CertificateLinkType + // DirectoryLinkType directory + DirectoryLinkType + // RevokeCertLinkType revoke certificate + RevokeCertLinkType + // KeyChangeLinkType key rollover + KeyChangeLinkType +) + +func (l LinkType) String() string { + switch l { + case NewNonceLinkType: + return "new-nonce" + case NewAccountLinkType: + return "new-account" + case AccountLinkType: + return "account" + case NewOrderLinkType: + return "new-order" + case OrderLinkType: + return "order" + case NewAuthzLinkType: + return "new-authz" + case AuthzLinkType: + return "authz" + case ChallengeLinkType: + return "challenge" + case CertificateLinkType: + return "certificate" + case DirectoryLinkType: + return "directory" + case RevokeCertLinkType: + return "revoke-cert" + case KeyChangeLinkType: + return "key-change" + default: + return fmt.Sprintf("unexpected LinkType '%d'", int(l)) + } +} + +// LinkOrder sets the ACME links required by an ACME order. +func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) { + o.azURLs = make([]string, len(o.AuthorizationIDs)) + for i, azID := range o.AutohrizationIDs { + o.azURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) + } + o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) + if o.CertificateID != "" { + o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID) + } +} + +// LinkAccount sets the ACME links required by an ACME account. +func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) { + a.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) +} + +// LinkChallenge sets the ACME links required by an ACME account. +func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) { + a.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID) +} + +// LinkAuthorization sets the ACME links required by an ACME account. +func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { + for _, ch := range az.Challenges { + l.LinkChallenge(ctx, ch) + } +} diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 3bf5f89a..7a3529cd 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" @@ -54,7 +53,7 @@ func baseURLFromRequest(r *http.Request) *url.URL { // E.g. https://ca.smallstep.com/ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r)) + ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } @@ -62,14 +61,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { // addNonce is a middleware that adds a nonce to the response header. func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.Auth.NewNonce() + nonce, err := h.db.CreateNonce(r.Context()) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Replay-Nonce", nonce) + w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Cache-Control", "no-store") - logNonce(w, nonce) + logNonce(w, string(nonce)) next(w, r) } } @@ -78,8 +77,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), - acme.DirectoryLink, true), "index")) + w.Header().Add("Link", link(h.linker.GetLink(r.Context(), + DirectoryLinkType, true), "index")) next(w, r) } } @@ -90,7 +89,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ct := r.Header.Get("Content-Type") var expected []string - if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) { + if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { @@ -103,8 +102,8 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.MalformedErr(errors.Errorf( - "expected content-type to be in %s, but got %s", expected, ct))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -113,15 +112,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body"))) + api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } - ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws) + ctx := context.WithValue(r.Context(), jwsContextKey, jws) next(w, r.WithContext(ctx)) } } @@ -143,17 +142,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func (h *Handler) validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -164,7 +164,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -174,25 +174,26 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+ - "keys must be at least %d bits (%d bytes) in size", - 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "rsa keys must be at least %d bits (%d bytes) in size", + 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "jws key type and algorithm do not match")) return } } case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. - if err := h.Auth.UseNonce(hdr.Nonce); err != nil { + if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { api.WriteError(w, err) return } @@ -200,21 +201,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -227,22 +229,27 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + return + } + ctx = context.WithValue(ctx, jwkContextKey, jwk) + kid, err := acme.KeyToID(jwk) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - acc, err := h.Auth.GetAccountByKey(ctx, jwk) + acc, err := h.db.GetAccountByKeyID(ctx, kid) switch { case nosql.IsErrNotFound(err): // For NewAccount requests ... @@ -252,10 +259,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, accContextKey, acc) } next(w, r.WithContext(ctx)) } @@ -270,20 +277,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { name := chi.URLParam(r, "provisionerID") provID, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name))) + api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name)) return } - p, err := h.Auth.LoadProvisionerByID("acme/" + provID) + p, err := h.ca.LoadProvisionerByID("acme/" + provID) if err != nil { api.WriteError(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } - ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv)) + ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) next(w, r.WithContext(ctx)) } } @@ -294,36 +301,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "") + kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ - "required prefix; expected %s, but got %s", kidPrefix, kid))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "kid does not have required prefix; expected %s, but got %s", + kidPrefix, kid)) return } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.Auth.GetAccount(r.Context(), accID) + acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: api.WriteError(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } @@ -334,26 +342,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } - ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{ + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ value: payload, isPostAsGet: string(payload) == "", isEmptyJSON: string(payload) == "{}", @@ -371,9 +380,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return } if !payload.isPostAsGet { - api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) } } + +// ContextKey is the key type for storing and searching for ACME request +// essentials in the context of a request. +type ContextKey string + +const ( + // accContextKey account key + accContextKey = ContextKey("acc") + // baseURLContextKey baseURL key + baseURLContextKey = ContextKey("baseURL") + // jwsContextKey jws key + jwsContextKey = ContextKey("jws") + // jwkContextKey jwk key + jwkContextKey = ContextKey("jwk") + // payloadContextKey payload key + payloadContextKey = ContextKey("payload") + // provisionerContextKey provisioner key + provisionerContextKey = ContextKey("provisioner") +) + +// accountFromContext searches the context for an ACME account. Returns the +// account or an error. +func accountFromContext(ctx context.Context) (*acme.Account, error) { + val, ok := ctx.Value(accContextKey).(*acme.Account) + if !ok || val == nil { + return nil, acme.NewErrorISE("account not in context") + } + return val, nil +} + +// baseURLFromContext returns the baseURL if one is stored in the context. +func baseURLFromContext(ctx context.Context) *url.URL { + val, ok := ctx.Value(baseURLContextKey).(*url.URL) + if !ok || val == nil { + return nil + } + return val +} + +// jwkFromContext searches the context for a JWK. Returns the JWK or an error. +func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { + val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) + if !ok || val == nil { + return nil, acme.NewErrorISE("jwk expected in request context") + } + return val, nil +} + +// jwsFromContext searches the context for a JWS. Returns the JWS or an error. +func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { + val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature) + if !ok || val == nil { + return nil, acme.NewErrorISE("jws expected in request context") + } + return val, nil +} + +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { + val := ctx.Value(provisionerContextKey) + if val == nil { + return nil, acme.NewErrorISE("provisioner expected in request context") + } + pval, ok := val.(acme.Provisioner) + if !ok || pval == nil { + return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") + } + return pval, nil +} + +// payloadFromContext searches the context for a payload. Returns the payload +// or an error. +func payloadFromContext(ctx context.Context) (*payloadInfo, error) { + val, ok := ctx.Value(payloadContextKey).(*payloadInfo) + if !ok || val == nil { + return nil, acme.NewErrorISE("payload expected in request context") + } + return val, nil +} diff --git a/acme/api/order.go b/acme/api/order.go index 1fead85c..2bf7d2ef 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -1,16 +1,18 @@ package api import ( + "context" "crypto/x509" "encoding/base64" "encoding/json" "net/http" + "strings" "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" ) // NewOrderRequest represents the body for a NewOrder request. @@ -23,11 +25,11 @@ type NewOrderRequest struct { // Validate validates a new-order request body. func (n *NewOrderRequest) Validate() error { if len(n.Identifiers) == 0 { - return acme.NewError(ErrorMalformedType, "identifiers list cannot be empty") + return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { if id.Type != "dns" { - return acme.NewError(ErrorMalformedType, "identifier type unsupported: %s", id.Type) + return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } } return nil @@ -44,22 +46,29 @@ func (f *FinalizeRequest) Validate() error { var err error csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") } f.csr, err = x509.ParseCertificateRequest(csrBytes) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "unable to parse csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr") } if err = f.csr.CheckSignature(); err != nil { - return acme.MalformedErr(errors.Wrap(err, "csr failed signature check")) + return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check") } return nil } +var defaultOrderExpiry = time.Hour * 24 + // NewOrder ACME api for creating a new order. func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -71,8 +80,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, - "failed to unmarshal new-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { @@ -80,44 +89,133 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{ - AccountID: acc.GetID(), - Identifiers: nor.Identifiers, - NotBefore: nor.NotBefore, - NotAfter: nor.NotAfter, - }) - if err != nil { - api.WriteError(w, err) + // New order. + o := &acme.Order{Identifiers: nor.Identifiers} + + o.AuthorizationIDs = make([]string, len(o.Identifiers)) + for i, identifier := range o.Identifiers { + az := &acme.Authorization{ + AccountID: acc.ID, + Identifier: identifier, + } + if err := h.newAuthorization(ctx, az); err != nil { + api.WriteError(w, err) + return + } + o.AuthorizationIDs[i] = az.ID + } + + now := clock.Now() + if o.NotBefore.IsZero() { + o.NotBefore = now + } + if o.NotAfter.IsZero() { + o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) + } + o.Expires = now.Add(defaultOrderExpiry) + + if err := h.db.CreateOrder(ctx, o); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) return } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.Link(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSONStatus(w, o, http.StatusCreated) } +func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { + if strings.HasPrefix(az.Identifier.Value, "*.") { + az.Wildcard = true + az.Identifier = acme.Identifier{ + Value: strings.TrimPrefix(az.Identifier.Value, "*."), + Type: az.Identifier.Type, + } + } + + var ( + err error + chTypes = []string{"dns-01"} + ) + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !az.Wildcard { + chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) + } + + az.Token, err = randutil.Alphanumeric(32) + if err != nil { + return acme.WrapErrorISE(err, "error generating random alphanumeric ID") + } + + az.Challenges = make([]*acme.Challenge, len(chTypes)) + for i, typ := range chTypes { + ch := &acme.Challenge{ + AccountID: az.AccountID, + AuthzID: az.ID, + Value: az.Identifier.Value, + Type: typ, + Token: az.Token, + } + if err := h.db.CreateChallenge(ctx, ch); err != nil { + return err + } + az.Challenges[i] = ch + } + if err = h.db.CreateAuthorization(ctx, az); err != nil { + return acme.WrapErrorISE(err, "error creating authorization") + } + return nil +} + // GetOrder ACME api for retrieving an order. func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid) + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return } + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + return + } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + return + } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.LinkOrder(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -129,7 +227,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { @@ -137,13 +236,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr) + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) return } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + return + } + + h.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } diff --git a/acme/authority.go b/acme/authority.go deleted file mode 100644 index 92e1c8f7..00000000 --- a/acme/authority.go +++ /dev/null @@ -1,420 +0,0 @@ -package acme - -import ( - "context" - "crypto/tls" - "crypto/x509" - "log" - "net" - "net/http" - "net/url" - "strings" - "time" - - "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" -) - -// Interface is the acme authority interface. -type Interface interface { - GetDirectory(ctx context.Context) (*Directory, error) - NewNonce() (string, error) - UseNonce(string) error - - DeactivateAccount(ctx context.Context, accID string) (*Account, error) - GetAccount(ctx context.Context, accID string) (*Account, error) - GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error) - NewAccount(ctx context.Context, acc *Account) (*Account, error) - UpdateAccount(ctx context.Context, acc *Account) (*Account, error) - - GetAuthz(ctx context.Context, accID string, authzID string) (*Authorization, error) - ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error) - - FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error) - GetOrder(ctx context.Context, accID string, orderID string) (*Order, error) - GetOrdersByAccount(ctx context.Context, accID string) ([]string, error) - NewOrder(ctx context.Context, o *Order) (*Order, error) - - GetCertificate(string, string) ([]byte, error) - - LoadProvisionerByID(string) (provisioner.Interface, error) - GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string - GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string -} - -// Authority is the layer that handles all ACME interactions. -type Authority struct { - backdate provisioner.Duration - db DB - dir *directory - signAuth SignAuthority -} - -// AuthorityOptions required to create a new ACME Authority. -type AuthorityOptions struct { - Backdate provisioner.Duration - // DB storage backend that impements the acme.DB interface. - DB DB - // DNS the host used to generate accurate ACME links. By default the authority - // will use the Host from the request, so this value will only be used if - // request.Host is empty. - DNS string - // Prefix is a URL path prefix under which the ACME api is served. This - // prefix is required to generate accurate ACME links. - // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- - // "acme" is the prefix from which the ACME api is accessed. - Prefix string -} - -// NewAuthority returns a new Authority that implements the ACME interface. -// -// Deprecated: NewAuthority exists for hitorical compatibility and should not -// be used. Use acme.New() instead. -func NewAuthority(db DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { - return New(signAuth, AuthorityOptions{ - DB: db, - DNS: dns, - Prefix: prefix, - }) -} - -// New returns a new Authority that implements the ACME interface. -func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { - return &Authority{ - backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, - }, nil -} - -// GetLink returns the requested link from the directory. -func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - return a.dir.getLink(ctx, typ, abs, inputs...) -} - -// GetLinkExplicit returns the requested link from the directory. -func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string { - return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...) -} - -// GetDirectory returns the ACME directory object. -func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) { - return &Directory{ - NewNonce: a.dir.getLink(ctx, NewNonceLink, true), - NewAccount: a.dir.getLink(ctx, NewAccountLink, true), - NewOrder: a.dir.getLink(ctx, NewOrderLink, true), - RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true), - KeyChange: a.dir.getLink(ctx, KeyChangeLink, true), - }, nil -} - -// LoadProvisionerByID calls out to the SignAuthority interface to load a -// provisioner by ID. -func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { - return a.signAuth.LoadProvisionerByID(id) -} - -// NewNonce generates, stores, and returns a new ACME nonce. -func (a *Authority) NewNonce(ctx context.Context) (Nonce, error) { - return a.db.CreateNonce(ctx) -} - -// UseNonce consumes the given nonce if it is valid, returns error otherwise. -func (a *Authority) UseNonce(ctx context.Context, nonce string) error { - return a.db.DeleteNonce(ctx, Nonce(nonce)) -} - -// NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(ctx context.Context, acc *Account) error { - if err := a.db.CreateAccount(ctx, acc); err != nil { - return ErrorISEWrap(err, "error creating account") - } - return nil -} - -// UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, error) { - /* - acc.Contact = auo.Contact - acc.Status = auo.Status - */ - if err := a.db.UpdateAccount(ctx, acc); err != nil { - return nil, ErrorISEWrap(err, "error updating account") - } - return acc, nil -} - -// GetAccount returns an ACME account. -func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { - acc, err := a.db.GetAccount(ctx, id) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving account") - } - return acc, nil -} - -// GetAccountByKey returns the ACME associated with the jwk id. -func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { - kid, err := KeyToID(jwk) - if err != nil { - return nil, err - } - acc, err := a.db.GetAccountByKeyID(ctx, kid) - return acc, err -} - -// GetOrder returns an ACME order. -func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - o, err := a.db.GetOrder(ctx, orderID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving order") - } - if accID != o.AccountID { - log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if prov.GetID() != o.ProvisionerID { - log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") - } - if err = o.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorISEWrap(err, "error updating order") - } - return o, nil -} - -/* -// GetOrdersByAccount returns the list of order urls owned by the account. -func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - ordersByAccountMux.Lock() - defer ordersByAccountMux.Unlock() - - var oiba = orderIDsByAccount{} - oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id) - if err != nil { - return nil, err - } - - var ret = []string{} - for _, oid := range oids { - ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid)) - } - return ret, nil -} -*/ - -// NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(ctx context.Context, o *Order) error { - if len(o.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(o.ProvisionerID) == 0 { - return NewErrorISE("provisioner-id cannot be empty") - } - if len(o.Identifiers) == 0 { - return NewErrorISE("identifiers cannot be empty") - } - if o.DefaultDuration == 0 { - return NewErrorISE("default-duration cannot be empty") - } - - o.AuthorizationIDs = make([]string, len(o.Identifiers)) - for i, identifier := range o.Identifiers { - az := &Authorization{ - AccountID: o.AccountID, - Identifier: identifier, - } - if err := a.NewAuthorization(ctx, az); err != nil { - return err - } - o.AuthorizationIDs[i] = az.ID - } - - now := clock.Now() - if o.NotBefore.IsZero() { - o.NotBefore = now - } - if o.NotAfter.IsZero() { - o.NotAfter = o.NotBefore.Add(o.DefaultDuration) - } - - if err := a.db.CreateOrder(ctx, o); err != nil { - return ErrorISEWrap(err, "error creating order") - } - return nil - /* - o.DefaultDuration = prov.DefaultTLSCertDuration() - o.Backdate = a.backdate.Duration - o.ProvisionerID = prov.GetID() - - if err = a.db.CreateOrder(ctx, o); err != nil { - return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order") - } - return o, nil - */ -} - -// FinalizeOrder attempts to finalize an order and generate a new certificate. -func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - o, err := a.db.GetOrder(ctx, orderID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving order") - } - if accID != o.AccountID { - log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if prov.GetID() != o.ProvisionerID { - log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) - return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order") - } - if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil { - return nil, ErrorISEWrap(err, "error finalizing order") - } - return o, nil -} - -// NewAuthorization generates and stores an ACME Authorization type along with -// any associated resources. -func (a *Authority) NewAuthorization(ctx context.Context, az *Authorization) error { - if len(az.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(az.Identifier.Value) == 0 { - return NewErrorISE("identifier cannot be empty") - } - - if strings.HasPrefix(az.Identifier.Value, "*.") { - az.Wildcard = true - az.Identifier = Identifier{ - Value: strings.TrimPrefix(az.Identifier.Value, "*."), - Type: az.Identifier.Type, - } - } - - var ( - err error - chTypes = []string{"dns-01"} - ) - // HTTP and TLS challenges can only be used for identifiers without wildcards. - if !az.Wildcard { - chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) - } - - az.Token, err = randutil.Alphanumeric(32) - if err != nil { - return ErrorISEWrap(err, "error generating random alphanumeric ID") - } - - az.Challenges = make([]*Challenge, len(chTypes)) - for i, typ := range chTypes { - ch := &Challenge{ - AccountID: az.AccountID, - AuthzID: az.ID, - Value: az.Identifier.Value, - Type: typ, - Token: az.Token, - } - if err := a.NewChallenge(ctx, ch); err != nil { - return err - } - az.Challenges[i] = ch - } - if err = a.db.CreateAuthorization(ctx, az); err != nil { - return ErrorISEWrap(err, "error creating authorization") - } - return nil -} - -// GetAuthorization retrieves and attempts to update the status on an ACME authz -// before returning. -func (a *Authority) GetAuthorization(ctx context.Context, accID, authzID string) (*Authorization, error) { - az, err := a.db.GetAuthorization(ctx, authzID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving authorization") - } - if accID != az.AccountID { - log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - if err = az.UpdateStatus(ctx, a.db); err != nil { - return nil, ErrorISEWrap(err, "error updating authorization status") - } - return az, nil -} - -// NewChallenge generates and stores an ACME challenge and associated resources. -func (a *Authority) NewChallenge(ctx context.Context, ch *Challenge) error { - if len(ch.AccountID) == 0 { - return NewErrorISE("account-id cannot be empty") - } - if len(ch.AuthzID) == 0 { - return NewErrorISE("authz-id cannot be empty") - } - if len(ch.Token) == 0 { - return NewErrorISE("token cannot be empty") - } - if len(ch.Value) == 0 { - return NewErrorISE("value cannot be empty") - } - - switch ch.Type { - case "dns-01", "http-01", "tls-alpn-01": - break - default: - return NewErrorISE("unexpected error type '%s'", ch.Type) - } - - if err := a.db.CreateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error creating challenge") - } - return nil -} - -// GetValidateChallenge attempts to validate the challenge. -func (a *Authority) GetValidateChallenge(ctx context.Context, accID, chID, azID string, jwk *jose.JSONWebKey) (*Challenge, error) { - ch, err := a.db.GetChallenge(ctx, chID, "todo") - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving challenge") - } - if accID != ch.AccountID { - log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - client := http.Client{ - Timeout: time.Duration(30 * time.Second), - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - if err = ch.Validate(ctx, a.db, jwk, validateOptions{ - httpGet: client.Get, - lookupTxt: net.LookupTXT, - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }); err != nil { - return nil, ErrorISEWrap(err, "error validating challenge") - } - return ch, nil -} - -// GetCertificate retrieves the Certificate by ID. -func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { - cert, err := a.db.GetCertificate(ctx, certID) - if err != nil { - return nil, ErrorISEWrap(err, "error retrieving certificate") - } - if cert.AccountID != accID { - log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) - return nil, NewError(ErrorUnauthorizedType, "account does not own order") - } - return cert.ToACME(ctx) -} diff --git a/acme/authorization.go b/acme/authorization.go index 7f15f4c6..df4ac229 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -22,7 +22,7 @@ type Authorization struct { func (az *Authorization) ToLog() (interface{}, error) { b, err := json.Marshal(az) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling authz for logging") + return nil, WrapErrorISE(err, "error marshaling authz for logging") } return string(b), nil } @@ -30,11 +30,7 @@ func (az *Authorization) ToLog() (interface{}, error) { // UpdateStatus updates the ACME Authorization Status if necessary. // Changes to the Authorization are saved using the database interface. func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { - now := time.Now().UTC() - expiry, err := time.Parse(time.RFC3339, az.Expires) - if err != nil { - return ErrorISEWrap(err, "error converting expiry string to time") - } + now := clock.Now() switch az.Status { case StatusInvalid: @@ -43,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil case StatusPending: // check expiry - if now.After(expiry) { + if now.After(az.Expires) { az.Status = StatusInvalid break } @@ -61,11 +57,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { } az.Status = StatusValid default: - return NewError(ErrorServerInternalType, "unrecognized authorization status: %s", az.Status) + return NewErrorISE("unrecognized authorization status: %s", az.Status) } - if err = db.UpdateAuthorization(ctx, az); err != nil { - return ErrorISEWrap(err, "error updating authorization") + if err := db.UpdateAuthorization(ctx, az); err != nil { + return WrapErrorISE(err, "error updating authorization") } return nil } diff --git a/acme/certificate.go b/acme/certificate.go index 356c0121..daf9556b 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -1,7 +1,6 @@ package acme import ( - "context" "crypto/x509" "encoding/pem" ) @@ -16,7 +15,7 @@ type Certificate struct { } // ToACME encodes the entire X509 chain into a PEM list. -func (cert *Certificate) ToACME(ctx context.Context) ([]byte, error) { +func (cert *Certificate) ToACME() ([]byte, error) { var ret []byte for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { ret = append(ret, pem.EncodeToMemory(&pem.Block{ diff --git a/acme/challenge.go b/acme/challenge.go index ca2e5562..2abc808c 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -38,7 +38,7 @@ type Challenge struct { func (ch *Challenge) ToLog() (interface{}, error) { b, err := json.Marshal(ch) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling challenge for logging") + return nil, WrapErrorISE(err, "error marshaling challenge for logging") } return string(b), nil } @@ -47,7 +47,7 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { // If already valid or invalid then return without performing validation. if ch.Status == StatusValid || ch.Status == StatusInvalid { return nil @@ -60,16 +60,16 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, case "tls-alpn-01": return tlsalpn01Validate(ctx, ch, db, jwk, vo) default: - return NewError(ErrorServerInternalType, "unexpected challenge type '%s'", ch.Type) + return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) - resp, err := vo.httpGet(url) + resp, err := vo.HTTPGet(url) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", url)) } if resp.StatusCode >= 400 { @@ -80,7 +80,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb body, err := ioutil.ReadAll(resp.Body) if err != nil { - return ErrorISEWrap(err, "error reading "+ + return WrapErrorISE(err, "error reading "+ "response body for url %s", url) } keyAuth := strings.Trim(string(body), "\r\n") @@ -100,12 +100,12 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error updating challenge") + return WrapErrorISE(err, "error updating challenge") } return nil } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, ServerName: ch.Value, @@ -114,9 +114,9 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.tlsDial("tcp", hostPort, config) + conn, err := vo.TLSDial("tcp", hostPort, config) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, + return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err, "error doing TLS dial for %s", hostPort)) } defer conn.Close() @@ -178,7 +178,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON ch.Validated = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") + return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") } return nil } @@ -197,16 +197,16 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) + txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) if err != nil { - return storeError(ctx, ch, db, ErrorWrap(ErrorDNSType, err, + return storeError(ctx, ch, db, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) } @@ -234,7 +234,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK ch.Validated = clock.Now().UTC().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "error updating challenge") + return WrapErrorISE(err, "error updating challenge") } return nil } @@ -244,7 +244,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return "", ErrorISEWrap(err, "error generating JWK thumbprint") + return "", WrapErrorISE(err, "error generating JWK thumbprint") } encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return fmt.Sprintf("%s.%s", token, encPrint), nil @@ -254,7 +254,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { ch.Error = err if err := db.UpdateChallenge(ctx, ch); err != nil { - return ErrorISEWrap(err, "failure saving error to acme challenge") + return WrapErrorISE(err, "failure saving error to acme challenge") } return nil } @@ -263,8 +263,9 @@ type httpGetter func(string) (*http.Response, error) type lookupTxt func(string) ([]string, error) type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) -type validateOptions struct { - httpGet httpGetter - lookupTxt lookupTxt - tlsDial tlsDialer +// ValidateOptions are ACME challenge validator functions. +type ValidateOptions struct { + HTTPGet httpGetter + LookupTxt lookupTxt + TLSDial tlsDialer } diff --git a/acme/common.go b/acme/common.go index b9dc6ff2..f7fd7141 100644 --- a/acme/common.go +++ b/acme/common.go @@ -3,13 +3,27 @@ package acme import ( "context" "crypto/x509" - "net/url" "time" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" ) +// CertificateAuthority is the interface implemented by a CA authority. +type CertificateAuthority interface { + Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + LoadProvisionerByID(string) (provisioner.Interface, error) +} + +// Clock that returns time in UTC rounded to seconds. +type Clock int + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Round(time.Second) +} + +var clock = new(Clock) + // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { @@ -70,89 +84,3 @@ func (m *MockProvisioner) GetID() string { } return m.Mret1.(string) } - -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string - -const ( - // AccContextKey account key - AccContextKey = ContextKey("acc") - // BaseURLContextKey baseURL key - BaseURLContextKey = ContextKey("baseURL") - // JwsContextKey jws key - JwsContextKey = ContextKey("jws") - // JwkContextKey jwk key - JwkContextKey = ContextKey("jwk") - // PayloadContextKey payload key - PayloadContextKey = ContextKey("payload") - // ProvisionerContextKey provisioner key - ProvisionerContextKey = ContextKey("provisioner") -) - -// AccountFromContext searches the context for an ACME account. Returns the -// account or an error. -func AccountFromContext(ctx context.Context) (*Account, error) { - val, ok := ctx.Value(AccContextKey).(*Account) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "account not in context") - } - return val, nil -} - -// BaseURLFromContext returns the baseURL if one is stored in the context. -func BaseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(BaseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - -// JwkFromContext searches the context for a JWK. Returns the JWK or an error. -func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { - val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "jwk expected in request context") - } - return val, nil -} - -// JwsFromContext searches the context for a JWS. Returns the JWS or an error. -func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { - val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature) - if !ok || val == nil { - return nil, NewError(ErrorServerInternalType, "jws expected in request context") - } - return val, nil -} - -// ProvisionerFromContext searches the context for a provisioner. Returns the -// provisioner or an error. -func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { - val := ctx.Value(ProvisionerContextKey) - if val == nil { - return nil, NewError(ErrorServerInternalType, "provisioner expected in request context") - } - pval, ok := val.(Provisioner) - if !ok || pval == nil { - return nil, NewError(ErrorServerInternalType, "provisioner in context is not an ACME provisioner") - } - return pval, nil -} - -// SignAuthority is the interface implemented by a CA authority. -type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - LoadProvisionerByID(string) (provisioner.Interface, error) -} - -// Clock that returns time in UTC rounded to seconds. -type Clock int - -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) -} - -var clock = new(Clock) diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 40961ce3..befeb54d 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -74,7 +74,6 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) return &acme.Account{ Status: dbacc.Status, Contact: dbacc.Contact, - Orders: dir.getLink(ctx, OrdersByAccountLink, true, dbacc.ID), Key: dbacc.Key, ID: dbacc.ID, }, nil diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 818f5c2d..0992509d 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -14,15 +14,15 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier *acme.Identifier `json:"identifier"` - Status acme.Status `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *acme.Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + Expires time.Time `json:"expires"` + Challenges []string `json:"challenges"` + Wildcard bool `json:"wildcard"` + Created time.Time `json:"created"` + Error *acme.Error `json:"error"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat Status: dbaz.Status, Challenges: chs, Wildcard: dbaz.Wildcard, - Expires: dbaz.Expires.Format(time.RFC3339), + Expires: dbaz.Expires, ID: dbaz.ID, }, nil } diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index 378b1f7b..48340cf4 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -21,7 +21,7 @@ type dbChallenge struct { Value string `json:"value"` Validated string `json:"validated"` Created time.Time `json:"created"` - Error *AError `json:"error"` + Error *acme.Error `json:"error"` } func (dbc *dbChallenge) clone() *dbChallenge { @@ -79,7 +79,6 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall Type: dbch.Type, Status: dbch.Status, Token: dbch.Token, - URL: dir.getLink(ctx, ChallengeLink, true, dbch.ID), ID: dbch.ID, AuthzID: dbch.AuthzID, Error: dbch.Error, diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index d2146e22..2f5ee11b 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -11,8 +11,6 @@ import ( "github.com/smallstep/nosql" ) -var defaultOrderExpiry = time.Hour * 24 - // Mutex for locking ordersByAccount index operations. var ordersByAccountMux sync.Mutex @@ -26,16 +24,16 @@ type dbOrder struct { Identifiers []acme.Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` + Error *acme.Error `json:"error,omitempty"` Authorizations []string `json:"authorizations"` - Certificate string `json:"certificate,omitempty"` + CertificateID string `json:"certificate,omitempty"` } // getDBOrder retrieves and unmarshals an ACME Order type from the database. func (db *DB) getDBOrder(id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "order %s not found", id) + return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading order %s", id) } @@ -49,34 +47,31 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) { // GetOrder retrieves an ACME Order from the database. func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { dbo, err := db.getDBOrder(id) - - azs := make([]string, len(dbo.Authorizations)) - for i, aid := range dbo.Authorizations { - azs[i] = dir.getLink(ctx, AuthzLink, true, aid) + if err != nil { + return nil, err } + o := &acme.Order{ - Status: dbo.Status, - Expires: dbo.Expires.Format(time.RFC3339), - Identifiers: dbo.Identifiers, - NotBefore: dbo.NotBefore.Format(time.RFC3339), - NotAfter: dbo.NotAfter.Format(time.RFC3339), - Authorizations: azs, - FinalizeURL: dir.getLink(ctx, FinalizeLink, true, o.ID), - ID: dbo.ID, - ProvisionerID: dbo.ProvisionerID, + Status: dbo.Status, + Expires: dbo.Expires, + Identifiers: dbo.Identifiers, + NotBefore: dbo.NotBefore, + NotAfter: dbo.NotAfter, + AuthorizationIDs: dbo.Authorizations, + ID: dbo.ID, + ProvisionerID: dbo.ProvisionerID, + CertificateID: dbo.CertificateID, } - if dbo.Certificate != "" { - o.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) - } return o, nil } // CreateOrder creates ACME Order resources and saves them to the DB. func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { + var err error o.ID, err = randID() if err != nil { - return nil, err + return err } now := clock.Now() @@ -85,23 +80,23 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { AccountID: o.AccountID, ProvisionerID: o.ProvisionerID, Created: now, - Status: StatusPending, - Expires: now.Add(defaultOrderExpiry), + Status: acme.StatusPending, + Expires: o.Expires, Identifiers: o.Identifiers, NotBefore: o.NotBefore, NotAfter: o.NotBefore, Authorizations: o.AuthorizationIDs, } - if err := db.save(ctx, o.ID, dbo, nil, orderTable); err != nil { - return nil, err + if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { + return err } var oidHelper = orderIDsByAccount{} _, err = oidHelper.addOrderID(db, o.AccountID, o.ID) if err != nil { - return nil, err + return err } - return o, nil + return nil } type orderIDsByAccount struct{} @@ -135,11 +130,11 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri if nosql.IsErrNotFound(err) { return []string{}, nil } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) + return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) } var oids []string if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) + return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } // Remove any order that is not in PENDING state and update the stored list @@ -152,21 +147,21 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri for _, oid := range oids { o, err := getOrder(db, oid) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + return nil, errors.Wrapf(err, "error loading order %s for account %s", oid, accID) } if o, err = o.UpdateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + return nil, errors.Wrapf(err, "error updating order %s for account %s", oid, accID) } - if o.Status == StatusPending { + if o.Status == acme.StatusPending { pendOids = append(pendOids, oid) } } // If the number of pending orders is less than the number of orders in the // list, then update the pending order list. if len(pendOids) != len(oids) { - if err = orderIDs(pendOiUs).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids)) } } @@ -192,7 +187,7 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { } else { oldb, err = json.Marshal(old) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) + return errors.Wrap(err, "error marshaling old order IDs slice") } } if len(oids) == 0 { @@ -200,13 +195,13 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { } else { newb, err = json.Marshal(oids) if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) + return errors.Wrap(err, "error marshaling new order IDs slice") } } _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) switch { case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) + return errors.Wrapf(err, "error storing order IDs for account %s", accID) case !swapped: return ServerInternalErr(errors.Errorf("error storing order IDs "+ "for account %s; order IDs changed since last read", accID)) diff --git a/acme/directory.go b/acme/directory.go deleted file mode 100644 index 8520d0e9..00000000 --- a/acme/directory.go +++ /dev/null @@ -1,148 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "fmt" - "net/url" -) - -// Directory represents an ACME directory for configuring clients. -type Directory struct { - NewNonce string `json:"newNonce,omitempty"` - NewAccount string `json:"newAccount,omitempty"` - NewOrder string `json:"newOrder,omitempty"` - NewAuthz string `json:"newAuthz,omitempty"` - RevokeCert string `json:"revokeCert,omitempty"` - KeyChange string `json:"keyChange,omitempty"` -} - -// ToLog enables response logging for the Directory type. -func (d *Directory) ToLog() (interface{}, error) { - b, err := json.Marshal(d) - if err != nil { - return nil, ErrorISEWrap(err, "error marshaling directory for logging") - } - return string(b), nil -} - -type directory struct { - prefix, dns string -} - -// newDirectory returns a new Directory type. -func newDirectory(dns, prefix string) *directory { - return &directory{prefix: prefix, dns: dns} -} - -// Link captures the link type. -type Link int - -const ( - // NewNonceLink new-nonce - NewNonceLink Link = iota - // NewAccountLink new-account - NewAccountLink - // AccountLink account - AccountLink - // OrderLink order - OrderLink - // NewOrderLink new-order - NewOrderLink - // OrdersByAccountLink list of orders owned by account - OrdersByAccountLink - // FinalizeLink finalize order - FinalizeLink - // NewAuthzLink authz - NewAuthzLink - // AuthzLink new-authz - AuthzLink - // ChallengeLink challenge - ChallengeLink - // CertificateLink certificate - CertificateLink - // DirectoryLink directory - DirectoryLink - // RevokeCertLink revoke certificate - RevokeCertLink - // KeyChangeLink key rollover - KeyChangeLink -) - -func (l Link) String() string { - switch l { - case NewNonceLink: - return "new-nonce" - case NewAccountLink: - return "new-account" - case AccountLink: - return "account" - case NewOrderLink: - return "new-order" - case OrderLink: - return "order" - case NewAuthzLink: - return "new-authz" - case AuthzLink: - return "authz" - case ChallengeLink: - return "challenge" - case CertificateLink: - return "certificate" - case DirectoryLink: - return "directory" - case RevokeCertLink: - return "revoke-cert" - case KeyChangeLink: - return "key-change" - default: - return "unexpected" - } -} - -func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - var provName string - if p, err := ProvisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...) -} - -// getLinkExplicit returns an absolute or partial path to the given resource and a base -// URL dynamically obtained from the request for which the link is being -// calculated. -func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { - var link string - switch typ { - case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: - link = fmt.Sprintf("/%s/%s", provisionerName, typ.String()) - case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink: - link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0]) - case OrdersByAccountLink: - link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0]) - case FinalizeLink: - link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) - } - - if abs { - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - u := url.URL{} - if baseURL != nil { - u = *baseURL - } - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = d.dns - } - - u.Path = d.prefix + link - return u.String() - } - return link -} diff --git a/acme/errors.go b/acme/errors.go index 8fe2559d..41305c87 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -262,7 +262,7 @@ var ( // Error represents an ACME type Error struct { Type string `json:"type"` - Details string `json:"detail"` + Detail string `json:"detail"` Subproblems []interface{} `json:"subproblems,omitempty"` Identifier interface{} `json:"identifier,omitempty"` Err error `json:"-"` @@ -275,18 +275,18 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error { if !ok { meta = errorServerInternalMetadata return &Error{ - Type: meta.typ, - Details: meta.details, - Status: meta.status, - Err: errors.Errorf("unrecognized problemType %v", pt), + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: errors.Errorf("unrecognized problemType %v", pt), } } return &Error{ - Type: meta.typ, - Details: meta.details, - Status: meta.status, - Err: errors.Errorf(msg, args...), + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: errors.Errorf(msg, args...), } } @@ -295,14 +295,14 @@ func NewErrorISE(msg string, args ...interface{}) *Error { return NewError(ErrorServerInternalType, msg, args...) } -// ErrorWrap attempts to wrap the internal error. -func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error { +// WrapError attempts to wrap the internal error. +func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error { switch e := err.(type) { case nil: return nil case *Error: if e.Err == nil { - e.Err = errors.Errorf(msg+"; "+e.Details, args...) + e.Err = errors.Errorf(msg+"; "+e.Detail, args...) } else { e.Err = errors.Wrapf(e.Err, msg, args...) } @@ -312,9 +312,9 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err } } -// ErrorISEWrap shortcut to wrap an internal server error type. -func ErrorISEWrap(err error, msg string, args ...interface{}) *Error { - return ErrorWrap(ErrorServerInternalType, err, msg, args...) +// WrapErrorISE shortcut to wrap an internal server error type. +func WrapErrorISE(err error, msg string, args ...interface{}) *Error { + return WrapError(ErrorServerInternalType, err, msg, args...) } // StatusCode returns the status code and implements the StatusCoder interface. @@ -324,13 +324,13 @@ func (e *Error) StatusCode() int { // Error allows AError to implement the error interface. func (e *Error) Error() string { - return e.Details + return e.Detail } // Cause returns the internal error and implements the Causer interface. func (e *Error) Cause() error { if e.Err == nil { - return errors.New(e.Details) + return errors.New(e.Detail) } return e.Err } diff --git a/acme/nonce.go b/acme/nonce.go index 4234e818..25c86360 100644 --- a/acme/nonce.go +++ b/acme/nonce.go @@ -1,3 +1,9 @@ package acme +// Nonce represents an ACME nonce type. type Nonce string + +// String implements the ToString interface. +func (n Nonce) String() string { + return string(n) +} diff --git a/acme/order.go b/acme/order.go index bf3297f9..1719d899 100644 --- a/acme/order.go +++ b/acme/order.go @@ -26,10 +26,11 @@ type Order struct { NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` Error interface{} `json:"error,omitempty"` - AuthorizationURLs []string `json:"authorizations"` AuthorizationIDs []string `json:"-"` + AuthorizationURLs []string `json:"authorizations"` FinalizeURL string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` + CertificateID string `json:"-"` + CertificateURL string `json:"certificate,omitempty"` ID string `json:"-"` AccountID string `json:"-"` ProvisionerID string `json:"-"` @@ -41,7 +42,7 @@ type Order struct { func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { - return nil, ErrorISEWrap(err, "error marshaling order for logging") + return nil, WrapErrorISE(err, "error marshaling order for logging") } return string(b), nil } @@ -111,7 +112,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { // Finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { +func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error { if err := o.UpdateStatus(ctx, db); err != nil { return err } @@ -170,7 +171,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return ErrorISEWrap(err, "error retrieving authorization options from ACME provisioner") + return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner") } // Template data @@ -180,7 +181,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return ErrorISEWrap(err, "error creating template options from ACME provisioner") + return WrapErrorISE(err, "error creating template options from ACME provisioner") } signOps = append(signOps, templateOptions) @@ -190,7 +191,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) if err != nil { - return ErrorISEWrap(err, "error signing certificate for order %s", o.ID) + return WrapErrorISE(err, "error signing certificate for order %s", o.ID) } cert := &Certificate{ @@ -203,7 +204,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques return err } - o.Certificate = cert.ID + o.CertificateID = cert.ID o.Status = StatusValid return db.UpdateOrder(ctx, o) } diff --git a/ca/acmeClient.go b/ca/acmeClient.go index deb8a3a2..b19ad664 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -21,7 +21,7 @@ import ( type ACMEClient struct { client *http.Client dirLoc string - dir *acme.Directory + dir *acmeAPI.Directory acc *acme.Account Key *jose.JSONWebKey kid string @@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } - var dir acme.Directory + var dir acmeAPI.Directory if err := readJSON(resp.Body, &dir); err != nil { return nil, errors.Wrapf(err, "error reading %s", endpoint) } @@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC // GetDirectory makes a directory request to the ACME api and returns an // ACME directory object. -func (c *ACMEClient) GetDirectory() (*acme.Directory, error) { +func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) { return c.dir, nil } @@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error { } // GetAuthz returns the Authz at the given path. -func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { +func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err @@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { return nil, readACMEError(resp.Body) } - var az acme.Authz + var az acme.Authorization if err := readJSON(resp.Body, &az); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } @@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error { if err != nil { return errors.Wrap(err, "error reading from body") } - ae := new(acme.AError) + ae := new(acme.Error) err = json.Unmarshal(b, &ae) // If we successfully marshaled to an ACMEError then return the ACMEError. if err != nil || len(ae.Error()) == 0 { diff --git a/ca/ca.go b/ca/ca.go index 5ba81e9e..5ebc0919 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -11,8 +11,8 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" - "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" + acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/db" @@ -124,11 +124,12 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeAuth, err := acme.New(auth, acme.AuthorityOptions{ + acmeAuth, err := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ Backdate: *config.AuthorityConfig.Backdate, - DB: auth.GetDatabase().(nosql.DB), + DB: acmeNoSQL.New(auth.GetDatabase().(nosql.DB)), DNS: dns, Prefix: prefix, + CA: auth, }) if err != nil { return nil, errors.Wrap(err, "error creating ACME authority")