diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index b01aff57..85b9a032 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -281,7 +281,7 @@ type mockCA struct { MockAreSANsallowed func(ctx context.Context, sans []string) error } -func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { +func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } diff --git a/acme/common.go b/acme/common.go index 7cce25fd..e86b23e9 100644 --- a/acme/common.go +++ b/acme/common.go @@ -21,7 +21,7 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error diff --git a/acme/order.go b/acme/order.go index 8dfcf97a..5a86c2c8 100644 --- a/acme/order.go +++ b/acme/order.go @@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques signOps = append(signOps, extraOptions...) // Sign a new certificate. - certChain, err := auth.Sign(csr, provisioner.SignOptions{ + certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) diff --git a/acme/order_test.go b/acme/order_test.go index 2851bb19..07372af0 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -271,16 +271,16 @@ func TestOrder_UpdateStatus(t *testing.T) { } type mockSignAuth struct { - sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} err error } -func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(csr, signOpts, extraOpts...) +func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, csr, signOpts, extraOpts...) } else if m.err != nil { return nil, m.err } @@ -578,7 +578,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return nil, errors.New("force") }, @@ -628,7 +628,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -685,7 +685,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -770,7 +770,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -863,7 +863,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -973,7 +973,7 @@ func TestOrder_Finalize(t *testing.T) { // using the mocking functions as a wrapper for actual test helpers generated per test case or per // function that's tested. ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -1044,7 +1044,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1108,7 +1108,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1175,7 +1175,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, diff --git a/api/api.go b/api/api.go index 5d96cc45..a12e7e19 100644 --- a/api/api.go +++ b/api/api.go @@ -42,7 +42,7 @@ type Authority interface { AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index 28944a1e..cf988593 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -189,7 +189,7 @@ type mockAuthority struct { authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -251,9 +251,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { return m.ret1.(*x509.Certificate), m.err } -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) +func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, cr, opts, signOpts...) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } diff --git a/api/sign.go b/api/sign.go index c0c83ce2..26b3c396 100644 --- a/api/sign.go +++ b/api/sign.go @@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) { return } - certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return diff --git a/api/ssh.go b/api/ssh.go index 9d0bbc14..08294c71 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return diff --git a/api/ssh_test.go b/api/ssh_test.go index 57dd6775..2b90dc12 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) { signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, - sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, }) diff --git a/authority/authority_test.go b/authority/authority_test.go index 45c7cd86..3787dab7 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto" "crypto/rand" "crypto/sha256" @@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) { csr, err := x509.ParseCertificateRequest(cr) assert.FatalError(t, err) - cert, err := a.Sign(csr, provisioner.SignOptions{}) + cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{}) assert.FatalError(t, err) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, crt, cert[1]) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 3d748f69..8f3c1ae2 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { } generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { - chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) if err != nil { t.Fatal(err) } diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 4b517bb6..c33dfa23 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -36,7 +37,7 @@ type WebhookController struct { // Enrich fetches data from remote servers and adds returned data to the // templateData -func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { +func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -55,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout + + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -68,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { } // Authorize checks that all remote servers allow the request -func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { +func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -87,7 +92,11 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout + + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -123,13 +132,6 @@ type Webhook struct { } `json:"-"` } -func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - return w.DoWithContext(ctx, client, reqBody, data) -} - func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { @@ -169,6 +171,11 @@ retry: return nil, err } + requestID, ok := logging.GetRequestID(ctx) + if ok { + req.Header.Set("X-Request-ID", requestID) + } + secret, err := base64.StdEncoding.DecodeString(w.Secret) if err != nil { return nil, err diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 9a2b62f0..60dcdbc7 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/hmac" "crypto/sha256" "crypto/tls" @@ -8,15 +9,18 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" "net/http/httptest" "testing" + "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" @@ -92,19 +96,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } +// withRequestID is a helper that calls into [logging.WithRequestID] and returns +// a new context with the requestID added to the provided context. +func withRequestID(ctx context.Context, requestID string) context.Context { + return logging.WithRequestID(ctx, requestID) +} + func TestWebhookController_Enrich(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -129,6 +138,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -143,6 +153,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -166,6 +177,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -185,14 +197,15 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -207,6 +220,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -221,6 +235,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -232,19 +247,21 @@ func TestWebhookController_Enrich(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Enrich(test.req) + err := test.ctl.Enrich(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -254,12 +271,11 @@ func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -280,6 +296,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -290,6 +307,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -300,13 +318,14 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -320,6 +339,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -332,6 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -342,15 +363,17 @@ func TestWebhookController_Authorize(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Authorize(test.req) + err := test.ctl.Authorize(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -366,6 +389,7 @@ func TestWebhook_Do(t *testing.T) { type test struct { webhook Webhook dataArg any + requestID string webhookResponse webhook.ResponseBody expectPath string errStatusCode int @@ -375,6 +399,16 @@ func TestWebhook_Do(t *testing.T) { } tests := map[string]test{ "ok": { + webhook: Webhook{ + ID: "abc123", + Secret: "c2VjcmV0Cg==", + }, + requestID: "reqID", + webhookResponse: webhook.ResponseBody{ + Data: map[string]interface{}{"role": "dba"}, + }, + }, + "ok/no-request-id": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", @@ -389,6 +423,7 @@ func TestWebhook_Do(t *testing.T) { Secret: "c2VjcmV0Cg==", BearerToken: "mytoken", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -405,6 +440,7 @@ func TestWebhook_Do(t *testing.T) { Password: "mypass", }, }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -416,7 +452,8 @@ func TestWebhook_Do(t *testing.T) { URL: "/users/{{ .username }}?region={{ .region }}", Secret: "c2VjcmV0Cg==", }, - dataArg: map[string]interface{}{"username": "areed", "region": "central"}, + requestID: "reqID", + dataArg: map[string]interface{}{"username": "areed", "region": "central"}, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -451,6 +488,7 @@ func TestWebhook_Do(t *testing.T) { ID: "abc123", Secret: "c2VjcmV0Cg==", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Allow: true, }, @@ -463,6 +501,7 @@ func TestWebhook_Do(t *testing.T) { webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, + requestID: "reqID", errStatusCode: 404, serverErrMsg: "item not found", expectErr: errors.New("Webhook server responded with 404"), @@ -471,17 +510,20 @@ func TestWebhook_Do(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := r.Header.Get("X-Smallstep-Webhook-ID") - assert.Equals(t, tc.webhook.ID, id) + if tc.requestID != "" { + assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) + } + + assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) - assert.FatalError(t, err) + assert.NoError(t, err) body, err := io.ReadAll(r.Body) - assert.FatalError(t, err) + assert.NoError(t, err) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) - assert.FatalError(t, err) + assert.NoError(t, err) h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) @@ -490,19 +532,19 @@ func TestWebhook_Do(t *testing.T) { switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - assert.Equals(t, ah, r.Header.Get("Authorization")) + assert.Equal(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.FatalError(t, err) + require.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - assert.Equals(t, ah, whReq.Header.Get("Authorization")) + assert.Equal(t, ah, whReq.Header.Get("Authorization")) default: - assert.Equals(t, "", r.Header.Get("Authorization")) + assert.Equal(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -512,26 +554,33 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) - assert.FatalError(t, err) - // assert.Equals(t, tc.expectToken, reqBody.Token) + require.NoError(t, err) err = json.NewEncoder(w).Encode(tc.webhookResponse) - assert.FatalError(t, err) + require.NoError(t, err) })) defer ts.Close() tc.webhook.URL = ts.URL + tc.webhook.URL reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) - got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg) + require.NoError(t, err) + + ctx := context.Background() + if tc.requestID != "" { + ctx = withRequestID(context.Background(), tc.requestID) + } + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + + got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - assert.Equals(t, tc.expectErr.Error(), err.Error()) + assert.Equal(t, tc.expectErr.Error(), err.Error()) return } - assert.FatalError(t, err) + assert.NoError(t, err) - assert.Equals(t, got, &tc.webhookResponse) + assert.Equal(t, &tc.webhookResponse, got) }) } @@ -544,7 +593,7 @@ func TestWebhook_Do(t *testing.T) { URL: ts.URL, } cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") - assert.FatalError(t, err) + require.NoError(t, err) transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, @@ -554,12 +603,19 @@ func TestWebhook_Do(t *testing.T) { Transport: transport, } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) - _, err = wh.Do(client, reqBody, nil) - assert.FatalError(t, err) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _, err = wh.DoWithContext(ctx, client, reqBody, nil) + require.NoError(t, err) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) + defer cancel() wh.DisableTLSClientAuth = true - _, err = wh.Do(client, reqBody, nil) - assert.Error(t, err) + _, err = wh.DoWithContext(ctx, client, reqBody, nil) + require.Error(t, err) }) } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index f6af6f54..f62f8127 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { opts, err := a.Authorize(ctx, token) require.NoError(t, err) opts = append(opts, extraOpts...) - certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) require.NoError(t, err) return certs[0] } diff --git a/authority/ssh.go b/authority/ssh.go index 756e376e..26e8eebc 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -152,7 +152,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return cert, err } -func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { +func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { var ( certOptions []sshutil.Option mods []provisioner.SSHCertModifier @@ -211,7 +211,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision } // Call enriching webhooks - if err := a.callEnrichingWebhooksSSH(prov, webhookCtl, cr); err != nil { + if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), @@ -284,7 +284,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision } // Send certificate to webhooks for authorization - if err := a.callAuthorizingWebhooksSSH(prov, webhookCtl, certificate, certTpl); err != nil { + if err := a.callAuthorizingWebhooksSSH(ctx, prov, webhookCtl, certificate, certTpl); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), ) @@ -671,7 +671,7 @@ func (a *Authority) getAddUserCommand(principal string) string { return strings.ReplaceAll(cmd, "", principal) } -func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) { +func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) { if webhookCtl == nil { return } @@ -680,7 +680,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook if whEnrichReq, err = webhook.NewRequestBody( webhook.WithSSHCertificateRequest(cr), ); err == nil { - err = webhookCtl.Enrich(whEnrichReq) + err = webhookCtl.Enrich(ctx, whEnrichReq) a.meter.SSHWebhookEnriched(prov, err) } @@ -688,7 +688,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook return } -func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) { +func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) { if webhookCtl == nil { return } @@ -697,7 +697,7 @@ func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webho if whAuthBody, err = webhook.NewRequestBody( webhook.WithSSHCertificate(cert, certTpl), ); err == nil { - err = webhookCtl.Authorize(whAuthBody) + err = webhookCtl.Authorize(ctx, whAuthBody) a.meter.SSHWebhookAuthorized(prov, err) } diff --git a/authority/tls.go b/authority/tls.go index fa170d44..082513c8 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -91,14 +91,23 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { } } -// Sign creates a signed certificate from a certificate signing request. +// Sign creates a signed certificate from a certificate signing request. It +// creates a new context.Context, and calls into SignWithContext. +// +// Deprecated: Use authority.SignWithContext with an actual context.Context. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - chain, prov, err := a.signX509(csr, signOpts, extraOpts...) + return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...) +} + +// SignWithContext creates a signed certificate from a certificate signing +// request, taking the provided context.Context. +func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + chain, prov, err := a.signX509(ctx, csr, signOpts, extraOpts...) a.meter.X509Signed(prov, err) return chain, err } -func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) { +func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) { var ( certOptions []x509util.Option certValidators []provisioner.CertificateValidator @@ -171,7 +180,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner. } } - if err := a.callEnrichingWebhooksX509(prov, webhookCtl, attData, csr); err != nil { + if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("csr", csr), @@ -265,7 +274,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner. } // Send certificate to webhooks for authorization - if err := a.callAuthorizingWebhooksX509(prov, webhookCtl, crt, leaf, attData); err != nil { + if err := a.callAuthorizingWebhooksX509(ctx, prov, webhookCtl, crt, leaf, attData); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., @@ -986,7 +995,7 @@ func templatingError(err error) error { return errors.Wrap(cause, "error applying certificate template") } -func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) { +func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) { if webhookCtl == nil { return } @@ -1003,7 +1012,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo webhook.WithX509CertificateRequest(csr), webhook.WithAttestationData(attested), ); err == nil { - err = webhookCtl.Enrich(whEnrichReq) + err = webhookCtl.Enrich(ctx, whEnrichReq) a.meter.X509WebhookEnriched(prov, err) } @@ -1011,7 +1020,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo return } -func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) { +func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) { if webhookCtl == nil { return } @@ -1028,7 +1037,7 @@ func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webh webhook.WithX509Certificate(cert, leaf), webhook.WithAttestationData(attested), ); err == nil { - err = webhookCtl.Authorize(whAuthBody) + err = webhookCtl.Authorize(ctx, whAuthBody) a.meter.X509WebhookAuthorized(prov, err) } diff --git a/authority/tls_test.go b/authority/tls_test.go index 1fb8411a..b481ca68 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error { return nil } -func TestAuthority_Sign(t *testing.T) { +func TestAuthority_SignWithContext(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) @@ -848,7 +848,7 @@ ZYtQ9Ot36qc= t.Run(name, func(t *testing.T) { tc := genTestCase(t) - certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) + certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...) if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) @@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) { t.Fatal(err) } - _, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption) + _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption) if (err != nil) != tt.wantErr { - t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr) } _, err = auth.Renew(cert) diff --git a/authority/webhook.go b/authority/webhook.go index d887e077..29e3e6c3 100644 --- a/authority/webhook.go +++ b/authority/webhook.go @@ -1,8 +1,12 @@ package authority -import "github.com/smallstep/certificates/webhook" +import ( + "context" + + "github.com/smallstep/certificates/webhook" +) type webhookController interface { - Enrich(*webhook.RequestBody) error - Authorize(*webhook.RequestBody) error + Enrich(context.Context, *webhook.RequestBody) error + Authorize(context.Context, *webhook.RequestBody) error } diff --git a/authority/webhook_test.go b/authority/webhook_test.go index 0e713af7..75b59f63 100644 --- a/authority/webhook_test.go +++ b/authority/webhook_test.go @@ -1,6 +1,8 @@ package authority import ( + "context" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) @@ -14,7 +16,7 @@ type mockWebhookController struct { var _ webhookController = &mockWebhookController{} -func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { +func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error { for key, data := range wc.respData { wc.templateData.SetWebhook(key, data) } @@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { return wc.enrichErr } -func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { +func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error { return wc.authorizeErr } diff --git a/scep/authority.go b/scep/authority.go index 1d156752..8ed065fb 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -60,7 +60,7 @@ func MustFromContext(ctx context.Context) *Authority { // SignAuthority is the interface for a signing authority type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) LoadProvisionerByName(string) (provisioner.Interface, error) } @@ -306,7 +306,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m } signOps = append(signOps, templateOptions) - certChain, err := a.signAuth.Sign(csr, opts, signOps...) + certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...) if err != nil { return nil, fmt.Errorf("error generating certificate for order: %w", err) }