diff --git a/acme/api/account.go b/acme/api/account.go index 92c5dbfc..b733c679 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -121,8 +121,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, - true, acc.ID)) + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) api.JSONStatus(w, acc, httpStatus) } @@ -169,7 +168,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, true, acc.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) api.JSON(w, acc) } diff --git a/acme/api/handler.go b/acme/api/handler.go index 7d02861e..c1d2d62a 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -87,12 +87,12 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - getLink := h.linker.GetLinkExplicit + getPath := h.linker.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) - r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) @@ -101,16 +101,16 @@ func (h *Handler) Route(r api.Router) { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) } - r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) - r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) - r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) - r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) + r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) + r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) + r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) + r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } // GetNonce just sets the right header since a Nonce is added to each response @@ -146,11 +146,11 @@ func (d *Directory) ToLog() (interface{}, error) { func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() api.JSON(w, &Directory{ - NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true), - NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true), - NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true), - RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true), - KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true), + NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), + NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), + NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), + RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), + KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), }) } @@ -185,7 +185,7 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { h.linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) api.JSON(w, az) } @@ -235,8 +235,8 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { h.linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, azID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID)) + w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) + w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) api.JSON(w, ch) } diff --git a/acme/api/linker.go b/acme/api/linker.go index 702f7433..d4490470 100644 --- a/acme/api/linker.go +++ b/acme/api/linker.go @@ -15,8 +15,8 @@ func NewLinker(dns, prefix string) Linker { // Linker interface for generating links for ACME resources. type Linker interface { - GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string - GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string + GetLink(ctx context.Context, typ LinkType, inputs ...string) string + GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string LinkOrder(ctx context.Context, o *acme.Order) LinkAccount(ctx context.Context, o *acme.Account) @@ -31,53 +31,52 @@ type linker struct { dns string } +func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + return fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + default: + return "" + } +} + // GetLink is a helper for GetLinkExplicit -func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { - var provName string +func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var ( + provName string + baseURL = baseURLFromContext(ctx) + u = url.URL{} + ) if p, err := provisionerFromContext(ctx); err == nil && p != nil { provName = p.GetName() } - return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...) -} - -// GetLinkExplicit returns an absolute or partial path to the given resource and a base -// URL dynamically obtained from the request for which the link is being -// calculated. -func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { - var u = url.URL{} // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 if baseURL != nil { u = *baseURL } - switch typ { - case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ) - case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) - case ChallengeLinkType: - u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) - case OrdersByAccountLinkType: - u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) - case FinalizeLinkType: - u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) - } - - if abs { - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } + u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...) - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = l.dns - } + // If no Scheme is set, then default to https. + if u.Scheme == "" { + u.Scheme = "https" + } - u.Path = l.prefix + u.Path - return u.String() + // If no Host is set, then use the default (first DNS attr in the ca.json). + if u.Host == "" { + u.Host = l.dns } - return u.EscapedPath() + + u.Path = l.prefix + u.Path + return u.String() } // LinkType captures the link type. @@ -149,22 +148,22 @@ func (l LinkType) String() string { func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { - o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) + o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) } - o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) + o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID) if o.CertificateID != "" { - o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID) + o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID) } } // LinkAccount sets the ACME links required by an ACME account. func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { - acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) + acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { - ch.URL = l.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID) + ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. @@ -177,6 +176,6 @@ func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) // LinkOrdersByAccountID converts each order ID to an ACME link. func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { for i, id := range orders { - orders[i] = l.GetLink(ctx, OrderLinkType, true, id) + orders[i] = l.GetLink(ctx, OrderLinkType, id) } } diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go index 6bb1f739..36e08985 100644 --- a/acme/api/linker_test.go +++ b/acme/api/linker_test.go @@ -10,94 +10,75 @@ import ( "github.com/smallstep/certificates/acme" ) -func TestLinker_GetLink(t *testing.T) { +func TestLinker_GetUnescapedPathSuffix(t *testing.T) { dns := "ca.smallstep.com" prefix := "acme" linker := NewLinker(dns, prefix) - id := "1234" - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true), - fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) - assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) - - // No provisioner - ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL) - assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true), - fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) - assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce") - - // No baseURL - ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov) - assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true), - fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) - assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) - - assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id), - fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) - assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) + getPath := linker.GetUnescapedPathSuffix + + assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), fmt.Sprintf("/{provisionerID}/new-nonce")) + assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), fmt.Sprintf("/{provisionerID}/directory")) + assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), fmt.Sprintf("/{provisionerID}/new-account")) + assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), fmt.Sprintf("/{provisionerID}/account/{accID}")) + assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), fmt.Sprintf("/{provisionerID}/key-change")) + assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), fmt.Sprintf("/{provisionerID}/new-order")) + assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), fmt.Sprintf("/{provisionerID}/order/{ordID}")) + assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), fmt.Sprintf("/{provisionerID}/account/{accID}/orders")) + assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), fmt.Sprintf("/{provisionerID}/order/{ordID}/finalize")) + assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), fmt.Sprintf("/{provisionerID}/authz/{authzID}")) + assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), fmt.Sprintf("/{provisionerID}/challenge/{authzID}/{chID}")) + assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), fmt.Sprintf("/{provisionerID}/certificate/{certID}")) } -func TestLinker_GetLinkExplicit(t *testing.T) { +func TestLinker_GetLink(t *testing.T) { dns := "ca.smallstep.com" - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prefix := "acme" linker := NewLinker(dns, prefix) id := "1234" prov := newProv() - provName := prov.GetName() - escProvName := url.PathEscape(provName) + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + + // No provisioner and no BaseURL from request + assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) + // Provisioner: yes, BaseURL: no + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + + // Provisioner: no, BaseURL: yes + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName)) + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName)) + assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName)) + assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName)) + assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName)) + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName)) + assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName)) + assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName)) + assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName)) + assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName)) + assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName)) + assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName)) + assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id)) - assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id)) + assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id)) - assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName)) - assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName)) + assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName)) } func TestLinker_LinkOrder(t *testing.T) { diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 861876a9..c33aaeda 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -78,8 +78,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.linker.GetLink(r.Context(), - DirectoryLinkType, true), "index")) + w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) next(w, r) } } @@ -88,15 +87,23 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ct := r.Header.Get("Content-Type") - var expected []string - if strings.Contains(r.URL.String(), h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { + var ( + expected []string + provName string + ) + if p, err := provisionerFromContext(r.Context()); err == nil && p != nil { + provName = p.GetName() + } + u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, provName, "")} + if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { // By default every request should have content-type applictaion/jose+json. expected = []string{"application/jose+json"} } + + ct := r.Header.Get("Content-Type") for _, e := range expected { if ct == e { next(w, r) @@ -314,7 +321,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return } - kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "") + kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, diff --git a/acme/api/order.go b/acme/api/order.go index e7a913ab..9d410173 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -136,7 +136,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) api.JSONStatus(w, o, http.StatusCreated) } @@ -217,7 +217,7 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) api.JSON(w, o) } @@ -272,6 +272,6 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { h.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) api.JSON(w, o) } diff --git a/ca/ca.go b/ca/ca.go index 7c723c73..356a1f6f 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -167,16 +167,14 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { acmeHandler.Route(r) }) - /* - // helpful routine for logging all routes // - walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { - fmt.Printf("%s %s\n", method, route) - return nil - } - if err := chi.Walk(mux, walkFunc); err != nil { - fmt.Printf("Logging err: %s\n", err.Error()) - } - */ + // helpful routine for logging all routes // + walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { + fmt.Printf("%s %s\n", method, route) + return nil + } + if err := chi.Walk(mux, walkFunc); err != nil { + fmt.Printf("Logging err: %s\n", err.Error()) + } // Add monitoring if configured if len(config.Monitoring) > 0 {