Refactor ACME api.

pull/914/head
Mariano Cano 2 years ago
parent fddd6f7d95
commit d1f75f1720

@ -69,6 +69,9 @@ func (u *UpdateAccountRequest) Validate() error {
// NewAccount is the handler resource for creating new ACME accounts.
func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
payload, err := payloadFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -120,7 +123,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
return
}
db := acme.MustFromContext(ctx)
acc = &acme.Account{
Key: jwk,
Contact: nar.Contact,
@ -148,16 +150,18 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK
}
o := optionsFromContext(ctx)
o.linker.LinkAccount(ctx, acc)
linker.LinkAccount(ctx, acc)
w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus)
}
// GetOrUpdateAccount is the api for updating an ACME account.
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -189,7 +193,6 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc.Contact = uar.Contact
}
db := acme.MustFromContext(ctx)
if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return
@ -197,10 +200,9 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
}
}
o := optionsFromContext(ctx)
o.linker.LinkAccount(ctx, acc)
linker.LinkAccount(ctx, acc)
w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc)
}
@ -216,6 +218,9 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -227,15 +232,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
return
}
db := acme.MustFromContext(ctx)
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
render.Error(w, err)
return
}
o := optionsFromContext(ctx)
o.linker.LinkOrdersByAccountID(ctx, orders)
linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders)
logOrdersByAccount(w, orders)

@ -47,7 +47,7 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest)
return nil, acmeErr
}
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
if err != nil {
if _, ok := err.(*acme.Error); ok {
@ -103,7 +103,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
// o The "nonce" field MUST NOT be present
// o The "url" field MUST be set to the same value as the outer JWS
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
if jws == nil {
return "", acme.NewErrorISE("no JWS provided")
}

@ -2,12 +2,10 @@ package api
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net"
"net/http"
"time"
@ -70,144 +68,117 @@ type HandlerOptions struct {
// PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error)
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
}
type optionsKey struct{}
func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context {
return context.WithValue(ctx, optionsKey{}, o)
}
func optionsFromContext(ctx context.Context) *HandlerOptions {
o, ok := ctx.Value(optionsKey{}).(*HandlerOptions)
if !ok {
panic("acme options are not in the context")
}
return o
}
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return authority.MustFromContext(ctx)
}
// Handler is the ACME API request handler.
type Handler struct {
// handler is the ACME API request handler.
type handler struct {
opts *HandlerOptions
}
// Route traffic and implement the Router interface.
//
// Deprecated: Use api.Route(r Router, opts *HandlerOptions)
func (h *Handler) Route(r api.Router) {
Route(r, h.opts)
func (h *handler) Route(r api.Router) {
route(r, h.opts)
}
// NewHandler returns a new ACME API handler.
//
// Deprecated: Use api.Route(r Router, opts *HandlerOptions)
func NewHandler(ops HandlerOptions) api.RouterHandler {
return &Handler{
opts: &ops,
func NewHandler(opts HandlerOptions) api.RouterHandler {
return &handler{
opts: &opts,
}
}
// Route traffic and implement the Router interface.
func Route(r api.Router, opts *HandlerOptions) {
// by default all prerequisites are met
if opts.PrerequisitesChecker == nil {
opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) {
return true, nil
}
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
opts.linker = NewLinker(opts.DNS, opts.Prefix)
opts.validateChallengeOptions = &acme.ValidateChallengeOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
}
withOptions := func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Route traffic and implement the Router interface. This method requires that
// all the acme components, authority, db, client, linker, and prerequisite
// checker to be present in the context.
func Route(r api.Router) {
route(r, nil)
}
// For backward compatibility with NewHandler.
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
func route(r api.Router, opts *HandlerOptions) {
var withContext func(next nextHTTP) nextHTTP
// For backward compatibility this block adds will add a new middleware that
// will set the ACME components to the context.
if opts != nil {
client := acme.NewClient()
linker := acme.NewLinker(opts.DNS, opts.Prefix)
withContext = func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker)
next(w, r.WithContext(ctx))
}
if opts.DB != nil {
ctx = acme.NewContext(ctx, opts.DB)
}
} else {
withContext = func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
next(w, r)
}
ctx = newOptionsContext(ctx, opts)
next(w, r.WithContext(ctx))
}
}
commonMiddleware := func(next nextHTTP) nextHTTP {
return withContext(func(w http.ResponseWriter, r *http.Request) {
// Linker middleware gets the provisioner and current url from the
// request and sets them in the context.
linker := acme.MustLinkerFromContext(r.Context())
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
})
}
validatingMiddleware := func(next nextHTTP) nextHTTP {
return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))))))
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
}
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))))
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
}
extractPayloadByKid := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))))
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
}
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))))
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
}
getPath := opts.linker.GetUnescapedPathSuffix
getPath := acme.GetUnescapedPathSuffix
// Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce)))))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce)))))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory)))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory)))))
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"),
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
commonMiddleware(GetDirectory))
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
extractPayloadByJWK(NewAccount))
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"),
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(GetOrUpdateAccount))
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"),
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(NotImplemented))
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
extractPayloadByKid(NewOrder))
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"),
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(isPostAsGet(GetOrder)))
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"),
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(FinalizeOrder))
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"),
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
extractPayloadByKid(isPostAsGet(GetAuthorization)))
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
extractPayloadByKid(GetChallenge))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"),
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
extractPayloadByKid(isPostAsGet(GetCertificate)))
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
extractPayloadByKidOrJWK(RevokeCert))
}
@ -251,20 +222,20 @@ func (d *Directory) ToLog() (interface{}, error) {
// for client configuration.
func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
o := optionsFromContext(ctx)
acmeProv, err := acmeProvisionerFromContext(ctx)
fmt.Println(acmeProv, err)
if err != nil {
render.Error(w, err)
return
}
linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{
NewNonce: o.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: o.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: o.linker.GetLink(ctx, NewOrderLinkType),
RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType),
KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType),
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB,
},
@ -280,8 +251,8 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) {
// GetAuthorization ACME api for retrieving an Authz.
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
o := optionsFromContext(ctx)
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
@ -303,17 +274,17 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
return
}
o.linker.LinkAuthorization(ctx, az)
linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az)
}
// GetChallenge ACME api for retrieving a Challenge.
func GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
o := optionsFromContext(ctx)
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
@ -351,22 +322,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil {
if err = ch.Validate(ctx, db, jwk); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return
}
o.linker.LinkChallenge(ctx, ch, azID)
linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch)
}
// GetCertificate ACME api for retrieving a Certificate.
func GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {

@ -31,39 +31,10 @@ func logNonce(w http.ResponseWriter, nonce string) {
}
}
// getBaseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in constructing
// ACME link URLs.
func getBaseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header.
func addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
db := acme.MustFromContext(r.Context())
db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context())
if err != nil {
render.Error(w, err)
@ -81,9 +52,9 @@ func addNonce(next nextHTTP) nextHTTP {
func addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
opts := optionsFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index"))
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
next(w, r)
}
}
@ -92,17 +63,12 @@ func addDirLink(next nextHTTP) nextHTTP {
// application/jose+json.
func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
var expected []string
ctx := r.Context()
opts := optionsFromContext(ctx)
p, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
p := acme.MustProvisionerFromContext(r.Context())
u := &url.URL{
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
}
u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
var expected []string
if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
@ -159,7 +125,7 @@ func parseJWS(next nextHTTP) nextHTTP {
func validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
@ -247,7 +213,7 @@ func validateJWS(next nextHTTP) nextHTTP {
func extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
@ -325,18 +291,20 @@ func lookupProvisioner(next nextHTTP) nextHTTP {
func checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
opts := optionsFromContext(ctx)
ok, err := opts.PrerequisitesChecker(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
// If the function is not set assume that all prerequisites are met.
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
if ok {
ok, err := checkFunc(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
}
next(w, r.WithContext(ctx))
next(w, r)
}
}
@ -346,8 +314,8 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
opts := optionsFromContext(ctx)
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
@ -355,7 +323,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
return
}
kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "")
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType,
@ -527,32 +495,14 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
return val, nil
}
// provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey)
if val == nil {
return nil, acme.NewErrorISE("provisioner expected in request context")
}
pval, ok := val.(acme.Provisioner)
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
}
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
prov, err := provisionerFromContext(ctx)
if err != nil {
return nil, err
}
acmeProv, ok := prov.(*provisioner.ACME)
if !ok || acmeProv == nil {
p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME)
if !ok {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return acmeProv, nil
return p, nil
}
// payloadFromContext searches the context for a payload. Returns the payload

@ -70,16 +70,15 @@ var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order.
func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -136,16 +135,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
}
db := acme.MustFromContext(ctx)
if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return
}
opts := optionsFromContext(ctx)
opts.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated)
}
@ -166,7 +163,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
}
db := acme.MustFromContext(ctx)
db := acme.MustDatabaseFromContext(ctx)
az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes {
ch := &acme.Challenge{
@ -190,18 +187,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
// GetOrder ACME api for retrieving an order.
func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
db := acme.MustFromContext(ctx)
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
@ -222,26 +217,24 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
return
}
opts := optionsFromContext(ctx)
opts.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -258,7 +251,6 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return
}
db := acme.MustFromContext(ctx)
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
@ -281,10 +273,9 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return
}
opts := optionsFromContext(ctx)
opts.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o)
}

@ -28,13 +28,11 @@ type revokePayload struct {
// RevokeCert attempts to revoke a certificate.
func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
prov, err := provisionerFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
return
@ -67,7 +65,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
return
}
db := acme.MustFromContext(ctx)
serial := certToBeRevoked.SerialNumber.String()
dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil {
@ -138,8 +135,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
}
logRevoke(w, options)
o := optionsFromContext(ctx)
w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index"))
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
w.Write(nil)
}

@ -14,7 +14,6 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"strings"
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are
// updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
// If already valid or invalid then return without performing validation.
if ch.Status != StatusPending {
return nil
}
switch ch.Type {
case HTTP01:
return http01Validate(ctx, ch, db, jwk, vo)
return http01Validate(ctx, ch, db, jwk)
case DNS01:
return dns01Validate(ctx, ch, db, jwk, vo)
return dns01Validate(ctx, ch, db, jwk)
case TLSALPN01:
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
return tlsalpn01Validate(ctx, ch, db, jwk)
default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
}
}
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String())
vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", u))
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
return 0
}
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
config := &tls.Config{
NextProtos: []string{"acme-tls/1"},
// https://tools.ietf.org/html/rfc8737#section-4
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.TLSDial("tcp", hostPort, config)
vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config)
if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
}
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
// Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
vc := MustClientFromContext(ctx)
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain))
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
}
return nil
}
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
// ValidateChallengeOptions are ACME challenge validator functions.
type ValidateChallengeOptions struct {
HTTPGet httpGetter
LookupTxt lookupTxt
TLSDial tlsDialer
}

@ -0,0 +1,79 @@
package acme
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// Client is the interface used to verify ACME challenges.
type Client interface {
// Get issues an HTTP GET to the specified URL.
Get(url string) (*http.Response, error)
// LookupTXT returns the DNS TXT records for the given domain name.
LookupTxt(name string) ([]string, error)
// TLSDial connects to the given network address using net.Dialer and then
// initiates a TLS handshake, returning the resulting TLS connection.
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
}
type clientKey struct{}
// NewClientContext adds the given client to the context.
func NewClientContext(ctx context.Context, c Client) context.Context {
return context.WithValue(ctx, clientKey{}, c)
}
// ClientFromContext returns the current client from the given context.
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
c, ok = ctx.Value(clientKey{}).(Client)
return
}
// MustClientFromContext returns the current client from the given context. It will
// return a new instance of the client if it does not exist.
func MustClientFromContext(ctx context.Context) Client {
if c, ok := ClientFromContext(ctx); !ok {
return NewClient()
} else {
return c
}
}
type client struct {
http *http.Client
dialer *net.Dialer
}
// NewClient returns an implementation of Client for verifying ACME challenges.
func NewClient() Client {
return &client{
http: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
dialer: &net.Dialer{
Timeout: 30 * time.Second,
},
}
}
func (c *client) Get(url string) (*http.Response, error) {
return c.http.Get(url)
}
func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name)
}
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(c.dialer, network, addr, config)
}

@ -9,6 +9,16 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
)
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
}
var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -17,15 +27,42 @@ type CertificateAuthority interface {
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
// PrerequisitesChecker is a function that checks if all prerequisites for
// serving ACME are met by the CA configuration.
type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil
}
var clock Clock
type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority.
@ -38,6 +75,29 @@ type Provisioner interface {
GetOptions() *provisioner.Options
}
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustLinkerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
if v, ok := ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
} else {
return v
}
}
// MockProvisioner for testing
type MockProvisioner struct {
Mret1 interface{}

@ -50,21 +50,21 @@ type DB interface {
type dbKey struct{}
// NewContext adds the given acme database to the context.
func NewContext(ctx context.Context, db DB) context.Context {
// NewDatabaseContext adds the given acme database to the context.
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// FromContext returns the current acme database from the given context.
func FromContext(ctx context.Context) (db DB, ok bool) {
// DatabaseFromContext returns the current acme database from the given context.
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustFromContext returns the current database from the given context. It
// will panic if it's not in the context.
func MustFromContext(ctx context.Context) DB {
if db, ok := FromContext(ctx); !ok {
// MustDatabaseFromContext returns the current database from the given context.
// It will panic if it's not in the context.
func MustDatabaseFromContext(ctx context.Context) DB {
if db, ok := DatabaseFromContext(ctx); !ok {
panic("acme database is not in the context")
} else {
return db

@ -4,12 +4,98 @@ import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/smallstep/certificates/acme"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
)
// LinkType captures the link type.
type LinkType int
const (
// NewNonceLinkType new-nonce
NewNonceLinkType LinkType = iota
// NewAccountLinkType new-account
NewAccountLinkType
// AccountLinkType account
AccountLinkType
// OrderLinkType order
OrderLinkType
// NewOrderLinkType new-order
NewOrderLinkType
// OrdersByAccountLinkType list of orders owned by account
OrdersByAccountLinkType
// FinalizeLinkType finalize order
FinalizeLinkType
// NewAuthzLinkType authz
NewAuthzLinkType
// AuthzLinkType new-authz
AuthzLinkType
// ChallengeLinkType challenge
ChallengeLinkType
// CertificateLinkType certificate
CertificateLinkType
// DirectoryLinkType directory
DirectoryLinkType
// RevokeCertLinkType revoke certificate
RevokeCertLinkType
// KeyChangeLinkType key rollover
KeyChangeLinkType
)
func (l LinkType) String() string {
switch l {
case NewNonceLinkType:
return "new-nonce"
case NewAccountLinkType:
return "new-account"
case AccountLinkType:
return "account"
case NewOrderLinkType:
return "new-order"
case OrderLinkType:
return "order"
case NewAuthzLinkType:
return "new-authz"
case AuthzLinkType:
return "authz"
case ChallengeLinkType:
return "challenge"
case CertificateLinkType:
return "certificate"
case DirectoryLinkType:
return "directory"
case RevokeCertLinkType:
return "revoke-cert"
case KeyChangeLinkType:
return "key-change"
default:
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
}
}
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
@ -32,12 +118,11 @@ func NewLinker(dns, prefix string) Linker {
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account)
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
LinkAuthorization(ctx context.Context, o *acme.Authorization)
Middleware(http.Handler) http.Handler
LinkOrder(ctx context.Context, o *Order)
LinkAccount(ctx context.Context, o *Account)
LinkChallenge(ctx context.Context, o *Challenge, azID string)
LinkAuthorization(ctx context.Context, o *Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
@ -64,127 +149,81 @@ func MustLinkerFromContext(ctx context.Context) Linker {
}
}
type baseURLKey struct{}
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
var u *url.URL
if r.Host != "" {
u = &url.URL{Scheme: "https", Host: r.Host}
}
return context.WithValue(ctx, baseURLKey{}, u)
}
func baseURLFromContext(ctx context.Context) *url.URL {
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
return u
}
return nil
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
// Middleware gets the provisioner and current url from the request and sets
// them in the context so we can use the linker to create ACME links.
func (l *linker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add base url to the context.
ctx := newBaseURLContext(r.Context(), r)
// Add provisioner to the context.
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetLink is a helper for GetLinkExplicit
// GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL
}
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + u.Path
p := MustProvisionerFromContext(ctx)
u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...)
return u.String()
}
// LinkType captures the link type.
type LinkType int
const (
// NewNonceLinkType new-nonce
NewNonceLinkType LinkType = iota
// NewAccountLinkType new-account
NewAccountLinkType
// AccountLinkType account
AccountLinkType
// OrderLinkType order
OrderLinkType
// NewOrderLinkType new-order
NewOrderLinkType
// OrdersByAccountLinkType list of orders owned by account
OrdersByAccountLinkType
// FinalizeLinkType finalize order
FinalizeLinkType
// NewAuthzLinkType authz
NewAuthzLinkType
// AuthzLinkType new-authz
AuthzLinkType
// ChallengeLinkType challenge
ChallengeLinkType
// CertificateLinkType certificate
CertificateLinkType
// DirectoryLinkType directory
DirectoryLinkType
// RevokeCertLinkType revoke certificate
RevokeCertLinkType
// KeyChangeLinkType key rollover
KeyChangeLinkType
)
func (l LinkType) String() string {
switch l {
case NewNonceLinkType:
return "new-nonce"
case NewAccountLinkType:
return "new-account"
case AccountLinkType:
return "account"
case NewOrderLinkType:
return "new-order"
case OrderLinkType:
return "order"
case NewAuthzLinkType:
return "new-authz"
case AuthzLinkType:
return "authz"
case ChallengeLinkType:
return "challenge"
case CertificateLinkType:
return "certificate"
case DirectoryLinkType:
return "directory"
case RevokeCertLinkType:
return "revoke-cert"
case KeyChangeLinkType:
return "key-change"
default:
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
}
}
// LinkOrder sets the ACME links required by an ACME order.
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
@ -196,17 +235,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
}
// LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
}
// LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
}
// LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch, az.ID)
}

@ -7,7 +7,6 @@ import (
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
)
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
@ -173,27 +172,27 @@ func TestLinker_LinkOrder(t *testing.T) {
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
o *acme.Order
validate func(o *acme.Order)
o *Order
validate func(o *Order)
}
var tests = map[string]test{
"no-authz-and-no-cert": {
o: &acme.Order{
o: &Order{
ID: oid,
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "")
},
},
"one-authz-and-cert": {
o: &acme.Order{
o: &Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo"},
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -202,12 +201,12 @@ func TestLinker_LinkOrder(t *testing.T) {
},
},
"many-authz": {
o: &acme.Order{
o: &Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"},
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -237,15 +236,15 @@ func TestLinker_LinkAccount(t *testing.T) {
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
a *acme.Account
validate func(o *acme.Account)
a *Account
validate func(o *Account)
}
var tests = map[string]test{
"ok": {
a: &acme.Account{
a: &Account{
ID: accID,
},
validate: func(a *acme.Account) {
validate: func(a *Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
},
},
@ -270,15 +269,15 @@ func TestLinker_LinkChallenge(t *testing.T) {
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
ch *acme.Challenge
validate func(o *acme.Challenge)
ch *Challenge
validate func(o *Challenge)
}
var tests = map[string]test{
"ok": {
ch: &acme.Challenge{
ch: &Challenge{
ID: chID,
},
validate: func(ch *acme.Challenge) {
validate: func(ch *Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
},
},
@ -305,20 +304,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
az *acme.Authorization
validate func(o *acme.Authorization)
az *Authorization
validate func(o *Authorization)
}
var tests = map[string]test{
"ok": {
az: &acme.Authorization{
az: &Authorization{
ID: azID,
Challenges: []*acme.Challenge{
Challenges: []*Challenge{
{ID: chID0},
{ID: chID1},
{ID: chID2},
},
},
validate: func(az *acme.Authorization) {
validate: func(az *Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))

@ -189,30 +189,24 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
dns = fmt.Sprintf("%s:%s", dns, port)
}
// ACME Router
prefix := "acme"
// ACME Router is only available if we have a database.
var acmeDB acme.DB
if cfg.DB == nil {
acmeDB = nil
} else {
var acmeLinker acme.Linker
if cfg.DB != nil {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
if err != nil {
return nil, errors.Wrap(err, "error configuring ACME DB interface")
}
acmeLinker = acme.NewLinker(dns, "acme")
mux.Route("/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
}
acmeOptions := &acmeAPI.HandlerOptions{
Backdate: *cfg.AuthorityConfig.Backdate,
DNS: dns,
Prefix: prefix,
}
mux.Route("/"+prefix, func(r chi.Router) {
acmeAPI.Route(r, acmeOptions)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
acmeAPI.Route(r, acmeOptions)
})
// Admin API Router
if cfg.AuthorityConfig.EnableAdmin {
@ -280,7 +274,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}
// Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB)
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
ca.srv = server.New(cfg.Address, handler, tlsConfig)
ca.srv.BaseContext = func(net.Listener) context.Context {
@ -304,7 +298,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}
// buildContext builds the server base context.
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context {
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
ctx := authority.NewContext(context.Background(), a)
if authDB := a.GetDatabase(); authDB != nil {
ctx = db.NewContext(ctx, authDB)
@ -316,7 +310,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
ctx = scep.NewContext(ctx, scepAuthority)
}
if acmeDB != nil {
ctx = acme.NewContext(ctx, acmeDB)
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
}
return ctx
}

Loading…
Cancel
Save