diff --git a/acme/api/handler.go b/acme/api/handler.go index 04680656..4b916404 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -84,7 +84,7 @@ func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { func optionsFromContext(ctx context.Context) *HandlerOptions { o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) if !ok { - panic("handler options are not in the context") + panic("acme options are not in the context") } return o } diff --git a/scep/api/api.go b/scep/api/api.go index 31f0f10d..0d62904d 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -48,29 +48,32 @@ type response struct { } // handler is the SCEP request handler. -type handler struct { - auth *scep.Authority +type handler struct{} + +// Route traffic and implement the Router interface. +// +// Deprecated: use scep.Route(r api.Router) +func (h *handler) Route(r api.Router) { + Route(r) } // New returns a new SCEP API router. +// +// Deprecated: use scep.Route(r api.Router) func New(auth *scep.Authority) api.RouterHandler { - return &handler{ - auth: auth, - } + return &handler{} } // Route traffic and implement the Router interface. -func (h *handler) Route(r api.Router) { - getLink := h.auth.GetLinkExplicit - r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) +func Route(r api.Router) { + r.MethodFunc(http.MethodGet, "/{provisionerName}/*", lookupProvisioner(Get)) + r.MethodFunc(http.MethodGet, "/{provisionerName}", lookupProvisioner(Get)) + r.MethodFunc(http.MethodPost, "/{provisionerName}/*", lookupProvisioner(Post)) + r.MethodFunc(http.MethodPost, "/{provisionerName}", lookupProvisioner(Post)) } // Get handles all SCEP GET requests -func (h *handler) Get(w http.ResponseWriter, r *http.Request) { - +func Get(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep get request: %w", err)) @@ -82,9 +85,9 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { switch req.Operation { case opnGetCACert: - res, err = h.GetCACert(ctx) + res, err = GetCACert(ctx) case opnGetCACaps: - res, err = h.GetCACaps(ctx) + res, err = GetCACaps(ctx) case opnPKIOperation: // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though default: @@ -100,20 +103,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { } // Post handles all SCEP POST requests -func (h *handler) Post(w http.ResponseWriter, r *http.Request) { - +func Post(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep post request: %w", err)) return } - ctx := r.Context() var res response - switch req.Operation { case opnPKIOperation: - res, err = h.PKIOperation(ctx, req) + res, err = PKIOperation(r.Context(), req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } @@ -127,7 +127,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) { } func decodeRequest(r *http.Request) (request, error) { - defer r.Body.Close() method := r.Method @@ -179,9 +178,8 @@ func decodeRequest(r *http.Request) (request, error) { // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { +func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { @@ -189,7 +187,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - p, err := h.auth.LoadProvisionerByName(provisionerName) + ctx := r.Context() + auth := scep.MustFromContext(ctx) + p, err := auth.LoadProvisionerByName(provisionerName) if err != nil { fail(w, err) return @@ -201,16 +201,15 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - ctx := r.Context() ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) } } // GetCACert returns the CA certificates in a SCEP response -func (h *handler) GetCACert(ctx context.Context) (response, error) { - - certs, err := h.auth.GetCACertificates(ctx) +func GetCACert(ctx context.Context) (response, error) { + auth := scep.MustFromContext(ctx) + certs, err := auth.GetCACertificates(ctx) if err != nil { return response{}, err } @@ -241,9 +240,9 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) { } // GetCACaps returns the CA capabilities in a SCEP response -func (h *handler) GetCACaps(ctx context.Context) (response, error) { - - caps := h.auth.GetCACaps(ctx) +func GetCACaps(ctx context.Context) (response, error) { + auth := scep.MustFromContext(ctx) + caps := auth.GetCACaps(ctx) res := response{ Operation: opnGetCACaps, @@ -254,8 +253,7 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) { } // PKIOperation performs PKI operations and returns a SCEP response -func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { - +func PKIOperation(ctx context.Context, req request) (response, error) { // parse the message using microscep implementation microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { @@ -280,7 +278,8 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro P7: p7, } - if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil { + auth := scep.MustFromContext(ctx) + if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil { return response{}, err } @@ -293,13 +292,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients. // We'll have to see how it works out. if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq { - challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) + challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) } if !challengeMatches { // TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too. - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) } } @@ -311,9 +310,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification // of the client cert is not. - certRep, err := h.auth.SignCSR(ctx, csr, msg) + certRep, err := auth.SignCSR(ctx, csr, msg) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } res := response{ @@ -350,8 +349,9 @@ func fail(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusInternalServerError) } -func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { - certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) +func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { + auth := scep.MustFromContext(ctx) + certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { return response{}, err }