From 9628fa356224902ba9527a41d0ec99a627cd3344 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 26 Apr 2022 12:54:54 -0700 Subject: [PATCH 01/40] Add methods to store and retrieve an authority from the context. --- authority/authority.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/authority/authority.go b/authority/authority.go index 9db38e14..091a01ae 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/hex" "log" + "net/http" "strings" "sync" "time" @@ -153,6 +154,27 @@ func NewEmbedded(opts ...Option) (*Authority, error) { return a, nil } +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// Middleware adds the current authority to the request context. +func (a *Authority) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := NewContext(r.Context(), a) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + // reloadAdminResources reloads admins and provisioners from the DB. func (a *Authority) reloadAdminResources(ctx context.Context) error { var ( From 900a640f016981b2b9d13bbbb5626c6532dff35a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 26 Apr 2022 12:55:28 -0700 Subject: [PATCH 02/40] Enable the authority middleware in the server --- ca/ca.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ca/ca.go b/ca/ca.go index 0380d166..bb8e15ac 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -279,6 +279,10 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } + // Add authority handler + handler = auth.Middleware(handler) + insecureHandler = auth.Middleware(insecureHandler) + ca.srv = server.New(cfg.Address, handler, tlsConfig) // only start the insecure server if the insecure address is configured From a6b8e65d69f7703108c65ec8b6ef221692dea5df Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 26 Apr 2022 12:58:40 -0700 Subject: [PATCH 03/40] Retrieve the authority from the context in api methods. --- api/api.go | 94 +++++++++++++++++++++++++++--------------------- api/rekey.go | 7 ++-- api/renew.go | 14 ++++---- api/revoke.go | 11 +++--- api/sign.go | 9 ++--- api/ssh.go | 44 +++++++++++++---------- api/sshRekey.go | 10 +++--- api/sshRenew.go | 13 +++---- api/sshRevoke.go | 9 +++-- 9 files changed, 121 insertions(+), 90 deletions(-) diff --git a/api/api.go b/api/api.go index da6309fd..9b795cf0 100644 --- a/api/api.go +++ b/api/api.go @@ -52,6 +52,16 @@ type Authority interface { Version() authority.Version } +var errAuthority = errors.New("authority is not in context") + +func mustAuthority(ctx context.Context) Authority { + a, ok := authority.FromContext(ctx) + if !ok { + panic(errAuthority) + } + return a +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -251,40 +261,40 @@ func New(auth Authority) RouterHandler { } func (h *caHandler) Route(r Router) { - r.MethodFunc("GET", "/version", h.Version) - r.MethodFunc("GET", "/health", h.Health) - r.MethodFunc("GET", "/root/{sha}", h.Root) - r.MethodFunc("POST", "/sign", h.Sign) - r.MethodFunc("POST", "/renew", h.Renew) - r.MethodFunc("POST", "/rekey", h.Rekey) - r.MethodFunc("POST", "/revoke", h.Revoke) - r.MethodFunc("GET", "/provisioners", h.Provisioners) - r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) - r.MethodFunc("GET", "/roots", h.Roots) - r.MethodFunc("GET", "/roots.pem", h.RootsPEM) - r.MethodFunc("GET", "/federation", h.Federation) + r.MethodFunc("GET", "/version", Version) + r.MethodFunc("GET", "/health", Health) + r.MethodFunc("GET", "/root/{sha}", Root) + r.MethodFunc("POST", "/sign", Sign) + r.MethodFunc("POST", "/renew", Renew) + r.MethodFunc("POST", "/rekey", Rekey) + r.MethodFunc("POST", "/revoke", Revoke) + r.MethodFunc("GET", "/provisioners", Provisioners) + r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) + r.MethodFunc("GET", "/roots", Roots) + r.MethodFunc("GET", "/roots.pem", RootsPEM) + r.MethodFunc("GET", "/federation", Federation) // SSH CA - r.MethodFunc("POST", "/ssh/sign", h.SSHSign) - r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) - r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) - r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) - r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) - r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) - r.MethodFunc("POST", "/ssh/config", h.SSHConfig) - r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) - r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) - r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) - r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) + r.MethodFunc("POST", "/ssh/sign", SSHSign) + r.MethodFunc("POST", "/ssh/renew", SSHRenew) + r.MethodFunc("POST", "/ssh/revoke", SSHRevoke) + r.MethodFunc("POST", "/ssh/rekey", SSHRekey) + r.MethodFunc("GET", "/ssh/roots", SSHRoots) + r.MethodFunc("GET", "/ssh/federation", SSHFederation) + r.MethodFunc("POST", "/ssh/config", SSHConfig) + r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig) + r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost) + r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts) + r.MethodFunc("POST", "/ssh/bastion", SSHBastion) // For compatibility with old code: - r.MethodFunc("POST", "/re-sign", h.Renew) - r.MethodFunc("POST", "/sign-ssh", h.SSHSign) - r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) + r.MethodFunc("POST", "/re-sign", Renew) + r.MethodFunc("POST", "/sign-ssh", SSHSign) + r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts) } // Version is an HTTP handler that returns the version of the server. -func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { - v := h.Authority.Version() +func Version(w http.ResponseWriter, r *http.Request) { + v := mustAuthority(r.Context()).Version() render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, @@ -292,17 +302,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { } // Health is an HTTP handler that returns the status of the server. -func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { +func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root // certificate for the given SHA256. -func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { +func Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the - cert, err := h.Authority.Root(sum) + cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return @@ -320,18 +330,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { } // Provisioners returns the list of provisioners configured in the authority. -func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { +func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { render.Error(w, err) return } - p, next, err := h.Authority.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return } + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, @@ -339,19 +350,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { } // ProvisionerKey returns the encrypted key of a provisioner by it's key id. -func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { +func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") - key, err := h.Authority.GetEncryptedKey(kid) + key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { render.Error(w, errs.NotFoundErr(err)) return } + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. -func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func Roots(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return @@ -368,8 +380,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { } // RootsPEM returns all the root certificates for the CA in PEM format. -func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func RootsPEM(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -391,8 +403,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { } // Federation returns all the public certificates in the federation. -func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - federated, err := h.Authority.GetFederation() +func Federation(w http.ResponseWriter, r *http.Request) { + federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return diff --git a/api/rekey.go b/api/rekey.go index 3116cf74..cda843a3 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error { } // Rekey is similar to renew except that the certificate will be renewed with new key from csr. -func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { +func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, errs.BadRequest("missing client certificate")) return @@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { return } - certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) + a := mustAuthority(r.Context()) + certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return @@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/renew.go b/api/renew.go index 9c4bff32..6e9f680f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -16,14 +16,15 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. -func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - cert, err := h.getPeerCertificate(r) +func Renew(w http.ResponseWriter, r *http.Request) { + cert, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - certChain, err := h.Authority.Renew(cert) + a := mustAuthority(r.Context()) + certChain, err := a.Renew(cert) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } -func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + ctx := r.Context() + return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) } } return nil, errs.BadRequest("missing client certificate") diff --git a/api/revoke.go b/api/revoke.go index c9da2c18..aebbb875 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "golang.org/x/crypto/ocsp" @@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. // // TODO: Add CRL and OCSP support. -func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { +func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. if len(body.OTT) > 0 { logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } @@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { opts.MTLS = true } - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/sign.go b/api/sign.go index b6bfcc8b..b263e2e9 100644 --- a/api/sign.go +++ b/api/sign.go @@ -49,7 +49,7 @@ type SignResponse struct { // Sign is an HTTP handler that reads a certificate request and an // one-time-token (ott) from the body and creates a new certificate with the // information in the certificate request. -func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { +func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,13 +68,14 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { TemplateData: body.TemplateData, } - signOpts, err := h.Authority.AuthorizeSign(body.OTT) + a := mustAuthority(r.Context()) + signOpts, err := a.AuthorizeSign(body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return @@ -89,6 +90,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/ssh.go b/api/ssh.go index 3b0de7c1..f3056fc5 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -250,7 +250,7 @@ type SSHBastionResponse struct { // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { +func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -288,13 +288,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) + cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -302,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var addUserCertificate *SSHCertificate if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { - addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) + addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -315,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -327,7 +329,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return @@ -344,8 +346,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. -func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHRoots(r.Context()) +func SSHRoots(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -369,8 +372,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. -func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederation(r.Context()) +func SSHFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -394,7 +398,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. -func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { +func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -405,7 +409,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) + ctx := r.Context() + ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -426,7 +431,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. -func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { +func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -437,7 +442,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } - exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) + ctx := r.Context() + exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -448,13 +454,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { } // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. -func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { +func SSHGetHosts(w http.ResponseWriter, r *http.Request) { var cert *x509.Certificate if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { cert = r.TLS.PeerCertificates[0] } - hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) + ctx := r.Context() + hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -465,7 +472,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { } // SSHBastion provides returns the bastion configured if any. -func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { +func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -476,7 +483,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) + ctx := r.Context() + bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { render.Error(w, errs.InternalServerErr(err)) return diff --git a/api/sshRekey.go b/api/sshRekey.go index 92278950..184f208a 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -39,7 +39,7 @@ type SSHRekeyResponse struct { // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { +func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -59,7 +59,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -70,7 +72,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) + newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return @@ -80,7 +82,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 78d16fa6..606b45bb 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -37,7 +37,7 @@ type SSHRenewResponse struct { // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { +func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -51,7 +51,8 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) - _, err := h.Authority.Authorize(ctx, body.OTT) + a := mustAuthority(ctx) + _, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -62,7 +63,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RenewSSH(ctx, oldCert) + newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return @@ -72,7 +73,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return @@ -85,7 +86,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the -func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { +func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } @@ -105,7 +106,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte cert.NotAfter = notAfter } - certChain, err := h.Authority.Renew(cert) + certChain, err := mustAuthority(r.Context()).Renew(cert) if err != nil { return nil, err } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index a33082cd..d377def9 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. -func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { +func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } From a93653ea8e2d1c2274784a719e8665660a69b574 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 26 Apr 2022 14:32:55 -0700 Subject: [PATCH 04/40] Use api.Route instead of the caHandler. --- api/api.go | 15 +++++++++------ ca/ca.go | 5 ++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/api/api.go b/api/api.go index 9b795cf0..2137e29a 100644 --- a/api/api.go +++ b/api/api.go @@ -249,18 +249,21 @@ type FederationResponse struct { } // caHandler is the type used to implement the different CA HTTP endpoints. -type caHandler struct { - Authority Authority +type caHandler struct{} + +// Route configures the http request router. +func (h *caHandler) Route(r Router) { + Route(r) } // New creates a new RouterHandler with the CA endpoints. +// +// Deprecated: Use api.Route(r Router) func New(auth Authority) RouterHandler { - return &caHandler{ - Authority: auth, - } + return &caHandler{} } -func (h *caHandler) Route(r Router) { +func Route(r Router) { r.MethodFunc("GET", "/version", Version) r.MethodFunc("GET", "/health", Health) r.MethodFunc("GET", "/root/{sha}", Root) diff --git a/ca/ca.go b/ca/ca.go index bb8e15ac..24da6311 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -170,10 +170,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler := http.Handler(insecureMux) // Add regular CA api endpoints in / and /1.0 - routerHandler := api.New(auth) - routerHandler.Route(mux) + api.Route(mux) mux.Route("/1.0", func(r chi.Router) { - routerHandler.Route(r) + api.Route(r) }) //Add ACME api endpoints in /acme and /1.0/acme From 817af3d6965be47ccec02d602cbd0e3a10d8bf59 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 10:38:53 -0700 Subject: [PATCH 05/40] Fix unit tests on the api package --- api/api.go | 7 +++-- api/api_test.go | 78 +++++++++++++++++++++++++--------------------- api/revoke_test.go | 4 +-- api/ssh_test.go | 56 ++++++++++++++++----------------- 4 files changed, 77 insertions(+), 68 deletions(-) diff --git a/api/api.go b/api/api.go index 2137e29a..e5f4f944 100644 --- a/api/api.go +++ b/api/api.go @@ -54,7 +54,8 @@ type Authority interface { var errAuthority = errors.New("authority is not in context") -func mustAuthority(ctx context.Context) Authority { +// mustAuthority will be replaced on unit tests. +var mustAuthority = func(ctx context.Context) Authority { a, ok := authority.FromContext(ctx) if !ok { panic(errAuthority) @@ -249,7 +250,9 @@ type FederationResponse struct { } // caHandler is the type used to implement the different CA HTTP endpoints. -type caHandler struct{} +type caHandler struct { + Authority Authority +} // Route configures the http request router. func (h *caHandler) Route(r Router) { diff --git a/api/api_test.go b/api/api_test.go index 39c77de7..698b629c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -171,6 +171,17 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } +func mockMustAuthority(t *testing.T, a Authority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) Authority { + return a + } +} + type mockAuthority struct { ret1, ret2 interface{} err error @@ -789,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) { } } -func Test_caHandler_Health(t *testing.T) { +func Test_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&mockAuthority{}).(*caHandler) - h.Health(w, req) + Health(w, req) res := w.Result() if res.StatusCode != 200 { @@ -811,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) { } } -func Test_caHandler_Root(t *testing.T) { +func Test_Root(t *testing.T) { tests := []struct { name string root *x509.Certificate @@ -832,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err}) w := httptest.NewRecorder() - h.Root(w, req) + Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -855,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) { } } -func Test_caHandler_Sign(t *testing.T) { +func Test_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, @@ -896,7 +906,7 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr @@ -904,10 +914,10 @@ func Test_caHandler_Sign(t *testing.T) { getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() - h.Sign(logging.NewResponseLogger(w), req) + Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -928,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) { } } -func Test_caHandler_Renew(t *testing.T) { +func Test_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1018,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) @@ -1039,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) { getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls req.Header = tt.header w := httptest.NewRecorder() - h.Renew(logging.NewResponseLogger(w), req) + Renew(logging.NewResponseLogger(w), req) res := w.Result() defer res.Body.Close() @@ -1073,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) { } } -func Test_caHandler_Rekey(t *testing.T) { +func Test_Rekey(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1104,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req.TLS = tt.tls w := httptest.NewRecorder() - h.Rekey(logging.NewResponseLogger(w), req) + Rekey(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1134,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) { } } -func Test_caHandler_Provisioners(t *testing.T) { +func Test_Provisioners(t *testing.T) { type fields struct { Authority Authority } @@ -1200,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) { assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.Provisioners(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1238,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) { } } -func Test_caHandler_ProvisionerKey(t *testing.T) { +func Test_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } @@ -1270,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.ProvisionerKey(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1298,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } } -func Test_caHandler_Roots(t *testing.T) { +func Test_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1319,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Roots(w, req) + Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1360,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err}) req := httptest.NewRequest("GET", "https://example.com/roots", nil) w := httptest.NewRecorder() - h.RootsPEM(w, req) + RootsPEM(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1384,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) { } } -func Test_caHandler_Federation(t *testing.T) { +func Test_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1405,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Federation(w, req) + Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/api/revoke_test.go b/api/revoke_test.go index 7635ce68..fa46dd90 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) { for name, _tc := range tests { tc := _tc(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*caHandler) + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) if tc.tls != nil { req.TLS = tc.tls } w := httptest.NewRecorder() - h.Revoke(logging.NewResponseLogger(w), req) + Revoke(logging.NewResponseLogger(w), req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/api/ssh_test.go b/api/ssh_test.go index 88a301f5..c6fee2de 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) { } } -func Test_caHandler_SSHSign(t *testing.T) { +func Test_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() assert.FatalError(t, err) host, err := getSignedHostCertificate() @@ -315,7 +315,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, @@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHSign(logging.NewResponseLogger(w), req) + SSHSign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } } -func Test_caHandler_SSHRoots(t *testing.T) { +func Test_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) w := httptest.NewRecorder() - h.SSHRoots(logging.NewResponseLogger(w), req) + SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } } -func Test_caHandler_SSHFederation(t *testing.T) { +func Test_SSHFederation(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) w := httptest.NewRecorder() - h.SSHFederation(logging.NewResponseLogger(w), req) + SSHFederation(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } } -func Test_caHandler_SSHConfig(t *testing.T) { +func Test_SSHConfig(t *testing.T) { userOutput := []templates.Output{ {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, @@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHConfig(logging.NewResponseLogger(w), req) + SSHConfig(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } } -func Test_caHandler_SSHCheckHost(t *testing.T) { +func Test_SSHCheckHost(t *testing.T) { tests := []struct { name string req string @@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHCheckHost(logging.NewResponseLogger(w), req) + SSHCheckHost(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } } -func Test_caHandler_SSHGetHosts(t *testing.T) { +func Test_SSHGetHosts(t *testing.T) { hosts := []authority.Host{ {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, @@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) w := httptest.NewRecorder() - h.SSHGetHosts(logging.NewResponseLogger(w), req) + SSHGetHosts(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } } -func Test_caHandler_SSHBastion(t *testing.T) { +func Test_SSHBastion(t *testing.T) { bastion := &authority.Bastion{ Hostname: "bastion.local", } @@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHBastion(logging.NewResponseLogger(w), req) + SSHBastion(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { From d5070ecf31b5c1c268a8e9f0243d6f240c6e3738 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 11:06:55 -0700 Subject: [PATCH 06/40] Use server BaseContext Instead of using the authority middleware this change adds the authority in the base context of the server. --- ca/ca.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 24da6311..795fa77a 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -1,10 +1,12 @@ package ca import ( + "context" "crypto/tls" "crypto/x509" "fmt" "log" + "net" "net/http" "net/url" "reflect" @@ -279,10 +281,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Add authority handler - handler = auth.Middleware(handler) - insecureHandler = auth.Middleware(insecureHandler) + baseContext := buildContext(auth) ca.srv = server.New(cfg.Address, handler, tlsConfig) + ca.srv.BaseContext = func(net.Listener) context.Context { + return baseContext + } // only start the insecure server if the insecure address is configured // and, currently, also only when it should serve SCEP endpoints. @@ -292,11 +296,20 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // will probably introduce more complexity in terms of graceful // reload. ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) + ca.insecureSrv.BaseContext = func(net.Listener) context.Context { + return baseContext + } } return ca, nil } +func buildContext(a *authority.Authority) context.Context { + ctx := authority.NewContext(context.Background(), a) + + return ctx +} + // Run starts the CA calling to the server ListenAndServe method. func (ca *CA) Run() error { var wg sync.WaitGroup From 48e2fabeb828b42c043820e6bc010db08b765b96 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 11:38:06 -0700 Subject: [PATCH 07/40] Add authority.MustFromContext --- api/api.go | 8 +------- authority/authority.go | 15 ++++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/api/api.go b/api/api.go index e5f4f944..0ca4a5ef 100644 --- a/api/api.go +++ b/api/api.go @@ -52,15 +52,9 @@ type Authority interface { Version() authority.Version } -var errAuthority = errors.New("authority is not in context") - // mustAuthority will be replaced on unit tests. var mustAuthority = func(ctx context.Context) Authority { - a, ok := authority.FromContext(ctx) - if !ok { - panic(errAuthority) - } - return a + return authority.MustFromContext(ctx) } // TimeDuration is an alias of provisioner.TimeDuration diff --git a/authority/authority.go b/authority/authority.go index 091a01ae..92ed6b31 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -7,7 +7,6 @@ import ( "crypto/x509" "encoding/hex" "log" - "net/http" "strings" "sync" "time" @@ -167,12 +166,14 @@ func FromContext(ctx context.Context) (a *Authority, ok bool) { return } -// Middleware adds the current authority to the request context. -func (a *Authority) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := NewContext(r.Context(), a) - next.ServeHTTP(w, r.WithContext(ctx)) - }) +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + if a, ok := FromContext(ctx); !ok { + panic("authority is not in the context") + } else { + return a + } } // reloadAdminResources reloads admins and provisioners from the DB. From 623c2965557a1aeadff32c1f0b6293d2277fb9e8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 11:58:52 -0700 Subject: [PATCH 08/40] Create context methods from admin database --- authority/admin/db.go | 23 +++++++++++++++++++++++ ca/ca.go | 8 +++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/authority/admin/db.go b/authority/admin/db.go index bf34a3c2..2da1a59a 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -71,6 +71,29 @@ type DB interface { DeleteAdmin(ctx context.Context, id string) error } +type dbKey struct{} + +// NewContext adds the given admin database to the context. +func NewContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current admin database from the given context. +func FromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustFromContext returns the current admin database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) DB { + if db, ok := FromContext(ctx); !ok { + panic("admin database is not in the context") + } else { + return db + } +} + // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { diff --git a/ca/ca.go b/ca/ca.go index 795fa77a..2df52555 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -20,6 +20,7 @@ import ( acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" adminAPI "github.com/smallstep/certificates/authority/admin/api" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/db" @@ -280,7 +281,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } - // Add authority handler + // Create context with all the necessary values. baseContext := buildContext(auth) ca.srv = server.New(cfg.Address, handler, tlsConfig) @@ -304,9 +305,14 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { return ca, nil } +// buildContext builds the server base context. func buildContext(a *authority.Authority) context.Context { ctx := authority.NewContext(context.Background(), a) + if db := a.GetAdminDatabase(); db != nil { + ctx = admin.NewContext(ctx, db) + } + return ctx } From 00f181dec3aa66962fb35788f9bf433e7b48c781 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 11:59:32 -0700 Subject: [PATCH 09/40] Use contexts in admin api handlers --- authority/admin/api/acme.go | 14 +++++--- authority/admin/api/admin.go | 25 ++++++------- authority/admin/api/handler.go | 56 +++++++++++++++++------------ authority/admin/api/middleware.go | 14 ++++---- authority/admin/api/provisioner.go | 58 +++++++++++++++++------------- 5 files changed, 95 insertions(+), 72 deletions(-) diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 21a7229d..2c189624 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -40,11 +40,11 @@ type GetExternalAccountKeysResponse struct { // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. -func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { +func requireEABEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() provName := chi.URLParam(r, "provisionerName") - eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) + eabEnabled, prov, err := provisionerHasEABEnabled(ctx, provName) if err != nil { render.Error(w, err) return @@ -60,16 +60,20 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { // provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME // provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { +func provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { var ( p provisioner.Interface err error ) - if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil { + + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + if p, err = auth.LoadProvisionerByName(provisionerName); err != nil { return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName) } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID()) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 5e4b9c30..6ef6f0eb 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -81,10 +81,10 @@ type DeleteResponse struct { } // GetAdmin returns the requested admin, or an error. -func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { +func GetAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - adm, ok := h.auth.LoadAdminByID(id) + adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) @@ -94,7 +94,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { } // GetAdmins returns a segment of admins associated with the authority. -func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { +func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -102,7 +102,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { return } - admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) + admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return @@ -114,7 +114,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { } // CreateAdmin creates a new admin. -func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { +func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -126,7 +126,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - p, err := h.auth.LoadProvisionerByName(body.Provisioner) + auth := mustAuthority(r.Context()) + p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return @@ -137,7 +138,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { Type: body.Type, } // Store to authority collection. - if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { + if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } @@ -146,10 +147,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // DeleteAdmin deletes admin. -func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { +func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { + if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } @@ -158,7 +159,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { } // UpdateAdmin updates an existing admin. -func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { +func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -171,8 +172,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { } id := chi.URLParam(r, "id") - - adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) + auth := mustAuthority(r.Context()) + adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 99e74c88..0acd3ca9 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,56 +1,66 @@ package api import ( + "context" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB acmeResponder acmeAdminResponderInterface } +// Route traffic and implement the Router interface. +// +// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface) +func (h *Handler) Route(r api.Router) { + Route(r, h.acmeResponder) +} + // NewHandler returns a new Authority Config Handler. +// +// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface) func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, acmeResponder: acmeResponder, } } +var mustAuthority = func(ctx context.Context) adminAuthority { + return authority.MustFromContext(ctx) +} + // Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { +func Route(r api.Router, acmeResponder acmeAdminResponderInterface) { authnz := func(next nextHTTP) nextHTTP { - return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } requireEABEnabled := func(next nextHTTP) nextHTTP { - return h.requireEABEnabled(next) + return requireEABEnabled(next) } // Provisioners - r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) - r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) - r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) - r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) - r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) + r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) + r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) + r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner)) + r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner)) + r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner)) // Admins - r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) - r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) - r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) - r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) - r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) + r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin)) + r.MethodFunc("GET", "/admins", authnz(GetAdmins)) + r.MethodFunc("POST", "/admins", authnz(CreateAdmin)) + r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) + r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys))) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.CreateExternalAccountKey))) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(acmeResponder.DeleteExternalAccountKey))) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index b57dd6eb..9bd6c698 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -12,11 +12,10 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. -func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { +func requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - if !h.auth.IsAdminAPIEnabled() { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, - "administration API not enabled")) + if !mustAuthority(r.Context()).IsAdminAPIEnabled() { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) @@ -24,7 +23,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. -func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { +func extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { @@ -33,13 +32,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - adm, err := h.auth.AuthorizeAdminToken(r, tok) + ctx := r.Context() + adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { render.Error(w, err) return } - ctx := context.WithValue(r.Context(), adminContextKey, adm) + ctx = context.WithValue(ctx, adminContextKey, adm) next(w, r.WithContext(ctx)) } } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 1cad62dd..149f2c6a 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -23,29 +23,31 @@ type GetProvisionersResponse struct { } // GetProvisioner returns the requested provisioner, or an error. -func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func GetProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + ctx := r.Context() + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, err) return @@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { } // GetProvisioners returns the given segment of provisioners associated with the authority. -func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { +func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { return } - p, next, err := h.auth.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { } // CreateProvisioner creates a new prov. -func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { +func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { render.Error(w, err) @@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { + if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } @@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { } // DeleteProvisioner deletes a provisioner. -func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(r.Context()) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { + if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } @@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { } // UpdateProvisioner updates an existing prov. -func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { +func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { render.Error(w, err) return } + ctx := r.Context() name := chi.URLParam(r, "name") - _old, err := h.auth.LoadProvisionerByName(name) + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + p, err := auth.LoadProvisionerByName(name) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } - old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) + old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } @@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { + if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { render.Error(w, err) return } From 0446e823208559907ab89f6efe5ac88f2ba43edf Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 12:05:19 -0700 Subject: [PATCH 10/40] Add context methods for the authority database --- ca/ca.go | 9 +++++---- db/db.go | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 2df52555..f5cf30db 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -308,11 +308,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // buildContext builds the server base context. func buildContext(a *authority.Authority) context.Context { ctx := authority.NewContext(context.Background(), a) - - if db := a.GetAdminDatabase(); db != nil { - ctx = admin.NewContext(ctx, db) + if authDB := a.GetDatabase(); authDB != nil { + ctx = db.NewContext(ctx, authDB) + } + if adminDB := a.GetAdminDatabase(); adminDB != nil { + ctx = admin.NewContext(ctx, adminDB) } - return ctx } diff --git a/db/db.go b/db/db.go index eccaf801..c4b1c8a7 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "crypto/x509" "encoding/json" "strconv" @@ -58,6 +59,29 @@ type AuthDB interface { Shutdown() error } +type dbKey struct{} + +// NewContext adds the given authority database to the context. +func NewContext(ctx context.Context, db AuthDB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current authority database from the given context. +func FromContext(ctx context.Context) (db AuthDB, ok bool) { + db, ok = ctx.Value(dbKey{}).(AuthDB) + return +} + +// MustFromContext returns the current database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) AuthDB { + if db, ok := FromContext(ctx); !ok { + panic("authority database is not in the context") + } else { + return db + } +} + // DB is a wrapper over the nosql.DB interface. type DB struct { nosql.DB From bd412c9f4285aeaec0f0a1b9488492e9516d52d6 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 12:11:00 -0700 Subject: [PATCH 11/40] Add context methods for the acme database --- acme/db.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/acme/db.go b/acme/db.go index 412276fd..a8637f57 100644 --- a/acme/db.go +++ b/acme/db.go @@ -48,6 +48,29 @@ type DB interface { UpdateOrder(ctx context.Context, o *Order) error } +type dbKey struct{} + +// NewContext adds the given acme database to the context. +func NewContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current acme database from the given context. +func FromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustFromContext returns the current database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) DB { + if db, ok := FromContext(ctx); !ok { + panic("acme database is not in the context") + } else { + return db + } +} + // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { From 8bd4e1d73e3886894f5a667cb41aea630c3ade0f Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 12:13:16 -0700 Subject: [PATCH 12/40] Inject the acme database in the context --- ca/ca.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index f5cf30db..80756559 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -282,7 +282,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Create context with all the necessary values. - baseContext := buildContext(auth) + baseContext := buildContext(auth, acmeDB) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -306,7 +306,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // buildContext builds the server base context. -func buildContext(a *authority.Authority) context.Context { +func buildContext(a *authority.Authority, acmeDB acme.DB) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -314,6 +314,9 @@ func buildContext(a *authority.Authority) context.Context { if adminDB := a.GetAdminDatabase(); adminDB != nil { ctx = admin.NewContext(ctx, adminDB) } + if acmeDB != nil { + ctx = acme.NewContext(ctx, acmeDB) + } return ctx } From 439cb81b133d994ecfc6a34b03c2067c4bb02d76 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 12:16:16 -0700 Subject: [PATCH 13/40] Use admin Route function --- ca/ca.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 80756559..783255ce 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -221,9 +221,8 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder) mux.Route("/admin", func(r chi.Router) { - adminHandler.Route(r) + adminAPI.Route(r, acmeAdminResponder) }) } } From d13537d426cce5a121115b40758105f8f46380ce Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 15:42:26 -0700 Subject: [PATCH 14/40] Use context in the acme handlers. --- acme/api/account.go | 33 +++--- acme/api/eab.go | 5 +- acme/api/handler.go | 234 +++++++++++++++++++++++++++-------------- acme/api/middleware.go | 80 ++++++++------ acme/api/order.go | 47 +++++---- acme/api/revoke.go | 18 ++-- 6 files changed, 267 insertions(+), 150 deletions(-) diff --git a/acme/api/account.go b/acme/api/account.go index ade51aef..8c8c4d97 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -67,7 +67,7 @@ func (u *UpdateAccountRequest) Validate() error { } // NewAccount is the handler resource for creating new ACME accounts. -func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { +func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { @@ -114,18 +114,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - eak, err := h.validateExternalAccountBinding(ctx, &nar) + eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { render.Error(w, err) return } + db := acme.MustFromContext(ctx) acc = &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, } - if err := h.db.CreateAccount(ctx, acc); err != nil { + if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } @@ -136,7 +137,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { + if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } @@ -147,14 +148,15 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - h.linker.LinkAccount(ctx, acc) + o := optionsFromContext(ctx) + o.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { +func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -187,16 +189,18 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - if err := h.db.UpdateAccount(ctx, acc); err != nil { + db := acme.MustFromContext(ctx) + if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } } - h.linker.LinkAccount(ctx, acc) + o := optionsFromContext(ctx) + o.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -210,7 +214,7 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { +func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -222,13 +226,16 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } - orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) + + db := acme.MustFromContext(ctx) + orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - h.linker.LinkOrdersByAccountID(ctx, orders) + o := optionsFromContext(ctx) + o.linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/eab.go b/acme/api/eab.go index 3660d066..2c94a4ed 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -16,7 +16,7 @@ type ExternalAccountBinding struct { } // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. -func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { +func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") @@ -47,7 +47,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acmeErr } - externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) + db := acme.MustFromContext(ctx) + externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") diff --git a/acme/api/handler.go b/acme/api/handler.go index 10eb22cb..04680656 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -16,6 +16,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) @@ -39,38 +40,89 @@ type payloadInfo struct { isEmptyJSON bool } -// Handler is the ACME API request handler. -type Handler struct { - db acme.DB - backdate provisioner.Duration - ca acme.CertificateAuthority - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions - prerequisitesChecker func(ctx context.Context) (bool, error) -} - // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { - Backdate provisioner.Duration // DB storage backend that impements the acme.DB interface. + // + // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB + + // CA is the certificate authority interface. + // + // Deprecated: use authority.NewContext(context.Context, *authority.Authority) + CA acme.CertificateAuthority + + // Backdate is the duration that the CA will substract from the current time + // to set the NotBefore in the certificate. + Backdate provisioner.Duration + // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. DNS string + // Prefix is a URL path prefix under which the ACME api is served. This // prefix is required to generate accurate ACME links. // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // "acme" is the prefix from which the ACME api is accessed. Prefix string - CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) + + linker Linker + validateChallengeOptions *acme.ValidateChallengeOptions +} + +type optionsKey struct{} + +func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { + return context.WithValue(ctx, optionsKey{}, o) +} + +func optionsFromContext(ctx context.Context) *HandlerOptions { + o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) + if !ok { + panic("handler options are not in the context") + } + return o +} + +var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return authority.MustFromContext(ctx) +} + +// Handler is the ACME API request handler. +type Handler struct { + opts *HandlerOptions +} + +// Route traffic and implement the Router interface. +// +// Deprecated: Use api.Route(r Router, opts *HandlerOptions) +func (h *Handler) Route(r api.Router) { + Route(r, h.opts) } // NewHandler returns a new ACME API handler. +// +// Deprecated: Use api.Route(r Router, opts *HandlerOptions) func NewHandler(ops HandlerOptions) api.RouterHandler { + return &Handler{ + opts: &ops, + } +} + +// Route traffic and implement the Router interface. +func Route(r api.Router, opts *HandlerOptions) { + // by default all prerequisites are met + if opts.PrerequisitesChecker == nil { + opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) { + return true, nil + } + } + transport := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -83,67 +135,85 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } - prerequisitesChecker := func(ctx context.Context) (bool, error) { - // by default all prerequisites are met - return true, nil - } - if ops.PrerequisitesChecker != nil { - prerequisitesChecker = ops.PrerequisitesChecker - } - return &Handler{ - ca: ops.CA, - db: ops.DB, - backdate: ops.Backdate, - linker: NewLinker(ops.DNS, ops.Prefix), - validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, + + opts.linker = NewLinker(opts.DNS, opts.Prefix) + opts.validateChallengeOptions = &acme.ValidateChallengeOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) }, - prerequisitesChecker: prerequisitesChecker, } -} -// Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - getPath := h.linker.GetUnescapedPathSuffix - // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) + withOptions := func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // For backward compatibility with NewHandler. + if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + if opts.DB != nil { + ctx = acme.NewContext(ctx, opts.DB) + } + + ctx = newOptionsContext(ctx, opts) + next(w, r.WithContext(ctx)) + } + } validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) + return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))) } - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) + getPath := opts.linker.GetUnescapedPathSuffix + + // Standard ACME API + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) + + r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), + extractPayloadByJWK(NewAccount)) + r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(GetOrUpdateAccount)) + r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(NotImplemented)) + r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), + extractPayloadByKid(NewOrder)) + r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(isPostAsGet(GetOrder))) + r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) + r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(FinalizeOrder)) + r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), + extractPayloadByKid(isPostAsGet(GetAuthorization))) + r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + extractPayloadByKid(GetChallenge)) + r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), + extractPayloadByKid(isPostAsGet(GetCertificate))) + r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), + extractPayloadByKidOrJWK(RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response // by middleware by default. -func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { +func GetNonce(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { w.WriteHeader(http.StatusOK) } else { @@ -179,8 +249,10 @@ func (d *Directory) ToLog() (interface{}, error) { // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. -func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { +func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { render.Error(w, err) @@ -188,11 +260,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { } render.JSON(w, &Directory{ - NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: o.linker.GetLink(ctx, NewNonceLinkType), + NewAccount: o.linker.GetLink(ctx, NewAccountLinkType), + NewOrder: o.linker.GetLink(ctx, NewOrderLinkType), + RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType), + KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -201,19 +273,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. -func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { +func NotImplemented(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. -func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { +func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) + az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return @@ -223,20 +298,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } - if err = az.UpdateStatus(ctx, h.db); err != nil { + if err = az.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } - h.linker.LinkAuthorization(ctx, az) + o.linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. -func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { +func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -257,7 +335,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // we'll just ignore the body. azID := chi.URLParam(r, "authzID") - ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) + ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return @@ -273,29 +351,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - h.linker.LinkChallenge(ctx, ch, azID) + o.linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up")) + w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. -func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { +func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - certID := chi.URLParam(r, "certID") - cert, err := h.db.GetCertificate(ctx, certID) + certID := chi.URLParam(r, "certID") + cert, err := db.GetCertificate(ctx, certID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 10f7841f..564a16f5 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -31,15 +31,15 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// baseURLFromRequest determines the base URL which should be used for +// getBaseURLFromRequest determines the base URL which should be used for // constructing link URLs in e.g. the ACME directory result by taking the // request Host into consideration. // // If the Request.Host is an empty string, we return an empty string, to // indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in -// constructing ACME link URLs. -func baseURLFromRequest(r *http.Request) *url.URL { +// function returns a non-empty result, then this should be used in constructing +// ACME link URLs. +func getBaseURLFromRequest(r *http.Request) *url.URL { // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go // for an implementation that allows HTTP requests using the x-forwarded-proto // header. @@ -53,17 +53,18 @@ func baseURLFromRequest(r *http.Request) *url.URL { // baseURLFromRequest is a middleware that extracts and caches the baseURL // from the request. // E.g. https://ca.smallstep.com/ -func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { +func baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) + ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } // addNonce is a middleware that adds a nonce to the response header. -func (h *Handler) addNonce(next nextHTTP) nextHTTP { +func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.db.CreateNonce(r.Context()) + db := acme.MustFromContext(r.Context()) + nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) return @@ -77,25 +78,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. -func (h *Handler) addDirLink(next nextHTTP) nextHTTP { +func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) + ctx := r.Context() + opts := optionsFromContext(ctx) + + w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. -func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { +func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { var expected []string - p, err := provisionerFromContext(r.Context()) + ctx := r.Context() + opts := optionsFromContext(ctx) + + p, err := provisionerFromContext(ctx) if err != nil { render.Error(w, err) return } - u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -117,7 +124,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. -func (h *Handler) parseJWS(next nextHTTP) nextHTTP { +func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -149,10 +156,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * “nonce” (defined in Section 6.5) // * “url” (defined in Section 6.4) // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below -func (h *Handler) validateJWS(next nextHTTP) nextHTTP { +func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -202,7 +211,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { } // Check the validity/freshness of the Nonce. - if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { + if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { render.Error(w, err) return } @@ -235,10 +244,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. -func (h *Handler) extractJWK(next nextHTTP) nextHTTP { +func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -264,7 +275,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account OR continue to generate a new one OR continue Revoke with certificate private key - acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) + acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): // For NewAccount and Revoke requests ... @@ -285,7 +296,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { +func lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() nameEscaped := chi.URLParam(r, "provisionerID") @@ -294,7 +305,7 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } - p, err := h.ca.LoadProvisionerByName(name) + p, err := mustAuthority(r.Context()).LoadProvisionerByName(name) if err != nil { render.Error(w, err) return @@ -311,10 +322,12 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. -func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { +func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ok, err := h.prerequisitesChecker(ctx) + opts := optionsFromContext(ctx) + + ok, err := opts.PrerequisitesChecker(ctx) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return @@ -330,16 +343,19 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { +func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + opts := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return } - kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -349,7 +365,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.db.GetAccount(ctx, accID) + acc, err := db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) @@ -372,7 +388,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // extractOrLookupJWK forwards handling to either extractJWK or // lookupJWK based on the presence of a JWK or a KID, respectively. -func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { +func extractOrLookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -385,13 +401,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { // and it can be used to check if a JWK exists. This flow is used when the ACME client // signed the payload with a certificate private key. if canExtractJWKFrom(jws) { - h.extractJWK(next)(w, r) + extractJWK(next)(w, r) return } // default to looking up the JWK based on KeyID. This flow is used when the ACME client // signed the payload with an account private key. - h.lookupJWK(next)(w, r) + lookupJWK(next)(w, r) } } @@ -408,7 +424,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { +func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -440,7 +456,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). -func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { +func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { diff --git a/acme/api/order.go b/acme/api/order.go index 99eb0e95..ebd0c7f5 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -68,7 +68,7 @@ var defaultOrderExpiry = time.Hour * 24 var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. -func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { +func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -117,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } - if err := h.newAuthorization(ctx, az); err != nil { + if err := newAuthorization(ctx, az); err != nil { render.Error(w, err) return } @@ -136,18 +136,20 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - if err := h.db.CreateOrder(ctx, o); err != nil { + db := acme.MustFromContext(ctx) + if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } -func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { +func newAuthorization(ctx context.Context, az *acme.Authorization) error { if strings.HasPrefix(az.Identifier.Value, "*.") { az.Wildcard = true az.Identifier = acme.Identifier{ @@ -163,6 +165,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } + + db := acme.MustFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -172,19 +176,19 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) Token: az.Token, Status: acme.StatusPending, } - if err := h.db.CreateChallenge(ctx, ch); err != nil { + if err := db.CreateChallenge(ctx, ch); err != nil { return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges[i] = ch } - if err = h.db.CreateAuthorization(ctx, az); err != nil { + if err = db.CreateAuthorization(ctx, az); err != nil { return acme.WrapErrorISE(err, "error creating authorization") } return nil } // GetOrder ACME api for retrieving an order. -func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { +func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -196,7 +200,9 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + + db := acme.MustFromContext(ctx) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -211,19 +217,20 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.UpdateStatus(ctx, h.db); err != nil { + if err = o.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. -func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { +func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -251,7 +258,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + db := acme.MustFromContext(ctx) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -266,14 +274,17 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + + ca := mustAuthority(ctx) + if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 4b71bc22..55774aea 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -26,8 +26,7 @@ type revokePayload struct { } // RevokeCert attempts to revoke a certificate. -func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { - +func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { @@ -68,8 +67,9 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { return } + db := acme.MustFromContext(ctx) serial := certToBeRevoked.SerialNumber.String() - dbCert, err := h.db.GetCertificateBySerial(ctx, serial) + dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return @@ -87,7 +87,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { render.Error(w, acmeErr) return @@ -103,7 +103,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } } - hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) + ca := mustAuthority(ctx) + hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return @@ -130,14 +131,15 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } options := revokeOptions(serial, certToBeRevoked, reasonCode) - err = h.ca.Revoke(ctx, options) + err = ca.Revoke(ctx, options) if err != nil { render.Error(w, wrapRevokeErr(err)) return } logRevoke(w, options) - w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) + o := optionsFromContext(ctx) + w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index")) w.Write(nil) } @@ -148,7 +150,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { // the identifiers in the certificate are extracted and compared against the (valid) Authorizations // that are stored for the ACME Account. If these sets match, the Account is considered authorized // to revoke the certificate. If this check fails, the client will receive an unauthorized error. -func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { +func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { if !account.IsValid() { return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) } From 216d8f0efbb95336c948faec338db1c5cd56e97c Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 15:44:41 -0700 Subject: [PATCH 15/40] Handle acme requests with the new api --- ca/ca.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 783255ce..933db275 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -200,20 +200,18 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { return nil, errors.Wrap(err, "error configuring ACME DB interface") } } - acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ + acmeOptions := &acmeAPI.HandlerOptions{ Backdate: *cfg.AuthorityConfig.Backdate, - DB: acmeDB, DNS: dns, Prefix: prefix, - CA: auth, - }) + } mux.Route("/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) + acmeAPI.Route(r, acmeOptions) }) // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 // of the ACME spec. mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) + acmeAPI.Route(r, acmeOptions) }) // Admin API Router From 688f9ceb5648805502d21db6f285c9453395b767 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 18:02:37 -0700 Subject: [PATCH 16/40] Add scep authority to context. --- ca/ca.go | 10 +++++++--- scep/authority.go | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/ca/ca.go b/ca/ca.go index 933db275..a8ecbb05 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -225,9 +225,10 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } } + var scepAuthority *scep.Authority if ca.shouldServeSCEPEndpoints() { scepPrefix := "scep" - scepAuthority, err := scep.New(auth, scep.AuthorityOptions{ + scepAuthority, err = scep.New(auth, scep.AuthorityOptions{ Service: auth.GetSCEPService(), DNS: dns, Prefix: scepPrefix, @@ -279,7 +280,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Create context with all the necessary values. - baseContext := buildContext(auth, acmeDB) + baseContext := buildContext(auth, scepAuthority, acmeDB) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -303,7 +304,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, acmeDB acme.DB) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -311,6 +312,9 @@ func buildContext(a *authority.Authority, acmeDB acme.DB) context.Context { if adminDB := a.GetAdminDatabase(); adminDB != nil { ctx = admin.NewContext(ctx, adminDB) } + if scepAuthority != nil { + ctx = scep.NewContext(ctx, scepAuthority) + } if acmeDB != nil { ctx = acme.NewContext(ctx, acmeDB) } diff --git a/scep/authority.go b/scep/authority.go index 71f92152..946fa948 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -27,6 +27,29 @@ type Authority struct { signAuth SignAuthority } +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + if a, ok := FromContext(ctx); !ok { + panic("scep authority is not in the context") + } else { + return a + } +} + // AuthorityOptions required to create a new SCEP Authority. type AuthorityOptions struct { // Service provides the certificate chain, the signer and the decrypter to the Authority @@ -40,6 +63,20 @@ type AuthorityOptions struct { Prefix string } +type optionsKey struct{} + +func newOptionsContext(ctx context.Context, o *AuthorityOptions) context.Context { + return context.WithValue(ctx, optionsKey{}, o) +} + +func optionsFromContext(ctx context.Context) *AuthorityOptions { + o, ok := ctx.Value(optionsKey{}).(*AuthorityOptions) + if !ok { + panic("scep options are not in the context") + } + return o +} + // SignAuthority is the interface for a signing authority type SignAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -163,7 +200,6 @@ func (a *Authority) GetCACertificates(ctx context.Context) ([]*x509.Certificate, // DecryptPKIEnvelope decrypts an enveloped message func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error { - p7c, err := pkcs7.Parse(msg.P7.Content) if err != nil { return fmt.Errorf("error parsing pkcs7 content: %w", err) @@ -210,7 +246,6 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err // SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials // returns a new PKIMessage with CertRep data func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) { - // TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate // to be signed, which can be performed asynchronously / out-of-band. In that case a client can // poll for the status. It seems to be similar as what can happen in ACME, so might want to model From 42435ace642edb1fd7abc0232c8afc4dc66f9287 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 18:06:27 -0700 Subject: [PATCH 17/40] Use scep authority from context This commit also converts all the methods from the handler to functions. --- acme/api/handler.go | 2 +- scep/api/api.go | 82 ++++++++++++++++++++++----------------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 04680656..4b916404 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -84,7 +84,7 @@ func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { func optionsFromContext(ctx context.Context) *HandlerOptions { o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) if !ok { - panic("handler options are not in the context") + panic("acme options are not in the context") } return o } diff --git a/scep/api/api.go b/scep/api/api.go index 31f0f10d..0d62904d 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -48,29 +48,32 @@ type response struct { } // handler is the SCEP request handler. -type handler struct { - auth *scep.Authority +type handler struct{} + +// Route traffic and implement the Router interface. +// +// Deprecated: use scep.Route(r api.Router) +func (h *handler) Route(r api.Router) { + Route(r) } // New returns a new SCEP API router. +// +// Deprecated: use scep.Route(r api.Router) func New(auth *scep.Authority) api.RouterHandler { - return &handler{ - auth: auth, - } + return &handler{} } // Route traffic and implement the Router interface. -func (h *handler) Route(r api.Router) { - getLink := h.auth.GetLinkExplicit - r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) +func Route(r api.Router) { + r.MethodFunc(http.MethodGet, "/{provisionerName}/*", lookupProvisioner(Get)) + r.MethodFunc(http.MethodGet, "/{provisionerName}", lookupProvisioner(Get)) + r.MethodFunc(http.MethodPost, "/{provisionerName}/*", lookupProvisioner(Post)) + r.MethodFunc(http.MethodPost, "/{provisionerName}", lookupProvisioner(Post)) } // Get handles all SCEP GET requests -func (h *handler) Get(w http.ResponseWriter, r *http.Request) { - +func Get(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep get request: %w", err)) @@ -82,9 +85,9 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { switch req.Operation { case opnGetCACert: - res, err = h.GetCACert(ctx) + res, err = GetCACert(ctx) case opnGetCACaps: - res, err = h.GetCACaps(ctx) + res, err = GetCACaps(ctx) case opnPKIOperation: // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though default: @@ -100,20 +103,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { } // Post handles all SCEP POST requests -func (h *handler) Post(w http.ResponseWriter, r *http.Request) { - +func Post(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep post request: %w", err)) return } - ctx := r.Context() var res response - switch req.Operation { case opnPKIOperation: - res, err = h.PKIOperation(ctx, req) + res, err = PKIOperation(r.Context(), req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } @@ -127,7 +127,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) { } func decodeRequest(r *http.Request) (request, error) { - defer r.Body.Close() method := r.Method @@ -179,9 +178,8 @@ func decodeRequest(r *http.Request) (request, error) { // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { +func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { @@ -189,7 +187,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - p, err := h.auth.LoadProvisionerByName(provisionerName) + ctx := r.Context() + auth := scep.MustFromContext(ctx) + p, err := auth.LoadProvisionerByName(provisionerName) if err != nil { fail(w, err) return @@ -201,16 +201,15 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - ctx := r.Context() ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) } } // GetCACert returns the CA certificates in a SCEP response -func (h *handler) GetCACert(ctx context.Context) (response, error) { - - certs, err := h.auth.GetCACertificates(ctx) +func GetCACert(ctx context.Context) (response, error) { + auth := scep.MustFromContext(ctx) + certs, err := auth.GetCACertificates(ctx) if err != nil { return response{}, err } @@ -241,9 +240,9 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) { } // GetCACaps returns the CA capabilities in a SCEP response -func (h *handler) GetCACaps(ctx context.Context) (response, error) { - - caps := h.auth.GetCACaps(ctx) +func GetCACaps(ctx context.Context) (response, error) { + auth := scep.MustFromContext(ctx) + caps := auth.GetCACaps(ctx) res := response{ Operation: opnGetCACaps, @@ -254,8 +253,7 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) { } // PKIOperation performs PKI operations and returns a SCEP response -func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { - +func PKIOperation(ctx context.Context, req request) (response, error) { // parse the message using microscep implementation microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { @@ -280,7 +278,8 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro P7: p7, } - if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil { + auth := scep.MustFromContext(ctx) + if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil { return response{}, err } @@ -293,13 +292,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients. // We'll have to see how it works out. if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq { - challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) + challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) } if !challengeMatches { // TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too. - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) } } @@ -311,9 +310,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification // of the client cert is not. - certRep, err := h.auth.SignCSR(ctx, csr, msg) + certRep, err := auth.SignCSR(ctx, csr, msg) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } res := response{ @@ -350,8 +349,9 @@ func fail(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusInternalServerError) } -func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { - certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) +func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { + auth := scep.MustFromContext(ctx) + certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { return response{}, err } From bb8d85a20128ce772f9f6709abe8e0af0ae37a85 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 27 Apr 2022 19:08:16 -0700 Subject: [PATCH 18/40] Fix unit tests - work in progress --- acme/api/account_test.go | 12 +++---- acme/api/eab_test.go | 8 ++--- acme/api/handler_test.go | 21 ++++++------ acme/api/middleware_test.go | 64 ++++++++++++++++++------------------- acme/api/order_test.go | 16 +++++----- acme/api/revoke_test.go | 8 ++--- 6 files changed, 65 insertions(+), 64 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 4c3404ec..3fbabfe5 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -315,11 +315,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrdersByAccountID(w, req) + GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -759,11 +759,11 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.NewAccount(w, req) + NewAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -959,11 +959,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrUpdateAccount(w, req) + GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index dce9f36d..1c76618b 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -762,10 +762,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - db: tc.db, - } - got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar) + // h := &Handler{ + // db: tc.db, + // } + got, err := validateExternalAccountBinding(tc.ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 67f7df30..fcc33a87 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -38,10 +38,10 @@ func TestHandler_GetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name - h.GetNonce(w, req) + GetNonce(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -53,6 +53,7 @@ func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) { linker := NewLinker("ca.smallstep.com", "acme") + _ = linker type test struct { ctx context.Context statusCode int @@ -130,11 +131,11 @@ func TestHandler_GetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: linker} + // h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetDirectory(w, req) + GetDirectory(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -304,11 +305,11 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAuthorization(w, req) + GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -447,11 +448,11 @@ func TestHandler_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + // h := &Handler{db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetCertificate(w, req) + GetCertificate(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -703,11 +704,11 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} + // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetChallenge(w, req) + GetChallenge(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 8003fa16..f192e67e 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -70,7 +70,7 @@ func Test_baseURLFromRequest(t *testing.T) { if tc.requestPreparer != nil { tc.requestPreparer(request) } - result := baseURLFromRequest(request) + result := getBaseURLFromRequest(request) if result == nil || tc.expectedResult == nil { assert.Equals(t, result, tc.expectedResult) } else if result.String() != tc.expectedResult.String() { @@ -81,7 +81,7 @@ func Test_baseURLFromRequest(t *testing.T) { } func TestHandler_baseURLFromRequest(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req.Host = "test.ca.smallstep.com:8080" w := httptest.NewRecorder() @@ -94,7 +94,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) { } } - h.baseURLFromRequest(next)(w, req) + baseURLFromRequest(next)(w, req) req = httptest.NewRequest("GET", "/foo", nil) req.Host = "" @@ -103,7 +103,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) { assert.Equals(t, baseURLFromContext(r.Context()), nil) } - h.baseURLFromRequest(next)(w, req) + baseURLFromRequest(next)(w, req) } func TestHandler_addNonce(t *testing.T) { @@ -139,10 +139,10 @@ func TestHandler_addNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + // h := &Handler{db: tc.db} req := httptest.NewRequest("GET", u, nil) w := httptest.NewRecorder() - h.addNonce(testNext)(w, req) + addNonce(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -195,11 +195,11 @@ func TestHandler_addDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker} + // h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.addDirLink(testNext)(w, req) + addDirLink(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -242,7 +242,7 @@ func TestHandler_verifyContentType(t *testing.T) { "fail/provisioner-not-set": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, url: u, ctx: context.Background(), @@ -254,7 +254,7 @@ func TestHandler_verifyContentType(t *testing.T) { "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, url: u, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), @@ -266,7 +266,7 @@ func TestHandler_verifyContentType(t *testing.T) { "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", @@ -277,7 +277,7 @@ func TestHandler_verifyContentType(t *testing.T) { "ok": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", @@ -287,7 +287,7 @@ func TestHandler_verifyContentType(t *testing.T) { "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkix-cert", @@ -297,7 +297,7 @@ func TestHandler_verifyContentType(t *testing.T) { "ok/certificate/jose+json": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", @@ -307,7 +307,7 @@ func TestHandler_verifyContentType(t *testing.T) { "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ h: Handler{ - linker: NewLinker("dns", "acme"), + // linker: NewLinker("dns", "acme"), }, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkcs7-mime", @@ -326,7 +326,7 @@ func TestHandler_verifyContentType(t *testing.T) { req = req.WithContext(tc.ctx) req.Header.Add("Content-Type", tc.contentType) w := httptest.NewRecorder() - tc.h.verifyContentType(testNext)(w, req) + verifyContentType(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -390,11 +390,11 @@ func TestHandler_isPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.isPostAsGet(testNext)(w, req) + isPostAsGet(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -481,10 +481,10 @@ func TestHandler_parseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, tc.body) w := httptest.NewRecorder() - h.parseJWS(tc.next)(w, req) + parseJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -679,11 +679,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.verifyAndExtractJWSPayload(tc.next)(w, req) + verifyAndExtractJWSPayload(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -881,11 +881,11 @@ func TestHandler_lookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + // h := &Handler{db: tc.db, linker: tc.linker} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.lookupJWK(tc.next)(w, req) + lookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1077,11 +1077,11 @@ func TestHandler_extractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + // h := &Handler{db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.extractJWK(tc.next)(w, req) + extractJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1444,11 +1444,11 @@ func TestHandler_validateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + // h := &Handler{db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.validateJWS(tc.next)(w, req) + validateJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1628,11 +1628,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + // h := &Handler{db: tc.db, linker: tc.linker} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.extractOrLookupJWK(tc.next)(w, req) + extractOrLookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1717,11 +1717,11 @@ func TestHandler_checkPrerequisites(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + // h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.checkPrerequisites(tc.next)(w, req) + checkPrerequisites(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 1ce034e7..f0a2d1d4 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -421,11 +421,11 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrder(w, req) + GetOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -636,8 +636,8 @@ func TestHandler_newAuthorization(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - h := &Handler{db: tc.db} - if err := h.newAuthorization(context.Background(), tc.az); err != nil { + // h := &Handler{db: tc.db} + if err := newAuthorization(context.Background(), tc.az); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *acme.Error: @@ -1334,11 +1334,11 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.NewOrder(w, req) + NewOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1624,11 +1624,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.FinalizeOrder(w, req) + FinalizeOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 4ff54405..3a0ba70d 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -1057,11 +1057,11 @@ func TestHandler_RevokeCert(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} req := httptest.NewRequest("POST", revokeURL, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.RevokeCert(w, req) + RevokeCert(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1198,8 +1198,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} - acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) + // h := &Handler{db: tc.db} + acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) expectError := tc.err != nil gotError := acmeErr != nil From 55b0f7282144f1c5a343d0c623248c9108acf706 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 28 Apr 2022 15:14:15 -0700 Subject: [PATCH 19/40] Add context methods for the acme linker. --- acme/api/linker.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/acme/api/linker.go b/acme/api/linker.go index a605ffc3..114ba698 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -41,6 +41,29 @@ type Linker interface { LinkOrdersByAccountID(ctx context.Context, orders []string) } +type linkerKey struct{} + +// NewLinkerContext adds the given linker to the context. +func NewLinkerContext(ctx context.Context, v Linker) context.Context { + return context.WithValue(ctx, linkerKey{}, v) +} + +// LinkerFromContext returns the current linker from the given context. +func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { + v, ok = ctx.Value(linkerKey{}).(Linker) + return +} + +// MustLinkerFromContext returns the current linker from the given context. It +// will panic if it's not in the context. +func MustLinkerFromContext(ctx context.Context) Linker { + if v, ok := LinkerFromContext(ctx); !ok { + panic("acme linker is not the context") + } else { + return v + } +} + // linker generates ACME links. type linker struct { prefix string From fddd6f7d9542c85b050c2d5ac04a8cfb07a24af4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 28 Apr 2022 15:15:50 -0700 Subject: [PATCH 20/40] Move linker to the acme package. --- acme/{api => }/linker.go | 2 +- acme/{api => }/linker_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename acme/{api => }/linker.go (99%) rename acme/{api => }/linker_test.go (99%) diff --git a/acme/api/linker.go b/acme/linker.go similarity index 99% rename from acme/api/linker.go rename to acme/linker.go index 114ba698..8dc87b14 100644 --- a/acme/api/linker.go +++ b/acme/linker.go @@ -1,4 +1,4 @@ -package api +package acme import ( "context" diff --git a/acme/api/linker_test.go b/acme/linker_test.go similarity index 99% rename from acme/api/linker_test.go rename to acme/linker_test.go index 74c2c8b0..a8612e6b 100644 --- a/acme/api/linker_test.go +++ b/acme/linker_test.go @@ -1,4 +1,4 @@ -package api +package acme import ( "context" From d1f75f172078370776d74edde2086d94122ccbd9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 28 Apr 2022 19:15:18 -0700 Subject: [PATCH 21/40] Refactor ACME api. --- acme/api/account.go | 25 ++-- acme/api/eab.go | 3 +- acme/api/handler.go | 201 ++++++++++++++------------------ acme/api/middleware.go | 106 +++++------------ acme/api/order.go | 47 +++----- acme/api/revoke.go | 14 +-- acme/challenge.go | 35 +++--- acme/client.go | 79 +++++++++++++ acme/common.go | 72 +++++++++++- acme/db.go | 16 +-- acme/linker.go | 257 ++++++++++++++++++++++++----------------- acme/linker_test.go | 43 ++++--- ca/ca.go | 36 +++--- 13 files changed, 503 insertions(+), 431 deletions(-) create mode 100644 acme/client.go diff --git a/acme/api/account.go b/acme/api/account.go index 8c8c4d97..d88c7066 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -69,6 +69,9 @@ func (u *UpdateAccountRequest) Validate() error { // NewAccount is the handler resource for creating new ACME accounts. func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -120,7 +123,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) acc = &acme.Account{ Key: jwk, Contact: nar.Contact, @@ -148,16 +150,18 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - o := optionsFromContext(ctx) - o.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -189,7 +193,6 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - db := acme.MustFromContext(ctx) if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return @@ -197,10 +200,9 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } } - o := optionsFromContext(ctx) - o.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -216,6 +218,9 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -227,15 +232,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - o := optionsFromContext(ctx) - o.linker.LinkOrdersByAccountID(ctx, orders) + linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/eab.go b/acme/api/eab.go index 2c94a4ed..13928ac4 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -47,7 +47,7 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) return nil, acmeErr } - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { @@ -103,7 +103,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool { // o The "nonce" field MUST NOT be present // o The "url" field MUST be set to the same value as the outer JWS func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { - if jws == nil { return "", acme.NewErrorISE("no JWS provided") } diff --git a/acme/api/handler.go b/acme/api/handler.go index 4b916404..efe2b780 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -2,12 +2,10 @@ package api import ( "context" - "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" - "net" "net/http" "time" @@ -70,144 +68,117 @@ type HandlerOptions struct { // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) - - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions -} - -type optionsKey struct{} - -func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { - return context.WithValue(ctx, optionsKey{}, o) -} - -func optionsFromContext(ctx context.Context) *HandlerOptions { - o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) - if !ok { - panic("acme options are not in the context") - } - return o } var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { return authority.MustFromContext(ctx) } -// Handler is the ACME API request handler. -type Handler struct { +// handler is the ACME API request handler. +type handler struct { opts *HandlerOptions } // Route traffic and implement the Router interface. -// -// Deprecated: Use api.Route(r Router, opts *HandlerOptions) -func (h *Handler) Route(r api.Router) { - Route(r, h.opts) +func (h *handler) Route(r api.Router) { + route(r, h.opts) } // NewHandler returns a new ACME API handler. -// -// Deprecated: Use api.Route(r Router, opts *HandlerOptions) -func NewHandler(ops HandlerOptions) api.RouterHandler { - return &Handler{ - opts: &ops, +func NewHandler(opts HandlerOptions) api.RouterHandler { + return &handler{ + opts: &opts, } } -// Route traffic and implement the Router interface. -func Route(r api.Router, opts *HandlerOptions) { - // by default all prerequisites are met - if opts.PrerequisitesChecker == nil { - opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) { - return true, nil - } - } - - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - client := http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - - opts.linker = NewLinker(opts.DNS, opts.Prefix) - opts.validateChallengeOptions = &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - } - - withOptions := func(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +// Route traffic and implement the Router interface. This method requires that +// all the acme components, authority, db, client, linker, and prerequisite +// checker to be present in the context. +func Route(r api.Router) { + route(r, nil) +} - // For backward compatibility with NewHandler. - if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { - ctx = authority.NewContext(ctx, ca) +func route(r api.Router, opts *HandlerOptions) { + var withContext func(next nextHTTP) nextHTTP + + // For backward compatibility this block adds will add a new middleware that + // will set the ACME components to the context. + if opts != nil { + client := acme.NewClient() + linker := acme.NewLinker(opts.DNS, opts.Prefix) + + withContext = func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker) + next(w, r.WithContext(ctx)) } - if opts.DB != nil { - ctx = acme.NewContext(ctx, opts.DB) + } + } else { + withContext = func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + next(w, r) } - - ctx = newOptionsContext(ctx, opts) - next(w, r.WithContext(ctx)) } } + commonMiddleware := func(next nextHTTP) nextHTTP { + return withContext(func(w http.ResponseWriter, r *http.Request) { + // Linker middleware gets the provisioner and current url from the + // request and sets them in the context. + linker := acme.MustLinkerFromContext(r.Context()) + linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) + }) + } validatingMiddleware := func(next nextHTTP) nextHTTP { - return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))))) + return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))) } - getPath := opts.linker.GetUnescapedPathSuffix + getPath := acme.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) - - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), + r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + + r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), + r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), + r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(isPostAsGet(GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), + r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), + r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(isPostAsGet(GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), + r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(isPostAsGet(GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), + r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(RevokeCert)) } @@ -251,20 +222,20 @@ func (d *Directory) ToLog() (interface{}, error) { // for client configuration. func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - acmeProv, err := acmeProvisionerFromContext(ctx) + fmt.Println(acmeProv, err) if err != nil { render.Error(w, err) return } + linker := acme.MustLinkerFromContext(ctx) render.JSON(w, &Directory{ - NewNonce: o.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: o.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: o.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), + NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), + NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), + RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), + KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -280,8 +251,8 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) { // GetAuthorization ACME api for retrieving an Authz. func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { @@ -303,17 +274,17 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) { return } - o.linker.LinkAuthorization(ctx, az) + linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { @@ -351,22 +322,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - o.linker.LinkChallenge(ctx, ch, azID) + linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) + w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 564a16f5..09e88b8d 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -31,39 +31,10 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// getBaseURLFromRequest determines the base URL which should be used for -// constructing link URLs in e.g. the ACME directory result by taking the -// request Host into consideration. -// -// If the Request.Host is an empty string, we return an empty string, to -// indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in constructing -// ACME link URLs. -func getBaseURLFromRequest(r *http.Request) *url.URL { - // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go - // for an implementation that allows HTTP requests using the x-forwarded-proto - // header. - - if r.Host == "" { - return nil - } - return &url.URL{Scheme: "https", Host: r.Host} -} - -// baseURLFromRequest is a middleware that extracts and caches the baseURL -// from the request. -// E.g. https://ca.smallstep.com/ -func baseURLFromRequest(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r)) - next(w, r.WithContext(ctx)) - } -} - // addNonce is a middleware that adds a nonce to the response header. func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - db := acme.MustFromContext(r.Context()) + db := acme.MustDatabaseFromContext(r.Context()) nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) @@ -81,9 +52,9 @@ func addNonce(next nextHTTP) nextHTTP { func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) - w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) next(w, r) } } @@ -92,17 +63,12 @@ func addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - var expected []string - ctx := r.Context() - opts := optionsFromContext(ctx) - - p, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return + p := acme.MustProvisionerFromContext(r.Context()) + u := &url.URL{ + Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), } - u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + var expected []string if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -159,7 +125,7 @@ func parseJWS(next nextHTTP) nextHTTP { func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -247,7 +213,7 @@ func validateJWS(next nextHTTP) nextHTTP { func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -325,18 +291,20 @@ func lookupProvisioner(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) - - ok, err := opts.PrerequisitesChecker(ctx) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) - return - } - if !ok { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) - return + // If the function is not set assume that all prerequisites are met. + checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx) + if ok { + ok, err := checkFunc(ctx) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } } - next(w, r.WithContext(ctx)) + next(w, r) } } @@ -346,8 +314,8 @@ func checkPrerequisites(next nextHTTP) nextHTTP { func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -355,7 +323,7 @@ func lookupJWK(next nextHTTP) nextHTTP { return } - kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -527,32 +495,14 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { return val, nil } -// provisionerFromContext searches the context for a provisioner. Returns the -// provisioner or an error. -func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { - val := ctx.Value(provisionerContextKey) - if val == nil { - return nil, acme.NewErrorISE("provisioner expected in request context") - } - pval, ok := val.(acme.Provisioner) - if !ok || pval == nil { - return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") - } - return pval, nil -} - // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - prov, err := provisionerFromContext(ctx) - if err != nil { - return nil, err - } - acmeProv, ok := prov.(*provisioner.ACME) - if !ok || acmeProv == nil { + p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME) + if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return acmeProv, nil + return p, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/order.go b/acme/api/order.go index ebd0c7f5..2b9f912e 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -70,16 +70,15 @@ var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -136,16 +135,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - db := acme.MustFromContext(ctx) if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } @@ -166,7 +163,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -190,18 +187,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { // GetOrder ACME api for retrieving an order. func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } - db := acme.MustFromContext(ctx) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) @@ -222,26 +217,24 @@ func GetOrder(w http.ResponseWriter, r *http.Request) { return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -258,7 +251,6 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) @@ -281,10 +273,9 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 55774aea..584ed27e 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -28,13 +28,11 @@ type revokePayload struct { // RevokeCert attempts to revoke a certificate. func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) - prov, err := provisionerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -67,7 +65,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) serial := certToBeRevoked.SerialNumber.String() dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { @@ -138,8 +135,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { } logRevoke(w, options) - o := optionsFromContext(ctx) - w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) w.Write(nil) } diff --git a/acme/challenge.go b/acme/challenge.go index 9f08bae5..8d8466bd 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net" - "net/http" "net/url" "reflect" "strings" @@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error { // If already valid or invalid then return without performing validation. if ch.Status != StatusPending { return nil } switch ch.Type { case HTTP01: - return http01Validate(ctx, ch, db, jwk, vo) + return http01Validate(ctx, ch, db, jwk) case DNS01: - return dns01Validate(ctx, ch, db, jwk, vo) + return dns01Validate(ctx, ch, db, jwk) case TLSALPN01: - return tlsalpn01Validate(ctx, ch, db, jwk, vo) + return tlsalpn01Validate(ctx, ch, db, jwk) default: return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(u.String()) + vc := MustClientFromContext(ctx) + resp, err := vc.Get(u.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", u)) @@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 { return 0 } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, // https://tools.ietf.org/html/rfc8737#section-4 @@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.TLSDial("tcp", hostPort, config) + vc := MustClientFromContext(ctx) + conn, err := vc.TLSDial("tcp", hostPort, config) if err != nil { // With Go 1.17+ tls.Dial fails if there's no overlap between configured // client and server protocols. When this happens the connection is @@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) + vc := MustClientFromContext(ctx) + txtRecords, err := vc.LookupTxt("_acme-challenge." + domain) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) @@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err } return nil } - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -// ValidateChallengeOptions are ACME challenge validator functions. -type ValidateChallengeOptions struct { - HTTPGet httpGetter - LookupTxt lookupTxt - TLSDial tlsDialer -} diff --git a/acme/client.go b/acme/client.go new file mode 100644 index 00000000..2b200e45 --- /dev/null +++ b/acme/client.go @@ -0,0 +1,79 @@ +package acme + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// Client is the interface used to verify ACME challenges. +type Client interface { + // Get issues an HTTP GET to the specified URL. + Get(url string) (*http.Response, error) + + // LookupTXT returns the DNS TXT records for the given domain name. + LookupTxt(name string) ([]string, error) + + // TLSDial connects to the given network address using net.Dialer and then + // initiates a TLS handshake, returning the resulting TLS connection. + TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +type clientKey struct{} + +// NewClientContext adds the given client to the context. +func NewClientContext(ctx context.Context, c Client) context.Context { + return context.WithValue(ctx, clientKey{}, c) +} + +// ClientFromContext returns the current client from the given context. +func ClientFromContext(ctx context.Context) (c Client, ok bool) { + c, ok = ctx.Value(clientKey{}).(Client) + return +} + +// MustClientFromContext returns the current client from the given context. It will +// return a new instance of the client if it does not exist. +func MustClientFromContext(ctx context.Context) Client { + if c, ok := ClientFromContext(ctx); !ok { + return NewClient() + } else { + return c + } +} + +type client struct { + http *http.Client + dialer *net.Dialer +} + +// NewClient returns an implementation of Client for verifying ACME challenges. +func NewClient() Client { + return &client{ + http: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + dialer: &net.Dialer{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *client) Get(url string) (*http.Response, error) { + return c.http.Get(url) +} + +func (c *client) LookupTxt(name string) ([]string, error) { + return net.LookupTXT(name) +} + +func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(c.dialer, network, addr, config) +} diff --git a/acme/common.go b/acme/common.go index 0c9e83dc..5290c06d 100644 --- a/acme/common.go +++ b/acme/common.go @@ -9,6 +9,16 @@ import ( "github.com/smallstep/certificates/authority/provisioner" ) +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock Clock + // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -17,15 +27,42 @@ type CertificateAuthority interface { LoadProvisionerByName(string) (provisioner.Interface, error) } -// Clock that returns time in UTC rounded to seconds. -type Clock struct{} +// NewContext adds the given acme components to the context. +func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { + ctx = NewDatabaseContext(ctx, db) + ctx = NewClientContext(ctx, client) + ctx = NewLinkerContext(ctx, linker) + // Prerequisite checker is optional. + if fn != nil { + ctx = NewPrerequisitesCheckerContext(ctx, fn) + } + return ctx +} -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Truncate(time.Second) +// PrerequisitesChecker is a function that checks if all prerequisites for +// serving ACME are met by the CA configuration. +type PrerequisitesChecker func(ctx context.Context) (bool, error) + +// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns +// always true. +func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) { + return true, nil } -var clock Clock +type prerequisitesKey struct{} + +// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the +// context. +func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context { + return context.WithValue(ctx, prerequisitesKey{}, fn) +} + +// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the +// context. +func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) { + fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker) + return fn, ok && fn != nil +} // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. @@ -38,6 +75,29 @@ type Provisioner interface { GetOptions() *provisioner.Options } +type provisionerKey struct{} + +// NewProvisionerContext adds the given provisioner to the context. +func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, v) +} + +// ProvisionerFromContext returns the current provisioner from the given context. +func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) { + v, ok = ctx.Value(provisionerKey{}).(Provisioner) + return +} + +// MustLinkerFromContext returns the current provisioner from the given context. +// It will panic if it's not in the context. +func MustProvisionerFromContext(ctx context.Context) Provisioner { + if v, ok := ProvisionerFromContext(ctx); !ok { + panic("acme provisioner is not the context") + } else { + return v + } +} + // MockProvisioner for testing type MockProvisioner struct { Mret1 interface{} diff --git a/acme/db.go b/acme/db.go index a8637f57..3d781156 100644 --- a/acme/db.go +++ b/acme/db.go @@ -50,21 +50,21 @@ type DB interface { type dbKey struct{} -// NewContext adds the given acme database to the context. -func NewContext(ctx context.Context, db DB) context.Context { +// NewDatabaseContext adds the given acme database to the context. +func NewDatabaseContext(ctx context.Context, db DB) context.Context { return context.WithValue(ctx, dbKey{}, db) } -// FromContext returns the current acme database from the given context. -func FromContext(ctx context.Context) (db DB, ok bool) { +// DatabaseFromContext returns the current acme database from the given context. +func DatabaseFromContext(ctx context.Context) (db DB, ok bool) { db, ok = ctx.Value(dbKey{}).(DB) return } -// MustFromContext returns the current database from the given context. It -// will panic if it's not in the context. -func MustFromContext(ctx context.Context) DB { - if db, ok := FromContext(ctx); !ok { +// MustDatabaseFromContext returns the current database from the given context. +// It will panic if it's not in the context. +func MustDatabaseFromContext(ctx context.Context) DB { + if db, ok := DatabaseFromContext(ctx); !ok { panic("acme database is not in the context") } else { return db diff --git a/acme/linker.go b/acme/linker.go index 8dc87b14..6e9110c2 100644 --- a/acme/linker.go +++ b/acme/linker.go @@ -4,12 +4,98 @@ import ( "context" "fmt" "net" + "net/http" "net/url" "strings" - "github.com/smallstep/certificates/acme" + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" ) +// LinkType captures the link type. +type LinkType int + +const ( + // NewNonceLinkType new-nonce + NewNonceLinkType LinkType = iota + // NewAccountLinkType new-account + NewAccountLinkType + // AccountLinkType account + AccountLinkType + // OrderLinkType order + OrderLinkType + // NewOrderLinkType new-order + NewOrderLinkType + // OrdersByAccountLinkType list of orders owned by account + OrdersByAccountLinkType + // FinalizeLinkType finalize order + FinalizeLinkType + // NewAuthzLinkType authz + NewAuthzLinkType + // AuthzLinkType new-authz + AuthzLinkType + // ChallengeLinkType challenge + ChallengeLinkType + // CertificateLinkType certificate + CertificateLinkType + // DirectoryLinkType directory + DirectoryLinkType + // RevokeCertLinkType revoke certificate + RevokeCertLinkType + // KeyChangeLinkType key rollover + KeyChangeLinkType +) + +func (l LinkType) String() string { + switch l { + case NewNonceLinkType: + return "new-nonce" + case NewAccountLinkType: + return "new-account" + case AccountLinkType: + return "account" + case NewOrderLinkType: + return "new-order" + case OrderLinkType: + return "order" + case NewAuthzLinkType: + return "new-authz" + case AuthzLinkType: + return "authz" + case ChallengeLinkType: + return "challenge" + case CertificateLinkType: + return "certificate" + case DirectoryLinkType: + return "directory" + case RevokeCertLinkType: + return "revoke-cert" + case KeyChangeLinkType: + return "key-change" + default: + return fmt.Sprintf("unexpected LinkType '%d'", int(l)) + } +} + +func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + return fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + default: + return "" + } +} + // NewLinker returns a new Directory type. func NewLinker(dns, prefix string) Linker { _, _, err := net.SplitHostPort(dns) @@ -32,12 +118,11 @@ func NewLinker(dns, prefix string) Linker { // Linker interface for generating links for ACME resources. type Linker interface { GetLink(ctx context.Context, typ LinkType, inputs ...string) string - GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string - - LinkOrder(ctx context.Context, o *acme.Order) - LinkAccount(ctx context.Context, o *acme.Account) - LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) - LinkAuthorization(ctx context.Context, o *acme.Authorization) + Middleware(http.Handler) http.Handler + LinkOrder(ctx context.Context, o *Order) + LinkAccount(ctx context.Context, o *Account) + LinkChallenge(ctx context.Context, o *Challenge, azID string) + LinkAuthorization(ctx context.Context, o *Authorization) LinkOrdersByAccountID(ctx context.Context, orders []string) } @@ -64,127 +149,81 @@ func MustLinkerFromContext(ctx context.Context) Linker { } } +type baseURLKey struct{} + +func newBaseURLContext(ctx context.Context, r *http.Request) context.Context { + var u *url.URL + if r.Host != "" { + u = &url.URL{Scheme: "https", Host: r.Host} + } + return context.WithValue(ctx, baseURLKey{}, u) +} + +func baseURLFromContext(ctx context.Context) *url.URL { + if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok { + return u + } + return nil +} + // linker generates ACME links. type linker struct { prefix string dns string } -func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { - switch typ { - case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - return fmt.Sprintf("/%s/%s", provisionerName, typ) - case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) - case ChallengeLinkType: - return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) - case OrdersByAccountLinkType: - return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) - case FinalizeLinkType: - return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) - default: - return "" - } +// Middleware gets the provisioner and current url from the request and sets +// them in the context so we can use the linker to create ACME links. +func (l *linker) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add base url to the context. + ctx := newBaseURLContext(r.Context(), r) + + // Add provisioner to the context. + nameEscaped := chi.URLParam(r, "provisionerID") + name, err := url.PathUnescape(nameEscaped) + if err != nil { + render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + return + } + + p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) + if err != nil { + render.Error(w, err) + return + } + + acmeProv, ok := p.(*provisioner.ACME) + if !ok { + render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + return + } + + ctx = NewProvisionerContext(ctx, Provisioner(acmeProv)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } -// GetLink is a helper for GetLinkExplicit +// GetLink is a helper for GetLinkExplicit. func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { - var ( - provName string - baseURL = baseURLFromContext(ctx) - u = url.URL{} - ) - if p, err := provisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - if baseURL != nil { + var u url.URL + if baseURL := baseURLFromContext(ctx); baseURL != nil { u = *baseURL } - - u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...) - - // If no Scheme is set, then default to https. if u.Scheme == "" { u.Scheme = "https" } - - // If no Host is set, then use the default (first DNS attr in the ca.json). if u.Host == "" { u.Host = l.dns } - u.Path = l.prefix + u.Path + p := MustProvisionerFromContext(ctx) + u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...) return u.String() } -// LinkType captures the link type. -type LinkType int - -const ( - // NewNonceLinkType new-nonce - NewNonceLinkType LinkType = iota - // NewAccountLinkType new-account - NewAccountLinkType - // AccountLinkType account - AccountLinkType - // OrderLinkType order - OrderLinkType - // NewOrderLinkType new-order - NewOrderLinkType - // OrdersByAccountLinkType list of orders owned by account - OrdersByAccountLinkType - // FinalizeLinkType finalize order - FinalizeLinkType - // NewAuthzLinkType authz - NewAuthzLinkType - // AuthzLinkType new-authz - AuthzLinkType - // ChallengeLinkType challenge - ChallengeLinkType - // CertificateLinkType certificate - CertificateLinkType - // DirectoryLinkType directory - DirectoryLinkType - // RevokeCertLinkType revoke certificate - RevokeCertLinkType - // KeyChangeLinkType key rollover - KeyChangeLinkType -) - -func (l LinkType) String() string { - switch l { - case NewNonceLinkType: - return "new-nonce" - case NewAccountLinkType: - return "new-account" - case AccountLinkType: - return "account" - case NewOrderLinkType: - return "new-order" - case OrderLinkType: - return "order" - case NewAuthzLinkType: - return "new-authz" - case AuthzLinkType: - return "authz" - case ChallengeLinkType: - return "challenge" - case CertificateLinkType: - return "certificate" - case DirectoryLinkType: - return "directory" - case RevokeCertLinkType: - return "revoke-cert" - case KeyChangeLinkType: - return "key-change" - default: - return fmt.Sprintf("unexpected LinkType '%d'", int(l)) - } -} - // LinkOrder sets the ACME links required by an ACME order. -func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { +func (l *linker) LinkOrder(ctx context.Context, o *Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) @@ -196,17 +235,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { } // LinkAccount sets the ACME links required by an ACME account. -func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { +func (l *linker) LinkAccount(ctx context.Context, acc *Account) { acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { +func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) { ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. -func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { +func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch, az.ID) } diff --git a/acme/linker_test.go b/acme/linker_test.go index a8612e6b..1946dd88 100644 --- a/acme/linker_test.go +++ b/acme/linker_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/smallstep/assert" - "github.com/smallstep/certificates/acme" ) func TestLinker_GetUnescapedPathSuffix(t *testing.T) { @@ -173,27 +172,27 @@ func TestLinker_LinkOrder(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - o *acme.Order - validate func(o *acme.Order) + o *Order + validate func(o *Order) } var tests = map[string]test{ "no-authz-and-no-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.CertificateURL, "") }, }, "one-authz-and-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -202,12 +201,12 @@ func TestLinker_LinkOrder(t *testing.T) { }, }, "many-authz": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo", "bar", "zap"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -237,15 +236,15 @@ func TestLinker_LinkAccount(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - a *acme.Account - validate func(o *acme.Account) + a *Account + validate func(o *Account) } var tests = map[string]test{ "ok": { - a: &acme.Account{ + a: &Account{ ID: accID, }, - validate: func(a *acme.Account) { + validate: func(a *Account) { assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) }, }, @@ -270,15 +269,15 @@ func TestLinker_LinkChallenge(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - ch *acme.Challenge - validate func(o *acme.Challenge) + ch *Challenge + validate func(o *Challenge) } var tests = map[string]test{ "ok": { - ch: &acme.Challenge{ + ch: &Challenge{ ID: chID, }, - validate: func(ch *acme.Challenge) { + validate: func(ch *Challenge) { assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, @@ -305,20 +304,20 @@ func TestLinker_LinkAuthorization(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - az *acme.Authorization - validate func(o *acme.Authorization) + az *Authorization + validate func(o *Authorization) } var tests = map[string]test{ "ok": { - az: &acme.Authorization{ + az: &Authorization{ ID: azID, - Challenges: []*acme.Challenge{ + Challenges: []*Challenge{ {ID: chID0}, {ID: chID1}, {ID: chID2}, }, }, - validate: func(az *acme.Authorization) { + validate: func(az *Authorization) { assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) diff --git a/ca/ca.go b/ca/ca.go index a8ecbb05..e910da74 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -189,30 +189,24 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { dns = fmt.Sprintf("%s:%s", dns, port) } - // ACME Router - prefix := "acme" + // ACME Router is only available if we have a database. var acmeDB acme.DB - if cfg.DB == nil { - acmeDB = nil - } else { + var acmeLinker acme.Linker + if cfg.DB != nil { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) if err != nil { return nil, errors.Wrap(err, "error configuring ACME DB interface") } + acmeLinker = acme.NewLinker(dns, "acme") + mux.Route("/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) + // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 + // of the ACME spec. + mux.Route("/2.0/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) } - acmeOptions := &acmeAPI.HandlerOptions{ - Backdate: *cfg.AuthorityConfig.Backdate, - DNS: dns, - Prefix: prefix, - } - mux.Route("/"+prefix, func(r chi.Router) { - acmeAPI.Route(r, acmeOptions) - }) - // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 - // of the ACME spec. - mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeAPI.Route(r, acmeOptions) - }) // Admin API Router if cfg.AuthorityConfig.EnableAdmin { @@ -280,7 +274,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Create context with all the necessary values. - baseContext := buildContext(auth, scepAuthority, acmeDB) + baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -304,7 +298,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -316,7 +310,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB ctx = scep.NewContext(ctx, scepAuthority) } if acmeDB != nil { - ctx = acme.NewContext(ctx, acmeDB) + ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) } return ctx } From 6f9d847bc6489f7669997edd0e6db5dcb0b9e2d1 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 17:35:35 -0700 Subject: [PATCH 22/40] Fix panic in acme/api tests. --- acme/api/account_test.go | 78 ++++++------ acme/api/eab_test.go | 48 +++---- acme/api/handler.go | 1 - acme/api/handler_test.go | 93 ++++++++------ acme/api/middleware.go | 66 ++++------ acme/api/middleware_test.go | 241 ++++++++++++------------------------ acme/api/order.go | 18 ++- acme/api/order_test.go | 112 +++++++++-------- acme/api/revoke.go | 7 +- acme/api/revoke_test.go | 62 +++++----- 10 files changed, 333 insertions(+), 393 deletions(-) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index 3fbabfe5..18d24ab6 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -296,10 +296,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -315,9 +314,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrdersByAccountID(w, req) res := w.Result() @@ -363,6 +362,7 @@ func TestHandler_NewAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -371,6 +371,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -379,6 +380,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to "+ @@ -393,6 +395,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -405,8 +408,9 @@ func TestHandler_NewAccount(t *testing.T) { b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -418,9 +422,10 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -432,10 +437,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -454,9 +460,9 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), @@ -471,7 +477,7 @@ func TestHandler_NewAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ @@ -510,9 +516,9 @@ func TestHandler_NewAccount(t *testing.T) { } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, scepProvisioner) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), @@ -551,8 +557,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) eak := &acme.ExternalAccountKey{ ID: "eakID", @@ -599,8 +604,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -635,11 +639,11 @@ func TestHandler_NewAccount(t *testing.T) { Status: acme.StatusValid, Contact: []string{"foo", "bar"}, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, acc: acc, statusCode: 200, @@ -664,8 +668,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = false ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -719,8 +722,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -759,9 +761,9 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() NewAccount(w, req) res := w.Result() @@ -814,6 +816,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -822,6 +825,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -830,6 +834,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -839,6 +844,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -848,6 +854,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), @@ -862,6 +869,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -894,10 +902,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -914,11 +921,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -929,10 +936,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -946,11 +952,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -959,9 +965,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrUpdateAccount(w, req) res := w.Result() diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index 1c76618b..ae47a1b9 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -98,8 +98,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -143,8 +142,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ @@ -198,8 +196,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, scepProvisioner) return test{ ctx: ctx, err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), @@ -218,8 +215,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -264,8 +260,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{}, @@ -310,8 +305,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -358,8 +352,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -408,8 +401,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -458,8 +450,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -506,8 +497,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() boundAt := time.Now().Add(1 * time.Second) @@ -565,8 +555,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -623,8 +612,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -678,8 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -734,8 +721,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -762,10 +748,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{ - // db: tc.db, - // } - got, err := validateExternalAccountBinding(tc.ctx, tc.nar) + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) + got, err := validateExternalAccountBinding(ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { diff --git a/acme/api/handler.go b/acme/api/handler.go index efe2b780..f6d79031 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -223,7 +223,6 @@ func (d *Directory) ToLog() (interface{}, error) { func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) - fmt.Println(acmeProv, err) if err != nil { render.Error(w, err) return diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index fcc33a87..2ac41228 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -24,6 +25,29 @@ import ( "go.step.sm/crypto/pemutil" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + +func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return a + } +} + func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string @@ -52,7 +76,7 @@ func TestHandler_GetNonce(t *testing.T) { } func TestHandler_GetDirectory(t *testing.T) { - linker := NewLinker("ca.smallstep.com", "acme") + linker := acme.NewLinker("ca.smallstep.com", "acme") _ = linker type test struct { ctx context.Context @@ -62,13 +86,10 @@ func TestHandler_GetDirectory(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - ctx: ctx, + ctx: context.Background(), statusCode: 500, - err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + err: acme.NewErrorISE("provisioner is not in context"), } }, "fail/different-provisioner": func(t *testing.T) test { @@ -76,9 +97,7 @@ func TestHandler_GetDirectory(t *testing.T) { Type: "SCEP", Name: "test@scep-provisioner.com", } - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ ctx: ctx, statusCode: 500, @@ -89,8 +108,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -109,8 +127,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -131,9 +148,9 @@ func TestHandler_GetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: linker} + ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetDirectory(w, req) res := w.Result() @@ -220,7 +237,7 @@ func TestHandler_GetAuthorization(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, @@ -286,10 +303,9 @@ func TestHandler_GetAuthorization(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { @@ -305,9 +321,9 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetAuthorization(w, req) res := w.Result() @@ -448,9 +464,9 @@ func TestHandler_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetCertificate(w, req) res := w.Result() @@ -492,7 +508,7 @@ func TestHandler_GetChallenge(t *testing.T) { type test struct { db acme.DB - vco *acme.ValidateChallengeOptions + vc acme.Client ctx context.Context statusCode int ch *acme.Challenge @@ -501,6 +517,7 @@ func TestHandler_GetChallenge(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -508,6 +525,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -517,6 +535,7 @@ func TestHandler_GetChallenge(t *testing.T) { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -524,10 +543,11 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -535,7 +555,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/db.GetChallenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -554,7 +574,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -573,7 +593,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/no-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -592,7 +612,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, jwkContextKey, nil) @@ -612,7 +632,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -640,8 +660,8 @@ func TestHandler_GetChallenge(t *testing.T) { return acme.NewErrorISE("force") }, }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -652,14 +672,13 @@ func TestHandler_GetChallenge(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -691,8 +710,8 @@ func TestHandler_GetChallenge(t *testing.T) { URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -704,9 +723,9 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetChallenge(w, req) res := w.Result() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 09e88b8d..a254a83b 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "net/url" "strings" - "github.com/go-chi/chi" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -63,7 +62,12 @@ func addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - p := acme.MustProvisionerFromContext(r.Context()) + p, err := provisionerFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } + u := &url.URL{ Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), } @@ -260,32 +264,6 @@ func extractJWK(next nextHTTP) nextHTTP { } } -// lookupProvisioner loads the provisioner associated with the request. -// Responds 404 if the provisioner does not exist. -func lookupProvisioner(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") - name, err := url.PathUnescape(nameEscaped) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) - return - } - p, err := mustAuthority(r.Context()).LoadProvisionerByName(name) - if err != nil { - render.Error(w, err) - return - } - acmeProv, ok := p.(*provisioner.ACME) - if !ok { - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) - return - } - ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) - next(w, r.WithContext(ctx)) - } -} - // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. func checkPrerequisites(next nextHTTP) nextHTTP { @@ -446,16 +424,12 @@ type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") - // baseURLContextKey baseURL key - baseURLContextKey = ContextKey("baseURL") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // accountFromContext searches the context for an ACME account. Returns the @@ -468,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) { return val, nil } -// baseURLFromContext returns the baseURL if one is stored in the context. -func baseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(baseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) @@ -495,14 +460,29 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { return val, nil } +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { + p, ok := acme.ProvisionerFromContext(ctx) + if !ok || p == nil { + return nil, acme.NewErrorISE("provisioner expected in request context") + } + return p, nil +} + // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME) + p, err := provisionerFromContext(ctx) + if err != nil { + return nil, err + } + ap, ok := p.(*provisioner.ACME) if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return p, nil + + return ap, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index f192e67e..39a696ae 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) { w.Write(testBody) } -func Test_baseURLFromRequest(t *testing.T) { - tests := []struct { - name string - targetURL string - expectedResult *url.URL - requestPreparer func(*http.Request) - }{ - { - "HTTPS host pass-through failed.", - "https://my.dummy.host", - &url.URL{Scheme: "https", Host: "my.dummy.host"}, - nil, - }, - { - "Port pass-through failed", - "https://host.with.port:8080", - &url.URL{Scheme: "https", Host: "host.with.port:8080"}, - nil, - }, - { - "Explicit host from Request.Host was not used.", - "https://some.target.host:8080", - &url.URL{Scheme: "https", Host: "proxied.host"}, - func(r *http.Request) { - r.Host = "proxied.host" - }, - }, - { - "Missing Request.Host value did not result in empty string result.", - "https://some.host", - nil, - func(r *http.Request) { - r.Host = "" - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", tc.targetURL, nil) - if tc.requestPreparer != nil { - tc.requestPreparer(request) - } - result := getBaseURLFromRequest(request) - if result == nil || tc.expectedResult == nil { - assert.Equals(t, result, tc.expectedResult) - } else if result.String() != tc.expectedResult.String() { - t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String()) - } - }) - } -} - -func TestHandler_baseURLFromRequest(t *testing.T) { - // h := &Handler{} - req := httptest.NewRequest("GET", "/foo", nil) - req.Host = "test.ca.smallstep.com:8080" - w := httptest.NewRecorder() - - next := func(w http.ResponseWriter, r *http.Request) { - bu := baseURLFromContext(r.Context()) - if assert.NotNil(t, bu) { - assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") - assert.Equals(t, bu.Scheme, "https") +func newBaseContext(ctx context.Context, args ...interface{}) context.Context { + for _, a := range args { + switch v := a.(type) { + case acme.DB: + ctx = acme.NewDatabaseContext(ctx, v) + case acme.Linker: + ctx = acme.NewLinkerContext(ctx, v) + case acme.PrerequisitesChecker: + ctx = acme.NewPrerequisitesCheckerContext(ctx, v) } } - - baseURLFromRequest(next)(w, req) - - req = httptest.NewRequest("GET", "/foo", nil) - req.Host = "" - - next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, baseURLFromContext(r.Context()), nil) - } - - baseURLFromRequest(next)(w, req) + return ctx } func TestHandler_addNonce(t *testing.T) { @@ -139,8 +74,8 @@ func TestHandler_addNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", u, nil) + ctx := newBaseContext(context.Background(), tc.db) + req := httptest.NewRequest("GET", u, nil).WithContext(ctx) w := httptest.NewRecorder() addNonce(testNext)(w, req) res := w.Result() @@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { link string - linker Linker statusCode int ctx context.Context err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) return test{ - linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -195,7 +128,6 @@ func TestHandler_addDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { - h Handler ctx context.Context contentType string err *acme.Error @@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/provisioner-not-set": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, url: u, ctx: context.Background(), contentType: "foo", @@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, url: u, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), @@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), @@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) { }, "ok": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkix-cert", statusCode: 200, } }, "ok/certificate/jose+json": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ - h: Handler{ - // linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - linker Linker + linker acme.Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), @@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _parsed) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) @@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) { } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -881,9 +791,9 @@ func TestHandler_lookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() lookupJWK(tc.next)(w, req) res := w.Result() @@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), @@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1077,9 +991,9 @@ func TestHandler_extractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() extractJWK(tc.next)(w, req) res := w.Result() @@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/no-signature": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), @@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), @@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), @@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), @@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), @@ -1444,9 +1365,9 @@ func TestHandler_validateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() validateJWS(tc.next)(w, req) res := w.Result() @@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { u := "https://ca.smallstep.com/acme/account" type test struct { db acme.DB - linker Linker + linker acme.Linker statusCode int ctx context.Context err *acme.Error @@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) @@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ - linker: NewLinker("test.ca.smallstep.com", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, acc.ID) @@ -1628,9 +1548,9 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() extractOrLookupJWK(tc.next)(w, req) res := w.Result() @@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) { u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) type test struct { - linker Linker + linker acme.Linker ctx context.Context prerequisitesChecker func(context.Context) (bool, error) next func(http.ResponseWriter, *http.Request) @@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "fail/prerequisites-nok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, next: func(w http.ResponseWriter, r *http.Request) { diff --git a/acme/api/order.go b/acme/api/order.go index 2b9f912e..08718977 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -72,13 +72,17 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -189,13 +193,17 @@ func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { @@ -228,13 +236,17 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) diff --git a/acme/api/order_test.go b/acme/api/order_test.go index f0a2d1d4..0ab76778 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -276,15 +276,17 @@ func TestHandler_GetOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -294,6 +296,7 @@ func TestHandler_GetOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -301,9 +304,10 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -311,7 +315,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -325,7 +329,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -341,7 +345,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -357,7 +361,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/order-update-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -381,10 +385,9 @@ func TestHandler_GetOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { @@ -421,9 +424,9 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrder(w, req) res := w.Result() @@ -636,8 +639,8 @@ func TestHandler_newAuthorization(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - // h := &Handler{db: tc.db} - if err := newAuthorization(context.Background(), tc.az); err != nil { + ctx := newBaseContext(context.Background(), tc.db) + if err := newAuthorization(ctx, tc.az); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *acme.Error: @@ -677,15 +680,17 @@ func TestHandler_NewOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -695,6 +700,7 @@ func TestHandler_NewOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -702,9 +708,10 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -713,8 +720,9 @@ func TestHandler_NewOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -722,10 +730,11 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("paylod does not exist"), @@ -733,10 +742,11 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), @@ -747,10 +757,11 @@ func TestHandler_NewOrder(t *testing.T) { fr := &NewOrderRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), @@ -765,7 +776,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -793,7 +804,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( @@ -863,10 +874,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3, ch4 **acme.Challenge az1ID, az2ID *string @@ -978,10 +988,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1070,10 +1079,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1161,10 +1169,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1253,10 +1260,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1334,9 +1340,9 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() NewOrder(w, req) res := w.Result() @@ -1371,6 +1377,7 @@ func TestHandler_NewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { + mockMustAuthority(t, &mockCA{}) prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -1429,15 +1436,17 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -1447,6 +1456,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1454,9 +1464,10 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1465,8 +1476,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -1474,10 +1486,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("paylod does not exist"), @@ -1485,10 +1498,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), @@ -1499,10 +1513,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), @@ -1511,7 +1526,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1526,7 +1541,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1543,7 +1558,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1560,7 +1575,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/order-finalize-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1585,10 +1600,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -1624,9 +1638,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() FinalizeOrder(w, req) res := w.Result() diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 584ed27e..a8b98f3f 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -30,7 +30,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) - prov := acme.MustProvisionerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -38,6 +37,12 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { return } + prov, err := provisionerFromContext(ctx) + if err != nil { + render.Error(w, err) + return + } + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 3a0ba70d..c746c11b 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -511,6 +511,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-jws": func(t *testing.T) test { ctx := context.Background() return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -519,6 +520,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/nil-jws": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -527,6 +529,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -534,8 +537,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, nil) + ctx = acme.NewProvisionerContext(ctx, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -543,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -552,9 +557,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -563,9 +569,10 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/unmarshal-payload": func(t *testing.T) test { malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error unmarshaling payload"), @@ -577,10 +584,11 @@ func TestHandler_RevokeCert(t *testing.T) { } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -596,10 +604,11 @@ func TestHandler_RevokeCert(t *testing.T) { } emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -610,7 +619,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -628,7 +637,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/different-certificate-contents": func(t *testing.T) test { aDifferentCert, _, err := generateCertKeyPair() assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -647,7 +656,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -666,7 +675,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, accContextKey, nil) @@ -687,11 +696,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -717,11 +725,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-authorized": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -771,10 +778,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -798,11 +804,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -832,7 +837,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -870,7 +875,7 @@ func TestHandler_RevokeCert(t *testing.T) { invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -908,7 +913,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) + ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -940,7 +945,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -972,7 +977,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -1003,11 +1008,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "ok/using-account-key": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1031,10 +1035,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1057,9 +1060,10 @@ func TestHandler_RevokeCert(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - // h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) + mockMustAuthority(t, tc.ca) req := httptest.NewRequest("POST", revokeURL, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() RevokeCert(w, req) res := w.Result() From ba499eeb2ad26f66d5bf0b45f0af12478dec4573 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 17:40:10 -0700 Subject: [PATCH 23/40] Fix acme/api tests. --- acme/api/middleware_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 39a696ae..193f5347 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1634,9 +1634,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - // h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() checkPrerequisites(tc.next)(w, req) res := w.Result() From 2ab7dc6f9d2278fabc926404b0f9781f9284d323 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 18:09:26 -0700 Subject: [PATCH 24/40] Fix acme tests. --- acme/challenge_test.go | 222 ++++++++++++++++++++++------------------- acme/linker.go | 8 +- acme/linker_test.go | 72 ++++++++----- 3 files changed, 171 insertions(+), 131 deletions(-) diff --git a/acme/challenge_test.go b/acme/challenge_test.go index c05b25e7..e1b6816a 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -29,6 +29,18 @@ import ( "github.com/smallstep/assert" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + func Test_storeError(t *testing.T) { type test struct { ch *Challenge @@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) { func TestChallenge_Validate(t *testing.T) { type test struct { ch *Challenge - vo *ValidateChallengeOptions + vc Client jwk *jose.JSONWebKey db DB srv *httptest.Server @@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) { } return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) { defer tc.srv.Close() } - if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -524,7 +537,7 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil @@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) { jwk.Key = "foo" return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) { fulldomain := "*.zap.internal" domain := strings.TrimPrefix(fulldomain, "*.") type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo"}, nil }, }, @@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) { } } +type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) + func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { srv := httptest.NewUnstartedServer(http.NewServeMux()) @@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) { } } type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, srv: srv, jwk: jwk, @@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) { defer tc.srv.Close() } - if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: diff --git a/acme/linker.go b/acme/linker.go index 6e9110c2..bddc21f1 100644 --- a/acme/linker.go +++ b/acme/linker.go @@ -206,6 +206,11 @@ func (l *linker) Middleware(next http.Handler) http.Handler { // GetLink is a helper for GetLinkExplicit. func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var name string + if p, ok := ProvisionerFromContext(ctx); ok { + name = p.GetName() + } + var u url.URL if baseURL := baseURLFromContext(ctx); baseURL != nil { u = *baseURL @@ -217,8 +222,7 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) st u.Host = l.dns } - p := MustProvisionerFromContext(ctx) - u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...) + u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...) return u.String() } diff --git a/acme/linker_test.go b/acme/linker_test.go index 1946dd88..b85d1a53 100644 --- a/acme/linker_test.go +++ b/acme/linker_test.go @@ -5,16 +5,34 @@ import ( "fmt" "net/url" "testing" + "time" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" ) -func TestLinker_GetUnescapedPathSuffix(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - linker := NewLinker(dns, prefix) +func mockProvisioner(t *testing.T) Provisioner { + t.Helper() + var defaultDisableRenewal = false + + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + }}); err != nil { + fmt.Printf("%v", err) + } + return p +} - getPath := linker.GetUnescapedPathSuffix +func TestGetUnescapedPathSuffix(t *testing.T) { + getPath := GetUnescapedPathSuffix assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") @@ -31,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) { } func TestLinker_DNS(t *testing.T) { - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) type test struct { name string dns string @@ -116,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) { linker := NewLinker(dns, prefix) id := "1234" - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) // No provisioner and no BaseURL from request assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) // Provisioner: yes, BaseURL: no - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) // Provisioner: no, BaseURL: yes - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) @@ -162,10 +180,10 @@ func TestLinker_GetLink(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) oid := "orderID" certID := "certID" @@ -227,10 +245,10 @@ func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) accID := "accountID" linkerPrefix := "acme" @@ -259,10 +277,10 @@ func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID := "chID" azID := "azID" @@ -292,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID0 := "chID-0" chID1 := "chID-1" @@ -334,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) From a8a42619804ef67cd39e4589bc472060366e17e7 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 18:39:03 -0700 Subject: [PATCH 25/40] Fix authority/admin/api tests --- authority/admin/api/acme_test.go | 34 +++++++++++--------- authority/admin/api/admin_test.go | 32 ++++++------------- authority/admin/api/middleware_test.go | 13 +++----- authority/admin/api/provisioner_test.go | 42 +++++++++++-------------- 4 files changed, 52 insertions(+), 69 deletions(-) diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 6ffe1418..6b89b288 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -29,6 +29,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } +func mockMustAuthority(t *testing.T, a adminAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) adminAuthority { + return a + } +} + func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context @@ -54,6 +65,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { return test{ ctx: ctx, auth: auth, + adminDB: &admin.MockDB{}, err: err, statusCode: 500, } @@ -143,16 +155,12 @@ func TestHandler_requireEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.requireEABEnabled(tc.next)(w, req) + requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -194,6 +202,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } return test{ auth: auth, + adminDB: &admin.MockDB{}, provisionerName: "provName", want: false, err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), @@ -358,12 +367,9 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - acmeDB: nil, - } - got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(context.Background(), tc.adminDB) + got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName) if (err != nil) != (tc.err != nil) { t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) return diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index 8d223b52..2f5528e1 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -317,14 +317,11 @@ func TestHandler_GetAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmin(w, req) + GetAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -456,13 +453,10 @@ func TestHandler_GetAdmins(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmins(w, req) + GetAdmins(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -640,13 +634,11 @@ func TestHandler_CreateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateAdmin(w, req) + CreateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -732,13 +724,11 @@ func TestHandler_DeleteAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteAdmin(w, req) + DeleteAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -877,13 +867,11 @@ func TestHandler_UpdateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.UpdateAdmin(w, req) + UpdateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 7fb4671a..3445a3b5 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -64,13 +64,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireAPIEnabled(tc.next)(w, req) + requireAPIEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -194,13 +192,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.extractAuthorizeTokenAdmin(tc.next)(w, req) + extractAuthorizeTokenAdmin(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 6d5024f2..6ee26dba 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -47,6 +47,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -71,6 +72,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -153,13 +155,11 @@ func TestHandler_GetProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } - req := tc.req.WithContext(tc.ctx) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + req := tc.req.WithContext(ctx) w := httptest.NewRecorder() - h.GetProvisioner(w, req) + GetProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -277,12 +277,10 @@ func TestHandler_GetProvisioners(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetProvisioners(w, req) + GetProvisioners(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -402,13 +400,11 @@ func TestHandler_CreateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateProvisioner(w, req) + CreateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -562,12 +558,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteProvisioner(w, req) + DeleteProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -616,6 +610,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ // TODO(hs): this probably needs a better error Type: "", @@ -645,6 +640,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: ctx, body: body, + adminDB: &admin.MockDB{}, auth: auth, statusCode: 500, err: &admin.Error{ @@ -1052,14 +1048,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.UpdateProvisioner(w, req) + UpdateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) From 9147356d8afd01dcc4553a2574db36f236b79ba8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 18:47:47 -0700 Subject: [PATCH 26/40] Fix linter errors --- acme/api/handler.go | 4 ++-- acme/api/handler_test.go | 2 +- acme/client.go | 6 ++--- authority/admin/api/handler.go | 4 ---- ca/ca.go | 5 ++-- scep/api/api.go | 42 +++++++++++++++++----------------- scep/authority.go | 14 ------------ 7 files changed, 29 insertions(+), 48 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index f6d79031..d00f8275 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -40,7 +40,7 @@ type payloadInfo struct { // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { - // DB storage backend that impements the acme.DB interface. + // DB storage backend that implements the acme.DB interface. // // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB @@ -50,7 +50,7 @@ type HandlerOptions struct { // Deprecated: use authority.NewContext(context.Context, *authority.Authority) CA acme.CertificateAuthority - // Backdate is the duration that the CA will substract from the current time + // Backdate is the duration that the CA will subtract from the current time // to set the NotBefore in the certificate. Backdate provisioner.Duration diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 2ac41228..bd88c96f 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -31,7 +31,7 @@ type mockClient struct { tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) } -func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) } func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { return m.tlsDial(network, addr, config) diff --git a/acme/client.go b/acme/client.go index 2b200e45..31f4c975 100644 --- a/acme/client.go +++ b/acme/client.go @@ -37,11 +37,11 @@ func ClientFromContext(ctx context.Context) (c Client, ok bool) { // MustClientFromContext returns the current client from the given context. It will // return a new instance of the client if it does not exist. func MustClientFromContext(ctx context.Context) Client { - if c, ok := ClientFromContext(ctx); !ok { + c, ok := ClientFromContext(ctx) + if !ok { return NewClient() - } else { - return c } + return c } type client struct { diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 0acd3ca9..95b9cd9c 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -40,10 +40,6 @@ func Route(r api.Router, acmeResponder acmeAdminResponderInterface) { return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } - requireEABEnabled := func(next nextHTTP) nextHTTP { - return requireEABEnabled(next) - } - // Provisioners r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) diff --git a/ca/ca.go b/ca/ca.go index e910da74..d7943a6c 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -230,13 +230,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { if err != nil { return nil, errors.Wrap(err, "error creating SCEP authority") } - scepRouterHandler := scepAPI.New(scepAuthority) // According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10), // SCEP operations are performed using HTTP, so that's why the API is mounted // to the insecure mux. insecureMux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) // The RFC also mentions usage of HTTPS, but seems to advise @@ -246,7 +245,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // as well as HTTPS can be used to request certificates // using SCEP. mux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) } diff --git a/scep/api/api.go b/scep/api/api.go index 0d62904d..e513aa43 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -38,8 +38,8 @@ type request struct { Message []byte } -// response is a SCEP server response. -type response struct { +// Response is a SCEP server Response. +type Response struct { Operation string CACertNum int Data []byte @@ -81,7 +81,7 @@ func Get(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - var res response + var res Response switch req.Operation { case opnGetCACert: @@ -110,7 +110,7 @@ func Post(w http.ResponseWriter, r *http.Request) { return } - var res response + var res Response switch req.Operation { case opnPKIOperation: res, err = PKIOperation(r.Context(), req) @@ -207,18 +207,18 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { } // GetCACert returns the CA certificates in a SCEP response -func GetCACert(ctx context.Context) (response, error) { +func GetCACert(ctx context.Context) (Response, error) { auth := scep.MustFromContext(ctx) certs, err := auth.GetCACertificates(ctx) if err != nil { - return response{}, err + return Response{}, err } if len(certs) == 0 { - return response{}, errors.New("missing CA cert") + return Response{}, errors.New("missing CA cert") } - res := response{ + res := Response{ Operation: opnGetCACert, CACertNum: len(certs), } @@ -231,7 +231,7 @@ func GetCACert(ctx context.Context) (response, error) { // not signed or encrypted data has to be returned. data, err := microscep.DegenerateCertificates(certs) if err != nil { - return response{}, err + return Response{}, err } res.Data = data } @@ -240,11 +240,11 @@ func GetCACert(ctx context.Context) (response, error) { } // GetCACaps returns the CA capabilities in a SCEP response -func GetCACaps(ctx context.Context) (response, error) { +func GetCACaps(ctx context.Context) (Response, error) { auth := scep.MustFromContext(ctx) caps := auth.GetCACaps(ctx) - res := response{ + res := Response{ Operation: opnGetCACaps, Data: formatCapabilities(caps), } @@ -253,12 +253,12 @@ func GetCACaps(ctx context.Context) (response, error) { } // PKIOperation performs PKI operations and returns a SCEP response -func PKIOperation(ctx context.Context, req request) (response, error) { +func PKIOperation(ctx context.Context, req request) (Response, error) { // parse the message using microscep implementation microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { // return the error, because we can't use the msg for creating a CertRep - return response{}, err + return Response{}, err } // this is essentially doing the same as microscep.ParsePKIMessage, but @@ -266,7 +266,7 @@ func PKIOperation(ctx context.Context, req request) (response, error) { // wrapper for the microscep implementation. p7, err := pkcs7.Parse(microMsg.Raw) if err != nil { - return response{}, err + return Response{}, err } // copy over properties to our internal PKIMessage @@ -280,7 +280,7 @@ func PKIOperation(ctx context.Context, req request) (response, error) { auth := scep.MustFromContext(ctx) if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil { - return response{}, err + return Response{}, err } // NOTE: at this point we have sufficient information for returning nicely signed CertReps @@ -315,7 +315,7 @@ func PKIOperation(ctx context.Context, req request) (response, error) { return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } - res := response{ + res := Response{ Operation: opnPKIOperation, Data: certRep.Raw, Certificate: certRep.Certificate, @@ -329,7 +329,7 @@ func formatCapabilities(caps []string) []byte { } // writeResponse writes a SCEP response back to the SCEP client. -func writeResponse(w http.ResponseWriter, res response) { +func writeResponse(w http.ResponseWriter, res Response) { if res.Error != nil { log.Error(w, res.Error) @@ -349,20 +349,20 @@ func fail(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusInternalServerError) } -func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { +func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (Response, error) { auth := scep.MustFromContext(ctx) certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { - return response{}, err + return Response{}, err } - return response{ + return Response{ Operation: opnPKIOperation, Data: certRepMsg.Raw, Error: failError, }, nil } -func contentHeader(r response) string { +func contentHeader(r Response) string { switch r.Operation { default: return "text/plain" diff --git a/scep/authority.go b/scep/authority.go index 946fa948..7fe01c1d 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -63,20 +63,6 @@ type AuthorityOptions struct { Prefix string } -type optionsKey struct{} - -func newOptionsContext(ctx context.Context, o *AuthorityOptions) context.Context { - return context.WithValue(ctx, optionsKey{}, o) -} - -func optionsFromContext(ctx context.Context) *AuthorityOptions { - o, ok := ctx.Value(optionsKey{}).(*AuthorityOptions) - if !ok { - panic("scep options are not in the context") - } - return o -} - // SignAuthority is the interface for a signing authority type SignAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) From 62d93a644e8d00d23d07582e403104c8096c3374 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 2 May 2022 19:39:50 -0700 Subject: [PATCH 27/40] Apply base context to test of the ca package --- ca/bootstrap_test.go | 4 ++++ ca/ca_test.go | 19 +++++++++++++------ ca/tls_test.go | 8 +++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 9aaa5f1f..1cda8232 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -53,7 +53,11 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } + baseContext := buildContext(ca.auth, nil, nil, nil) srv.Config.Handler = ca.srv.Handler + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } srv.TLS = ca.srv.TLSConfig srv.StartTLS() // Force the use of GetCertificate on IPs diff --git a/ca/ca_test.go b/ca/ca_test.go index e4c35a90..29eac575 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -2,6 +2,7 @@ package ca import ( "bytes" + "context" "crypto" "crypto/rand" "crypto/sha1" @@ -281,7 +282,8 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) { rq.TLS = tc.tlsConnState rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} diff --git a/ca/tls_test.go b/ca/tls_test.go index 93dbe9b3..946a6cb5 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "io" "log" + "net" "net/http" "net/http/httptest" "reflect" @@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server { panic(err) } // Use a httptest.Server instead - return startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + baseContext := buildContext(ca.auth, nil, nil, nil) + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } + return srv } func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { From 43ddcf2efe96d6bbf41a92c6f59ad9442e5c3f48 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 4 May 2022 17:35:34 -0700 Subject: [PATCH 28/40] Do not use deprecated AuthorizeSign --- api/api.go | 1 - api/api_test.go | 12 ++++-------- api/revoke_test.go | 8 ++++---- api/sign.go | 7 +++++-- api/ssh_test.go | 2 +- authority/authorize.go | 3 +-- 6 files changed, 15 insertions(+), 18 deletions(-) diff --git a/api/api.go b/api/api.go index 0ca4a5ef..75d26237 100644 --- a/api/api.go +++ b/api/api.go @@ -35,7 +35,6 @@ type Authority interface { SSHAuthority // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) - AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index 698b629c..1f27ab8c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -185,7 +185,7 @@ func mockMustAuthority(t *testing.T, a Authority) { type mockAuthority struct { ret1, ret2 interface{} err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) @@ -214,12 +214,8 @@ type mockAuthority struct { // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) + if m.authorize != nil { + return m.authorize(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } @@ -908,7 +904,7 @@ func Test_Sign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *authority.TLSOptions { diff --git a/api/revoke_test.go b/api/revoke_test.go index fa46dd90..c3fa6ceb 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) { statusCode: http.StatusOK, tls: cs, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { @@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusInternalServerError, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusForbidden, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { diff --git a/api/sign.go b/api/sign.go index b263e2e9..f7c3cc5a 100644 --- a/api/sign.go +++ b/api/sign.go @@ -68,8 +68,11 @@ func Sign(w http.ResponseWriter, r *http.Request) { TemplateData: body.TemplateData, } - a := mustAuthority(r.Context()) - signOpts, err := a.AuthorizeSign(body.OTT) + ctx := r.Context() + a := mustAuthority(ctx) + + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return diff --git a/api/ssh_test.go b/api/ssh_test.go index c6fee2de..57dd6775 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -316,7 +316,7 @@ func Test_SSHSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { diff --git a/authority/authorize.go b/authority/authorize.go index 7f9f456c..c0722a1b 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -251,8 +251,7 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio // AuthorizeSign authorizes a signature request by validating and authenticating // a token that must be sent w/ the request. // -// NOTE: This method is deprecated and should not be used. We make it available -// in the short term os as not to break existing clients. +// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error). func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) return a.Authorize(ctx, token) From d51c6b7d83b566182f0584ea6d6e82332057ed39 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 4 May 2022 19:20:34 -0700 Subject: [PATCH 29/40] Make step handler backward compatible --- scep/api/api.go | 34 +++++++++++++++++++++++++++------- scep/authority.go | 1 - 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/scep/api/api.go b/scep/api/api.go index e513aa43..49a5267a 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -48,28 +48,48 @@ type Response struct { } // handler is the SCEP request handler. -type handler struct{} +type handler struct { + auth *scep.Authority +} // Route traffic and implement the Router interface. // // Deprecated: use scep.Route(r api.Router) func (h *handler) Route(r api.Router) { - Route(r) + route(r, func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := scep.NewContext(r.Context(), h.auth) + next(w, r.WithContext(ctx)) + } + }) } // New returns a new SCEP API router. // // Deprecated: use scep.Route(r api.Router) func New(auth *scep.Authority) api.RouterHandler { - return &handler{} + return &handler{auth: auth} } // Route traffic and implement the Router interface. func Route(r api.Router) { - r.MethodFunc(http.MethodGet, "/{provisionerName}/*", lookupProvisioner(Get)) - r.MethodFunc(http.MethodGet, "/{provisionerName}", lookupProvisioner(Get)) - r.MethodFunc(http.MethodPost, "/{provisionerName}/*", lookupProvisioner(Post)) - r.MethodFunc(http.MethodPost, "/{provisionerName}", lookupProvisioner(Post)) + route(r, nil) +} + +func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) { + getHandler := lookupProvisioner(Get) + postHandler := lookupProvisioner(Post) + + // For backward compatibility. + if middleware != nil { + getHandler = middleware(getHandler) + postHandler = middleware(postHandler) + } + + r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler) + r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler) } // Get handles all SCEP GET requests diff --git a/scep/authority.go b/scep/authority.go index 7fe01c1d..7dbbb8c5 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -453,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi // MatchChallengePassword verifies a SCEP challenge password func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) { - p, err := provisionerFromContext(ctx) if err != nil { return false, err From 2ea0c703448f5dff7398535167da9b346d1fd5a9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 5 May 2022 12:25:07 -0700 Subject: [PATCH 30/40] Move acme context middleware to deprecated handler --- acme/api/handler.go | 57 +++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index d00f8275..96e22d85 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -79,12 +79,29 @@ type handler struct { opts *HandlerOptions } -// Route traffic and implement the Router interface. +// Route traffic and implement the Router interface. For backward compatibility +// this route adds will add a new middleware that will set the ACME components +// on the context. +// +// Deprecated: use api.Route(r api.Router) func (h *handler) Route(r api.Router) { - route(r, h.opts) + client := acme.NewClient() + linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix) + route(r, func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) + next(w, r.WithContext(ctx)) + } + }) } // NewHandler returns a new ACME API handler. +// +// Deprecated: use api.Route(r api.Router) func NewHandler(opts HandlerOptions) api.RouterHandler { return &handler{ opts: &opts, @@ -98,40 +115,18 @@ func Route(r api.Router) { route(r, nil) } -func route(r api.Router, opts *HandlerOptions) { - var withContext func(next nextHTTP) nextHTTP - - // For backward compatibility this block adds will add a new middleware that - // will set the ACME components to the context. - if opts != nil { - client := acme.NewClient() - linker := acme.NewLinker(opts.DNS, opts.Prefix) - - withContext = func(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { - ctx = authority.NewContext(ctx, ca) - } - ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker) - next(w, r.WithContext(ctx)) - } - } - } else { - withContext = func(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - next(w, r) - } - } - } - +func route(r api.Router, middleware func(next nextHTTP) nextHTTP) { commonMiddleware := func(next nextHTTP) nextHTTP { - return withContext(func(w http.ResponseWriter, r *http.Request) { + handler := func(w http.ResponseWriter, r *http.Request) { // Linker middleware gets the provisioner and current url from the // request and sets them in the context. linker := acme.MustLinkerFromContext(r.Context()) linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) - }) + } + if middleware != nil { + handler = middleware(handler) + } + return handler } validatingMiddleware := func(next nextHTTP) nextHTTP { return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) From f639bfc53b7d53f089e5edc22ead70f91957e532 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 6 May 2022 14:05:08 -0700 Subject: [PATCH 31/40] Use contexts on the new PolicyAdminResponder --- authority/admin/api/handler.go | 6 +- authority/admin/api/policy.go | 165 ++++++++++------------- authority/admin/api/policy_test.go | 202 +++++++++++++++++++---------- ca/ca.go | 2 +- 4 files changed, 209 insertions(+), 166 deletions(-) diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index bb871c2a..0ab417e6 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -13,7 +13,7 @@ import ( // Handler is the Admin API request handler. type Handler struct { acmeResponder acmeAdminResponderInterface - policyResponder policyAdminResponderInterface + policyResponder PolicyAdminResponder } // Route traffic and implement the Router interface. @@ -24,7 +24,7 @@ func (h *Handler) Route(r api.Router) { } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler { +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder PolicyAdminResponder) api.RouterHandler { return &Handler{ acmeResponder: acmeResponder, policyResponder: policyResponder, @@ -36,7 +36,7 @@ var mustAuthority = func(ctx context.Context) adminAuthority { } // Route traffic and implement the Router interface. -func Route(r api.Router, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) { +func Route(r api.Router, acmeResponder acmeAdminResponderInterface, policyResponder PolicyAdminResponder) { authnz := func(next http.HandlerFunc) http.HandlerFunc { return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 6af1104a..9f338c0b 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "net/http" @@ -14,7 +15,9 @@ import ( "github.com/smallstep/certificates/authority/policy" ) -type policyAdminResponderInterface interface { +// PolicyAdminResponder is the interface responsible for writing ACME admin +// responses. +type PolicyAdminResponder interface { GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) @@ -29,39 +32,24 @@ type policyAdminResponderInterface interface { DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) } -// PolicyAdminResponder is responsible for writing ACME admin responses -type PolicyAdminResponder struct { - auth adminAuthority - adminDB admin.DB - acmeDB acme.DB - isLinkedCA bool -} - -// NewACMEAdminResponder returns a new ACMEAdminResponder -func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder { - - var isLinkedCA bool - if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok { - isLinkedCA = a.IsLinkedCA() - } +// policyAdminResponder is responsible for writing ACME admin responses. +type policyAdminResponder struct{} - return &PolicyAdminResponder{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, - isLinkedCA: isLinkedCA, - } +// NewACMEAdminResponder returns a new PolicyAdminResponder. +func NewPolicyAdminResponder() PolicyAdminResponder { + return &policyAdminResponder{} } // GetAuthorityPolicy handles the GET /admin/authority/policy request -func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context()) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) return @@ -76,15 +64,15 @@ func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht } // CreateAuthorityPolicy handles the POST /admin/authority/policy request -func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) @@ -113,7 +101,7 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r adm := linkedca.MustAdminFromContext(ctx) var createdPolicy *linkedca.Policy - if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return @@ -127,15 +115,15 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r } // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request -func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) @@ -163,7 +151,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r adm := linkedca.MustAdminFromContext(ctx) var updatedPolicy *linkedca.Policy - if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return @@ -177,15 +165,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request -func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) @@ -197,7 +185,7 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r return } - if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil { + if err := auth.RemoveAuthorityPolicy(ctx); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) return } @@ -206,15 +194,14 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r } // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - prov := linkedca.MustProvisionerFromContext(r.Context()) - + prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) @@ -225,16 +212,14 @@ func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r * } // CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - provisionerPolicy := prov.GetPolicy() if provisionerPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) @@ -256,8 +241,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } prov.Policy = newPolicy - - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) return @@ -271,16 +256,14 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } // UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) @@ -301,7 +284,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } prov.Policy = newPolicy - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) return @@ -315,16 +299,14 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } // DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - if prov.Policy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return @@ -333,7 +315,8 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, // remove the policy prov.Policy = nil - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) return } @@ -341,16 +324,14 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) } -func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -360,17 +341,15 @@ func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r * render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) } -func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) @@ -394,7 +373,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) return } @@ -402,17 +382,15 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) } -func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -434,7 +412,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) return } @@ -442,17 +421,15 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, render.ProtoJSONStatus(w, newPolicy, http.StatusOK) } -func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -463,7 +440,8 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, eak.Policy = nil acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) return } @@ -472,9 +450,10 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, } // blockLinkedCA blocks all API operations on linked deployments -func (par *PolicyAdminResponder) blockLinkedCA() error { +func blockLinkedCA(ctx context.Context) error { // temporary blocking linked deployments - if par.isLinkedCA { + adminDB := admin.MustFromContext(ctx) + if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() { return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") } return nil diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go index 1e70db52..1ec88fb6 100644 --- a/authority/admin/api/policy_test.go +++ b/authority/admin/api/policy_test.go @@ -109,7 +109,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -124,7 +125,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err.Message = "authority policy does not exist" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -179,7 +181,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { }, } return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -234,11 +237,12 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetAuthorityPolicy(w, req) @@ -301,7 +305,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -316,7 +321,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") err.Message = "authority already has a policy" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{}, nil @@ -332,7 +338,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -358,7 +365,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -509,11 +517,13 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateAuthorityPolicy(w, req) @@ -586,7 +596,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -602,7 +613,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil @@ -625,7 +637,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -658,7 +671,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -809,11 +823,13 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateAuthorityPolicy(w, req) @@ -886,7 +902,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -902,7 +919,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil @@ -924,7 +942,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting authority policy: force") err.Message = "error deleting authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -947,7 +966,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { } ctx := context.Background() return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -963,11 +983,13 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteAuthorityPolicy(w, req) @@ -1033,6 +1055,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1085,7 +1108,8 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ @@ -1135,11 +1159,13 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetProvisionerPolicy(w, req) @@ -1214,6 +1240,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { err.Message = "provisioner provName already has a policy" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 409, } @@ -1228,6 +1255,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -1251,7 +1279,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -1283,7 +1312,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1318,7 +1348,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1351,7 +1382,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1372,11 +1404,12 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateProvisionerPolicy(w, req) @@ -1452,6 +1485,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1474,6 +1508,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -1505,7 +1540,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -1538,7 +1574,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1574,7 +1611,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1608,7 +1646,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1629,11 +1668,12 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateProvisionerPolicy(w, req) @@ -1710,6 +1750,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1723,7 +1764,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting provisioner policy: force") err.Message = "error deleting provisioner policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return errors.New("force") @@ -1740,7 +1782,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1753,11 +1796,13 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteProvisionerPolicy(w, req) @@ -1828,6 +1873,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1885,7 +1931,8 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ @@ -1935,11 +1982,12 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetACMEAccountPolicy(w, req) @@ -2018,6 +2066,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK eakID already has a policy" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 409, } @@ -2036,6 +2085,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2064,6 +2114,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { }`) return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2091,7 +2142,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2124,7 +2176,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2147,11 +2200,12 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateACMEAccountPolicy(w, req) @@ -2231,6 +2285,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -2257,6 +2312,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2293,6 +2349,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { }`) return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2321,7 +2378,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2355,7 +2413,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2378,11 +2437,12 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateACMEAccountPolicy(w, req) @@ -2462,6 +2522,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -2487,7 +2548,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting ACME EAK policy: force") err.Message = "error deleting ACME EAK policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2518,7 +2580,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2533,11 +2596,12 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteACMEAccountPolicy(w, req) diff --git a/ca/ca.go b/ca/ca.go index 16a5c600..9252fff7 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -213,7 +213,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB) + policyAdminResponder := adminAPI.NewPolicyAdminResponder() mux.Route("/admin", func(r chi.Router) { adminAPI.Route(r, acmeAdminResponder, policyAdminResponder) }) From 1e03bbb1afb8f84f99ccd4d2b7a62f0051b4f1c8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 6 May 2022 14:11:10 -0700 Subject: [PATCH 32/40] Change types in the ACMEAdminResponder --- authority/admin/api/acme.go | 17 +++++---- authority/admin/api/handler.go | 69 +++++++++++++++++++--------------- authority/admin/api/policy.go | 2 +- 3 files changed, 48 insertions(+), 40 deletions(-) diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 814ca226..db393e9a 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -53,32 +53,33 @@ func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { } } -type acmeAdminResponderInterface interface { +// ACMEAdminResponder is responsible for writing ACME admin responses +type ACMEAdminResponder interface { GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) } -// ACMEAdminResponder is responsible for writing ACME admin responses -type ACMEAdminResponder struct{} +// acmeAdminResponder implements ACMEAdminResponder. +type acmeAdminResponder struct{} // NewACMEAdminResponder returns a new ACMEAdminResponder -func NewACMEAdminResponder() *ACMEAdminResponder { - return &ACMEAdminResponder{} +func NewACMEAdminResponder() ACMEAdminResponder { + return &acmeAdminResponder{} } // GetExternalAccountKeys writes the response for the EAB keys GET endpoint -func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint -func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint -func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 0ab417e6..1e5919ce 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -12,19 +12,21 @@ import ( // Handler is the Admin API request handler. type Handler struct { - acmeResponder acmeAdminResponderInterface + acmeResponder ACMEAdminResponder policyResponder PolicyAdminResponder } // Route traffic and implement the Router interface. // -// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface) +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) func (h *Handler) Route(r api.Router) { Route(r, h.acmeResponder, h.policyResponder) } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder PolicyAdminResponder) api.RouterHandler { +// +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler { return &Handler{ acmeResponder: acmeResponder, policyResponder: policyResponder, @@ -36,7 +38,7 @@ var mustAuthority = func(ctx context.Context) adminAuthority { } // Route traffic and implement the Router interface. -func Route(r api.Router, acmeResponder acmeAdminResponderInterface, policyResponder PolicyAdminResponder) { +func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) { authnz := func(next http.HandlerFunc) http.HandlerFunc { return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } @@ -79,32 +81,37 @@ func Route(r api.Router, acmeResponder acmeAdminResponderInterface, policyRespon r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) - // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey)) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey)) - - // Policy - Authority - r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy)) - r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy)) - r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy)) - r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy)) - - // Policy - Provisioner - r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy)) - r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy)) - r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy)) - r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy)) - - // Policy - ACME Account - r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) - r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) - r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) - r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) - r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) - r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) - r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) - r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + // ACME responder + if acmeResponder != nil { + // ACME External Account Binding Keys + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey)) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey)) + } + // Policy responder + if policyResponder != nil { + // Policy - Authority + r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy)) + r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy)) + r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy)) + r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy)) + + // Policy - Provisioner + r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy)) + r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy)) + r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy)) + r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy)) + + // Policy - ACME Account + r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + } } diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 9f338c0b..a478c83c 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -32,7 +32,7 @@ type PolicyAdminResponder interface { DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) } -// policyAdminResponder is responsible for writing ACME admin responses. +// policyAdminResponder implements PolicyAdminResponder. type policyAdminResponder struct{} // NewACMEAdminResponder returns a new PolicyAdminResponder. From 894242297350750f33d134545213cb1e2a8c9a76 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 10 May 2022 16:51:09 -0700 Subject: [PATCH 33/40] Add GetID() and add authority to initial context --- authority/authority.go | 28 +++++++++++++++++++--------- authority/authority_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index cdf2c8bf..c184c6e9 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -250,6 +250,7 @@ func (a *Authority) init() error { } var err error + ctx := NewContext(context.Background(), a) // Set password if they are not set. var configPassword []byte @@ -285,7 +286,7 @@ func (a *Authority) init() error { if a.config.KMS != nil { options = *a.config.KMS } - a.keyManager, err = kms.New(context.Background(), options) + a.keyManager, err = kms.New(ctx, options) if err != nil { return err } @@ -315,7 +316,7 @@ func (a *Authority) init() error { // Configure linked RA if linkedcaClient != nil && options.CertificateAuthority == "" { - conf, err := linkedcaClient.GetConfiguration(context.Background()) + conf, err := linkedcaClient.GetConfiguration(ctx) if err != nil { return err } @@ -349,7 +350,7 @@ func (a *Authority) init() error { } } - a.x509CAService, err = cas.New(context.Background(), options) + a.x509CAService, err = cas.New(ctx, options) if err != nil { return err } @@ -536,7 +537,7 @@ func (a *Authority) init() error { } } - a.scepService, err = scep.NewService(context.Background(), options) + a.scepService, err = scep.NewService(ctx, options) if err != nil { return err } @@ -558,19 +559,19 @@ func (a *Authority) init() error { } } - provs, err := a.adminDB.GetProvisioners(context.Background()) + provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { // Create First Provisioner - prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) + prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password)) if err != nil { return admin.WrapErrorISE(err, "error creating first provisioner") } // Create first admin - if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{ ProvisionerId: prov.Id, Subject: "step", Type: linkedca.Admin_SUPER_ADMIN, @@ -581,12 +582,12 @@ func (a *Authority) init() error { } // Load Provisioners and Admins - if err := a.reloadAdminResources(context.Background()); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { return err } // Load x509 and SSH Policy Engines - if err := a.reloadPolicyEngines(context.Background()); err != nil { + if err := a.reloadPolicyEngines(ctx); err != nil { return err } @@ -611,6 +612,15 @@ func (a *Authority) init() error { return nil } +// GetID returns the define authority id or a zero uuid. +func (a *Authority) GetID() string { + const zeroUUID = "00000000-0000-0000-0000-000000000000" + if id := a.config.AuthorityConfig.AuthorityID; id != "" { + return id + } + return zeroUUID +} + // GetDatabase returns the authority database. If the configuration does not // define a database, GetDatabase will return a db.SimpleDB instance. func (a *Authority) GetDatabase() db.AuthDB { diff --git a/authority/authority_test.go b/authority/authority_test.go index 1f63333d..9f35f23e 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" @@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) { }) } } + +func TestAuthority_GetID(t *testing.T) { + type fields struct { + authorityID string + } + tests := []struct { + name string + fields fields + want string + }{ + {"ok", fields{""}, "00000000-0000-0000-0000-000000000000"}, + {"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + AuthorityID: tt.fields.authorityID, + }, + }, + } + if got := a.GetID(); got != tt.want { + t.Errorf("Authority.GetID() = %v, want %v", got, tt.want) + } + }) + } +} From 400b1ece0bc19b217262712ea292cdeea44b0135 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 12 May 2022 17:39:36 -0700 Subject: [PATCH 34/40] Remove scep handler after merge. --- scep/api/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scep/api/api.go b/scep/api/api.go index f063df21..b738a933 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -109,7 +109,7 @@ func Get(w http.ResponseWriter, r *http.Request) { case opnGetCACaps: res, err = GetCACaps(ctx) case opnPKIOperation: - res, err = h.PKIOperation(ctx, req) + res, err = PKIOperation(ctx, req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } From dec1067addc5bd8f43fafeaca7cd4f6264738abf Mon Sep 17 00:00:00 2001 From: Erik De Lamarter Date: Mon, 25 Apr 2022 22:45:22 +0200 Subject: [PATCH 35/40] vault kubernetes auth --- cas/vaultcas/vaultcas.go | 70 ++++++++++++++++++++++++++-------------- go.mod | 1 + go.sum | 2 ++ 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go index c29ef691..8a09a850 100644 --- a/cas/vaultcas/vaultcas.go +++ b/cas/vaultcas/vaultcas.go @@ -18,6 +18,7 @@ import ( vault "github.com/hashicorp/vault/api" auth "github.com/hashicorp/vault/api/auth/approle" + kubeauth "github.com/hashicorp/vault/api/auth/kubernetes" ) func init() { @@ -34,6 +35,7 @@ type VaultOptions struct { PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"` PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` + KubernetesRole string `json:"kubernetesRole,omitempty"` RoleID string `json:"roleID,omitempty"` SecretID auth.SecretID `json:"secretID,omitempty"` AppRole string `json:"appRole,omitempty"` @@ -77,31 +79,49 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } - var appRoleAuth *auth.AppRoleAuth - if vc.IsWrappingToken { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithWrappingToken(), - auth.WithMountPath(vc.AppRole), + if vc.KubernetesRole != "" { + var kubernetesAuth *kubeauth.KubernetesAuth + kubernetesAuth, err = kubeauth.NewKubernetesAuth( + vc.KubernetesRole, ) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + authInfo, err := client.Auth().Login(ctx, kubernetesAuth) + if err != nil { + return nil, fmt.Errorf("unable to login to Kubernetes auth method: %w", err) + } + if authInfo == nil { + return nil, errors.New("no auth info was returned after login") + } } else { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithMountPath(vc.AppRole), - ) - } - if err != nil { - return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) - } + var appRoleAuth *auth.AppRoleAuth + if vc.IsWrappingToken { + appRoleAuth, err = auth.NewAppRoleAuth( + vc.RoleID, + &vc.SecretID, + auth.WithWrappingToken(), + auth.WithMountPath(vc.AppRole), + ) + } else { + appRoleAuth, err = auth.NewAppRoleAuth( + vc.RoleID, + &vc.SecretID, + auth.WithMountPath(vc.AppRole), + ) + } + if err != nil { + return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) + } - authInfo, err := client.Auth().Login(ctx, appRoleAuth) - if err != nil { - return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) - } - if authInfo == nil { - return nil, errors.New("no auth info was returned after login") + authInfo, err := client.Auth().Login(ctx, appRoleAuth) + if err != nil { + return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) + } + if authInfo == nil { + return nil, errors.New("no auth info was returned after login") + } } return &VaultCAS{ @@ -272,11 +292,11 @@ func loadOptions(config json.RawMessage) (*VaultOptions, error) { vc.PKIRoleEd25519 = vc.PKIRoleDefault } - if vc.RoleID == "" { - return nil, errors.New("vaultCAS config options must define `roleID`") + if vc.RoleID == "" && vc.KubernetesRole == "" { + return nil, errors.New("vaultCAS config options must define `roleID` or `kubernetesRole`") } - if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" { + if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" && vc.RoleID != "" { return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") } diff --git a/go.mod b/go.mod index 8b66f470..0b772018 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/googleapis/gax-go/v2 v2.1.1 github.com/hashicorp/vault/api v1.3.1 github.com/hashicorp/vault/api/auth/approle v0.1.1 + github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 github.com/jhump/protoreflect v1.9.0 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.13 // indirect diff --git a/go.sum b/go.sum index 4780111e..d76648c2 100644 --- a/go.sum +++ b/go.sum @@ -449,6 +449,8 @@ github.com/hashicorp/vault/api v1.3.1 h1:pkDkcgTh47PRjY1NEFeofqR4W/HkNUi9qIakESO github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw= github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ= github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 h1:6BtyahbF4aQp8gg3ww0A/oIoqzbhpNP1spXU3nHE0n0= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0/go.mod h1:Pdgk78uIs0mgDOLvc3a+h/vYIT9rznw2sz+ucuH9024= github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY= github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= From 6c44291d8df63e16e662a9cc03ffa8783fa364ce Mon Sep 17 00:00:00 2001 From: Erik De Lamarter Date: Mon, 9 May 2022 13:27:37 +0200 Subject: [PATCH 36/40] refactor vault auth --- cas/vaultcas/auth/approle/approle.go | 46 ++++ cas/vaultcas/auth/approle/approle_test.go | 16 ++ cas/vaultcas/auth/kubernetes/kubernetes.go | 43 +++ .../auth/kubernetes/kubernetes_test.go | 21 ++ cas/vaultcas/auth/kubernetes/token | 1 + cas/vaultcas/vaultcas.go | 120 +++----- cas/vaultcas/vaultcas_test.go | 256 ++++-------------- 7 files changed, 220 insertions(+), 283 deletions(-) create mode 100644 cas/vaultcas/auth/approle/approle.go create mode 100644 cas/vaultcas/auth/approle/approle_test.go create mode 100644 cas/vaultcas/auth/kubernetes/kubernetes.go create mode 100644 cas/vaultcas/auth/kubernetes/kubernetes_test.go create mode 100644 cas/vaultcas/auth/kubernetes/token diff --git a/cas/vaultcas/auth/approle/approle.go b/cas/vaultcas/auth/approle/approle.go new file mode 100644 index 00000000..38d3c51c --- /dev/null +++ b/cas/vaultcas/auth/approle/approle.go @@ -0,0 +1,46 @@ +package approle + +import ( + "encoding/json" + "fmt" + + "github.com/hashicorp/vault/api/auth/approle" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is approle +type AuthOptions struct { + RoleID string `json:"roleID,omitempty"` + SecretID string `json:"secretID,omitempty"` + IsWrappingToken bool `json:"isWrappingToken,omitempty"` +} + +func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding AppRole auth options: %w", err) + } + + var approleAuth *approle.AppRoleAuth + + var loginOptions []approle.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, approle.WithMountPath(mountPath)) + } + if opts.IsWrappingToken { + loginOptions = append(loginOptions, approle.WithWrappingToken()) + } + + sid := approle.SecretID{ + FromString: opts.SecretID, + } + + approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return approleAuth, nil +} diff --git a/cas/vaultcas/auth/approle/approle_test.go b/cas/vaultcas/auth/approle/approle_test.go new file mode 100644 index 00000000..ab7e6a97 --- /dev/null +++ b/cas/vaultcas/auth/approle/approle_test.go @@ -0,0 +1,16 @@ +package approle + +import ( + "encoding/json" + "testing" +) + +func TestKubernetes_NewKubernetesAuthMethod(t *testing.T) { + mountPath := "approle" + raw := `{"roleID": "roleID", "secretID": "secretIDwrapped", "isWrappedToken": true}` + + _, err := NewApproleAuthMethod(mountPath, json.RawMessage(raw)) + if err != nil { + t.Fatal(err) + } +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes.go b/cas/vaultcas/auth/kubernetes/kubernetes.go new file mode 100644 index 00000000..0c4db62f --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes.go @@ -0,0 +1,43 @@ +package kubernetes + +import ( + "encoding/json" + "fmt" + + "github.com/hashicorp/vault/api/auth/kubernetes" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is kubernetes +type AuthOptions struct { + Role string `json:"role,omitempty"` + TokenPath string `json:"tokenPath,omitempty"` +} + +func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err) + } + + var kubernetesAuth *kubernetes.KubernetesAuth + + var loginOptions []kubernetes.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath)) + } + if opts.TokenPath != "" { + loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath)) + } + kubernetesAuth, err = kubernetes.NewKubernetesAuth( + opts.Role, + loginOptions..., + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return kubernetesAuth, nil +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes_test.go b/cas/vaultcas/auth/kubernetes/kubernetes_test.go new file mode 100644 index 00000000..604f1898 --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes_test.go @@ -0,0 +1,21 @@ +package kubernetes + +import ( + "encoding/json" + "path" + "path/filepath" + "runtime" + "testing" +) + +func TestKubernetes_NewKubernetesAuthMethod(t *testing.T) { + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + mountPath := "kubernetes" + raw := `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}` + + _, err := NewKubernetesAuthMethod(mountPath, json.RawMessage(raw)) + if err != nil { + t.Fatal(err) + } +} diff --git a/cas/vaultcas/auth/kubernetes/token b/cas/vaultcas/auth/kubernetes/token new file mode 100644 index 00000000..6745be67 --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/token @@ -0,0 +1 @@ +token \ No newline at end of file diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go index 8a09a850..02c814b7 100644 --- a/cas/vaultcas/vaultcas.go +++ b/cas/vaultcas/vaultcas.go @@ -15,10 +15,10 @@ import ( "time" "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/cas/vaultcas/auth/approle" + "github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes" vault "github.com/hashicorp/vault/api" - auth "github.com/hashicorp/vault/api/auth/approle" - kubeauth "github.com/hashicorp/vault/api/auth/kubernetes" ) func init() { @@ -30,16 +30,14 @@ func init() { // VaultOptions defines the configuration options added using the // apiv1.Options.Config field. type VaultOptions struct { - PKI string `json:"pki,omitempty"` - PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` - PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` - PKIRoleEC string `json:"pkiRoleEC,omitempty"` - PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` - KubernetesRole string `json:"kubernetesRole,omitempty"` - RoleID string `json:"roleID,omitempty"` - SecretID auth.SecretID `json:"secretID,omitempty"` - AppRole string `json:"appRole,omitempty"` - IsWrappingToken bool `json:"isWrappingToken,omitempty"` + PKIMountPath string `json:"pkiMountPath,omitempty"` + PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` + PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` + PKIRoleEC string `json:"pkiRoleEC,omitempty"` + PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` + AuthType string `json:"authType,omitempty"` + AuthMountPath string `json:"authMountPath,omitempty"` + AuthOptions json.RawMessage `json:"authOptions,omitempty"` } // VaultCAS implements a Certificate Authority Service using Hashicorp Vault. @@ -79,49 +77,25 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } - if vc.KubernetesRole != "" { - var kubernetesAuth *kubeauth.KubernetesAuth - kubernetesAuth, err = kubeauth.NewKubernetesAuth( - vc.KubernetesRole, - ) - if err != nil { - return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) - } - - authInfo, err := client.Auth().Login(ctx, kubernetesAuth) - if err != nil { - return nil, fmt.Errorf("unable to login to Kubernetes auth method: %w", err) - } - if authInfo == nil { - return nil, errors.New("no auth info was returned after login") - } - } else { - var appRoleAuth *auth.AppRoleAuth - if vc.IsWrappingToken { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithWrappingToken(), - auth.WithMountPath(vc.AppRole), - ) - } else { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithMountPath(vc.AppRole), - ) - } - if err != nil { - return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) - } + var method vault.AuthMethod + switch vc.AuthType { + case "kubernetes": + method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions) + case "approle": + method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) + default: + return nil, fmt.Errorf("unknown auth type: %v", vc.AuthType) + } + if err != nil { + return nil, fmt.Errorf("unable to configure auth method: %w", err) + } - authInfo, err := client.Auth().Login(ctx, appRoleAuth) - if err != nil { - return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) - } - if authInfo == nil { - return nil, errors.New("no auth info was returned after login") - } + authInfo, err := client.Auth().Login(ctx, method) + if err != nil { + return nil, fmt.Errorf("unable to login to Kubernetes auth method: %w", err) + } + if authInfo == nil { + return nil, errors.New("no auth info was returned after login") } return &VaultCAS{ @@ -154,7 +128,7 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { - secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain") + secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain") if err != nil { return nil, fmt.Errorf("error reading ca chain: %w", err) } @@ -210,7 +184,7 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv vaultReq := map[string]interface{}{ "serial_number": formatSerialNumber(sn), } - _, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq) + _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq) if err != nil { return nil, fmt.Errorf("error revoking certificate: %w", err) } @@ -244,7 +218,7 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time. "ttl": lifetime.Seconds(), } - secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq) + secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq) if err != nil { return nil, nil, fmt.Errorf("error signing certificate: %w", err) } @@ -267,21 +241,17 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time. } func loadOptions(config json.RawMessage) (*VaultOptions, error) { - var vc *VaultOptions + // setup default values + vc := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "default", + } err := json.Unmarshal(config, &vc) if err != nil { return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) } - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.PKIRoleDefault == "" { - vc.PKIRoleDefault = "default" // use default pki role name - } - if vc.PKIRoleRSA == "" { vc.PKIRoleRSA = vc.PKIRoleDefault } @@ -292,23 +262,7 @@ func loadOptions(config json.RawMessage) (*VaultOptions, error) { vc.PKIRoleEd25519 = vc.PKIRoleDefault } - if vc.RoleID == "" && vc.KubernetesRole == "" { - return nil, errors.New("vaultCAS config options must define `roleID` or `kubernetesRole`") - } - - if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" && vc.RoleID != "" { - return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") - } - - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.AppRole == "" { - vc.AppRole = "auth/approle" - } - - return vc, nil + return &vc, nil } func parseCertificates(pemCert string) []*x509.Certificate { diff --git a/cas/vaultcas/vaultcas_test.go b/cas/vaultcas/vaultcas_test.go index 9f73a1ee..3c1f09a3 100644 --- a/cas/vaultcas/vaultcas_test.go +++ b/cas/vaultcas/vaultcas_test.go @@ -14,7 +14,6 @@ import ( "time" vault "github.com/hashicorp/vault/api" - auth "github.com/hashicorp/vault/api/auth/approle" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/pemutil" ) @@ -99,7 +98,7 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.RequestURI == "/v1/auth/auth/approle/login": + case r.RequestURI == "/v1/auth/approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { @@ -183,11 +182,10 @@ func TestNew_register(t *testing.T) { CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, Config: json.RawMessage(`{ - "PKI": "pki", + "PKIMountPath": "pki", "PKIRoleDefault": "pki-role", - "RoleID": "roleID", - "SecretID": {"FromString": "secretID"}, - "IsWrappingToken": false + "AuthType": "approle", + "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} }`), }) @@ -201,15 +199,13 @@ func TestVaultCAS_CreateCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + AuthType: "approle", + AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { @@ -291,7 +287,7 @@ func TestVaultCAS_GetCertificateAuthority(t *testing.T) { } options := VaultOptions{ - PKI: "pki", + PKIMountPath: "pki", } rootCert := parseCertificates(testRootCertificate)[0] @@ -335,15 +331,13 @@ func TestVaultCAS_RevokeCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + AuthType: "approle", + AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { @@ -407,15 +401,13 @@ func TestVaultCAS_RenewCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", + AuthType: "approle", + AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { @@ -464,202 +456,66 @@ func TestVaultCAS_loadOptions(t *testing.T) { want *VaultOptions wantErr bool }{ - { - "ok mandatory with SecretID FromString", - `{"RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with SecretID FromFile", - `{"RoleID": "roleID", "SecretID": {"FromFile": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromFile: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with SecretID FromEnv", - `{"RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, { "ok mandatory PKIRole PKIRoleEd25519", - `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "role", - PKIRoleEC: "role", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "role", + PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRole PKIRoleEC", - `{"PKIRoleDefault": "role", "PKIRoleEC": "ec" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "role", - PKIRoleEC: "ec", - PKIRoleEd25519: "role", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "ec", + PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRole PKIRoleRSA", - `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "role", - PKIRoleEd25519: "role", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "role", + PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", - `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "default", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", - `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", }, false, }, - { - "ok mandatory with AppRole", - `{"AppRole": "test", "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "test", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with IsWrappingToken", - `{"IsWrappingToken": true, "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: true, - }, - false, - }, - { - "fail with SecretID FromFail", - `{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`, - nil, - true, - }, - { - "fail with SecretID empty FromEnv", - `{"RoleID": "roleID", "SecretID": {"FromEnv": ""}}`, - nil, - true, - }, - { - "fail with SecretID empty FromFile", - `{"RoleID": "roleID", "SecretID": {"FromFile": ""}}`, - nil, - true, - }, - { - "fail with SecretID empty FromString", - `{"RoleID": "roleID", "SecretID": {"FromString": ""}}`, - nil, - true, - }, - { - "fail mandatory with SecretID FromFail", - `{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`, - nil, - true, - }, - { - "fail missing RoleID", - `{"SecretID": {"FromString": "secretID"}}`, - nil, - true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 6989c7f146c534df3ebb9fbac5c92e9735fde882 Mon Sep 17 00:00:00 2001 From: Erik De Lamarter Date: Sun, 15 May 2022 17:42:08 +0200 Subject: [PATCH 37/40] vault auth unit tests --- cas/vaultcas/auth/approle/approle.go | 24 ++- cas/vaultcas/auth/approle/approle_test.go | 163 +++++++++++++++++- cas/vaultcas/auth/kubernetes/kubernetes.go | 6 + .../auth/kubernetes/kubernetes_test.go | 140 ++++++++++++++- cas/vaultcas/vaultcas_test.go | 8 - 5 files changed, 321 insertions(+), 20 deletions(-) diff --git a/cas/vaultcas/auth/approle/approle.go b/cas/vaultcas/auth/approle/approle.go index 38d3c51c..d842bae0 100644 --- a/cas/vaultcas/auth/approle/approle.go +++ b/cas/vaultcas/auth/approle/approle.go @@ -2,6 +2,7 @@ package approle import ( "encoding/json" + "errors" "fmt" "github.com/hashicorp/vault/api/auth/approle" @@ -12,6 +13,8 @@ import ( type AuthOptions struct { RoleID string `json:"roleID,omitempty"` SecretID string `json:"secretID,omitempty"` + SecretIDFile string `json:"secretIDFile,omitempty"` + SecretIDEnv string `json:"secretIDEnv,omitempty"` IsWrappingToken bool `json:"isWrappingToken,omitempty"` } @@ -33,8 +36,25 @@ func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.A loginOptions = append(loginOptions, approle.WithWrappingToken()) } - sid := approle.SecretID{ - FromString: opts.SecretID, + if opts.RoleID == "" { + return nil, errors.New("you must set roleID") + } + + var sid approle.SecretID + if opts.SecretID != "" { + sid = approle.SecretID{ + FromString: opts.SecretID, + } + } else if opts.SecretIDFile != "" { + sid = approle.SecretID{ + FromFile: opts.SecretIDFile, + } + } else if opts.SecretIDEnv != "" { + sid = approle.SecretID{ + FromEnv: opts.SecretIDEnv, + } + } else { + return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv") } approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...) diff --git a/cas/vaultcas/auth/approle/approle_test.go b/cas/vaultcas/auth/approle/approle_test.go index ab7e6a97..ec4d523f 100644 --- a/cas/vaultcas/auth/approle/approle_test.go +++ b/cas/vaultcas/auth/approle/approle_test.go @@ -1,16 +1,171 @@ package approle import ( + "context" "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" "testing" + + vault "github.com/hashicorp/vault/api" ) -func TestKubernetes_NewKubernetesAuthMethod(t *testing.T) { - mountPath := "approle" - raw := `{"roleID": "roleID", "secretID": "secretIDwrapped", "isWrappedToken": true}` +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL - _, err := NewApproleAuthMethod(mountPath, json.RawMessage(raw)) + client, err := vault.NewClient(config) if err != nil { + srv.Close() t.Fatal(err) } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "approle", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-approle", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`, + false, + }, + { + "ok secret-id string and wrapped", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id string and wrapped with custom mountPath", + "approle2", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + false, + }, + { + "ok secret-id env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + false, + }, + { + "fail mandatory role-id", + "", + `{}`, + true, + }, + { + "fail mandatory secret-id any", + "", + `{"RoleID": "0000-0000-0000-0000"}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } } diff --git a/cas/vaultcas/auth/kubernetes/kubernetes.go b/cas/vaultcas/auth/kubernetes/kubernetes.go index 0c4db62f..267bcdca 100644 --- a/cas/vaultcas/auth/kubernetes/kubernetes.go +++ b/cas/vaultcas/auth/kubernetes/kubernetes.go @@ -2,6 +2,7 @@ package kubernetes import ( "encoding/json" + "errors" "fmt" "github.com/hashicorp/vault/api/auth/kubernetes" @@ -31,6 +32,11 @@ func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubern if opts.TokenPath != "" { loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath)) } + + if opts.Role == "" { + return nil, errors.New("you must set role") + } + kubernetesAuth, err = kubernetes.NewKubernetesAuth( opts.Role, loginOptions..., diff --git a/cas/vaultcas/auth/kubernetes/kubernetes_test.go b/cas/vaultcas/auth/kubernetes/kubernetes_test.go index 604f1898..55be904d 100644 --- a/cas/vaultcas/auth/kubernetes/kubernetes_test.go +++ b/cas/vaultcas/auth/kubernetes/kubernetes_test.go @@ -1,21 +1,149 @@ package kubernetes import ( + "context" "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" "path" "path/filepath" "runtime" "testing" + + vault "github.com/hashicorp/vault/api" ) -func TestKubernetes_NewKubernetesAuthMethod(t *testing.T) { - _, filename, _, _ := runtime.Caller(0) - tokenPath := filepath.Join(path.Dir(filename), "token") - mountPath := "kubernetes" - raw := `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}` +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } - _, err := NewKubernetesAuthMethod(mountPath, json.RawMessage(raw)) + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) if err != nil { + srv.Close() t.Fatal(err) } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "kubernetes", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-kubernetes", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`, + false, + }, + { + "fail mandatory role", + "", + `{}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } } diff --git a/cas/vaultcas/vaultcas_test.go b/cas/vaultcas/vaultcas_test.go index 3c1f09a3..0ea0c4b1 100644 --- a/cas/vaultcas/vaultcas_test.go +++ b/cas/vaultcas/vaultcas_test.go @@ -182,8 +182,6 @@ func TestNew_register(t *testing.T) { CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, Config: json.RawMessage(`{ - "PKIMountPath": "pki", - "PKIRoleDefault": "pki-role", "AuthType": "approle", "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} }`), @@ -204,8 +202,6 @@ func TestVaultCAS_CreateCertificate(t *testing.T) { PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", - AuthType: "approle", - AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { @@ -336,8 +332,6 @@ func TestVaultCAS_RevokeCertificate(t *testing.T) { PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", - AuthType: "approle", - AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { @@ -406,8 +400,6 @@ func TestVaultCAS_RenewCertificate(t *testing.T) { PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", - AuthType: "approle", - AuthOptions: json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`), } type fields struct { From 9ec154aab02f25fad6ee47a545a8250ed76e3345 Mon Sep 17 00:00:00 2001 From: Erik De Lamarter Date: Tue, 17 May 2022 22:13:11 +0200 Subject: [PATCH 38/40] rewrite and improve secret-id config --- cas/vaultcas/auth/approle/approle.go | 9 +++++---- cas/vaultcas/auth/approle/approle_test.go | 24 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/cas/vaultcas/auth/approle/approle.go b/cas/vaultcas/auth/approle/approle.go index d842bae0..118afb10 100644 --- a/cas/vaultcas/auth/approle/approle.go +++ b/cas/vaultcas/auth/approle/approle.go @@ -41,19 +41,20 @@ func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.A } var sid approle.SecretID - if opts.SecretID != "" { + switch { + case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "": sid = approle.SecretID{ FromString: opts.SecretID, } - } else if opts.SecretIDFile != "" { + case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "": sid = approle.SecretID{ FromFile: opts.SecretIDFile, } - } else if opts.SecretIDEnv != "" { + case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "": sid = approle.SecretID{ FromEnv: opts.SecretIDEnv, } - } else { + default: return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv") } diff --git a/cas/vaultcas/auth/approle/approle_test.go b/cas/vaultcas/auth/approle/approle_test.go index ec4d523f..28b7b7f7 100644 --- a/cas/vaultcas/auth/approle/approle_test.go +++ b/cas/vaultcas/auth/approle/approle_test.go @@ -158,6 +158,30 @@ func TestApprole_NewApproleAuthMethod(t *testing.T) { `{"RoleID": "0000-0000-0000-0000"}`, true, }, + { + "fail multiple secret-id types id and env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types id and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + true, + }, + { + "fail multiple secret-id types env and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types all", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 07984a968fba1eedfa514d80e088a41e1f59651f Mon Sep 17 00:00:00 2001 From: Erik DeLamarter Date: Sat, 21 May 2022 21:00:50 +0200 Subject: [PATCH 39/40] better error messages Co-authored-by: Mariano Cano --- cas/vaultcas/vaultcas.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go index 02c814b7..a5658620 100644 --- a/cas/vaultcas/vaultcas.go +++ b/cas/vaultcas/vaultcas.go @@ -84,15 +84,15 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { case "approle": method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) default: - return nil, fmt.Errorf("unknown auth type: %v", vc.AuthType) + return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType) } if err != nil { - return nil, fmt.Errorf("unable to configure auth method: %w", err) + return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err) } authInfo, err := client.Auth().Login(ctx, method) if err != nil { - return nil, fmt.Errorf("unable to login to Kubernetes auth method: %w", err) + return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err) } if authInfo == nil { return nil, errors.New("no auth info was returned after login") From e7f4eaf6c42b9334866ae182ca4befccb36f72a8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 23 May 2022 14:04:31 -0700 Subject: [PATCH 40/40] Remove explicit deprecation notice This will avoid linter errors on other projects for now. --- acme/api/handler.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 96e22d85..2e3931b1 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -83,7 +83,9 @@ type handler struct { // this route adds will add a new middleware that will set the ACME components // on the context. // -// Deprecated: use api.Route(r api.Router) +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. func (h *handler) Route(r api.Router) { client := acme.NewClient() linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix) @@ -101,7 +103,9 @@ func (h *handler) Route(r api.Router) { // NewHandler returns a new ACME API handler. // -// Deprecated: use api.Route(r api.Router) +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. func NewHandler(opts HandlerOptions) api.RouterHandler { return &handler{ opts: &opts,