Merge pull request #1542 from smallstep/herman/webhook-request-id

Propagate request ID when webhook requests are made
pull/1742/head
Herman Slatman 3 months ago committed by GitHub
commit c798735f7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -281,7 +281,7 @@ type mockCA struct {
MockAreSANsallowed func(ctx context.Context, sans []string) error MockAreSANsallowed func(ctx context.Context, sans []string) error
} }
func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil return nil, nil
} }

@ -21,7 +21,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority. // CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface { type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error) IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error

@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
signOps = append(signOps, extraOptions...) signOps = append(signOps, extraOptions...)
// Sign a new certificate. // Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{ certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter), NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...) }, signOps...)

@ -271,16 +271,16 @@ func TestOrder_UpdateStatus(t *testing.T) {
} }
type mockSignAuth struct { type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error) loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
} }
func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil { if m.signWithContext != nil {
return m.sign(csr, signOpts, extraOpts...) return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
@ -578,7 +578,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return nil, errors.New("force") return nil, errors.New("force")
}, },
@ -628,7 +628,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -685,7 +685,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -770,7 +770,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -863,7 +863,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -973,7 +973,7 @@ func TestOrder_Finalize(t *testing.T) {
// using the mocking functions as a wrapper for actual test helpers generated per test case or per // using the mocking functions as a wrapper for actual test helpers generated per test case or per
// function that's tested. // function that's tested.
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -1044,7 +1044,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -1108,7 +1108,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -1175,7 +1175,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },

@ -42,7 +42,7 @@ type Authority interface {
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)

@ -189,7 +189,7 @@ type mockAuthority struct {
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
@ -251,9 +251,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
return m.ret1.(*x509.Certificate), m.err return m.ret1.(*x509.Certificate), m.err
} }
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil { if m.signWithContext != nil {
return m.sign(cr, opts, signOpts...) return m.signWithContext(ctx, cr, opts, signOpts...)
} }
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }

@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
return return
} }
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return

@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0), NotAfter: time.Unix(int64(cert.ValidBefore), 0),
}) })
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return

@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) {
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
return tt.addUserCert, tt.addUserErr return tt.addUserCert, tt.addUserErr
}, },
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr return tt.tlsSignCerts, tt.tlsSignErr
}, },
}) })

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
csr, err := x509.ParseCertificateRequest(cr) csr, err := x509.ParseCertificateRequest(cr)
assert.FatalError(t, err) assert.FatalError(t, err)
cert, err := a.Sign(csr, provisioner.SignOptions{}) cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{})
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames)
assert.Equals(t, crt, cert[1]) assert.Equals(t, crt, cert[1])

@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
} }
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -36,7 +37,7 @@ type WebhookController struct {
// Enrich fetches data from remote servers and adds returned data to the // Enrich fetches data from remote servers and adds returned data to the
// templateData // templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -55,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) { if !wc.isCertTypeOK(wh) {
continue continue
} }
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }
@ -68,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
} }
// Authorize checks that all remote servers allow the request // Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -87,7 +92,11 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) { if !wc.isCertTypeOK(wh) {
continue continue
} }
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }
@ -123,13 +132,6 @@ type Webhook struct {
} `json:"-"` } `json:"-"`
} }
func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return w.DoWithContext(ctx, client, reqBody, data)
}
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil { if err != nil {
@ -169,6 +171,11 @@ retry:
return nil, err return nil, err
} }
requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
}
secret, err := base64.StdEncoding.DecodeString(w.Secret) secret, err := base64.StdEncoding.DecodeString(w.Secret)
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,6 +1,7 @@
package provisioner package provisioner
import ( import (
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
@ -8,15 +9,18 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/pkg/errors" "github.com/smallstep/certificates/logging"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -92,19 +96,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
} }
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh))
}) })
} }
} }
// withRequestID is a helper that calls into [logging.WithRequestID] and returns
// a new context with the requestID added to the provided context.
func withRequestID(ctx context.Context, requestID string) context.Context {
return logging.WithRequestID(ctx, requestID)
}
func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Enrich(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
ctx context.Context
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
@ -129,6 +138,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
@ -143,6 +153,7 @@ func TestWebhookController_Enrich(t *testing.T) {
}, },
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -166,6 +177,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
certType: linkedca.Webhook_X509, certType: linkedca.Webhook_X509,
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -185,14 +197,15 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}},
assertRequest: func(t *testing.T, req *webhook.RequestBody) { assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, &webhook.X5CCertificate{ assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw, Raw: cert.Raw,
PublicKey: key, PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -207,6 +220,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -221,6 +235,7 @@ func TestWebhookController_Enrich(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -232,19 +247,21 @@ func TestWebhookController_Enrich(t *testing.T) {
for i, wh := range test.ctl.webhooks { for i, wh := range test.ctl.webhooks {
var j = i var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j]) err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err) require.NoError(t, err)
})) }))
// nolint: gocritic // defer in loop isn't a memory leak // nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close() defer ts.Close()
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Enrich(test.req) err := test.ctl.Enrich(test.ctx, test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData)
if test.assertRequest != nil { if test.assertRequest != nil {
test.assertRequest(t, test.req) test.assertRequest(t, test.req)
} }
@ -254,12 +271,11 @@ func TestWebhookController_Enrich(t *testing.T) {
func TestWebhookController_Authorize(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
ctx context.Context
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
@ -280,6 +296,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
@ -290,6 +307,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
certType: linkedca.Webhook_SSH, certType: linkedca.Webhook_SSH,
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false, expectErr: false,
@ -300,13 +318,14 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
assertRequest: func(t *testing.T, req *webhook.RequestBody) { assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, &webhook.X5CCertificate{ assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw, Raw: cert.Raw,
PublicKey: key, PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -320,6 +339,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -332,6 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -342,15 +363,17 @@ func TestWebhookController_Authorize(t *testing.T) {
for i, wh := range test.ctl.webhooks { for i, wh := range test.ctl.webhooks {
var j = i var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j]) err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err) require.NoError(t, err)
})) }))
// nolint: gocritic // defer in loop isn't a memory leak // nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close() defer ts.Close()
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Authorize(test.req) err := test.ctl.Authorize(test.ctx, test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
@ -366,6 +389,7 @@ func TestWebhook_Do(t *testing.T) {
type test struct { type test struct {
webhook Webhook webhook Webhook
dataArg any dataArg any
requestID string
webhookResponse webhook.ResponseBody webhookResponse webhook.ResponseBody
expectPath string expectPath string
errStatusCode int errStatusCode int
@ -375,6 +399,16 @@ func TestWebhook_Do(t *testing.T) {
} }
tests := map[string]test{ tests := map[string]test{
"ok": { "ok": {
webhook: Webhook{
ID: "abc123",
Secret: "c2VjcmV0Cg==",
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
},
"ok/no-request-id": {
webhook: Webhook{ webhook: Webhook{
ID: "abc123", ID: "abc123",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
@ -389,6 +423,7 @@ func TestWebhook_Do(t *testing.T) {
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
BearerToken: "mytoken", BearerToken: "mytoken",
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -405,6 +440,7 @@ func TestWebhook_Do(t *testing.T) {
Password: "mypass", Password: "mypass",
}, },
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -416,7 +452,8 @@ func TestWebhook_Do(t *testing.T) {
URL: "/users/{{ .username }}?region={{ .region }}", URL: "/users/{{ .username }}?region={{ .region }}",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
}, },
dataArg: map[string]interface{}{"username": "areed", "region": "central"}, requestID: "reqID",
dataArg: map[string]interface{}{"username": "areed", "region": "central"},
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -451,6 +488,7 @@ func TestWebhook_Do(t *testing.T) {
ID: "abc123", ID: "abc123",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Allow: true, Allow: true,
}, },
@ -463,6 +501,7 @@ func TestWebhook_Do(t *testing.T) {
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
requestID: "reqID",
errStatusCode: 404, errStatusCode: 404,
serverErrMsg: "item not found", serverErrMsg: "item not found",
expectErr: errors.New("Webhook server responded with 404"), expectErr: errors.New("Webhook server responded with 404"),
@ -471,17 +510,20 @@ func TestWebhook_Do(t *testing.T) {
for name, tc := range tests { for name, tc := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Smallstep-Webhook-ID") if tc.requestID != "" {
assert.Equals(t, tc.webhook.ID, id) assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID"))
}
assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID"))
sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature"))
assert.FatalError(t, err) assert.NoError(t, err)
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.NoError(t, err)
secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret)
assert.FatalError(t, err) assert.NoError(t, err)
h := hmac.New(sha256.New, secret) h := hmac.New(sha256.New, secret)
h.Write(body) h.Write(body)
mac := h.Sum(nil) mac := h.Sum(nil)
@ -490,19 +532,19 @@ func TestWebhook_Do(t *testing.T) {
switch { switch {
case tc.webhook.BearerToken != "": case tc.webhook.BearerToken != "":
ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken)
assert.Equals(t, ah, r.Header.Get("Authorization")) assert.Equal(t, ah, r.Header.Get("Authorization"))
case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "":
whReq, err := http.NewRequest("", "", http.NoBody) whReq, err := http.NewRequest("", "", http.NoBody)
assert.FatalError(t, err) require.NoError(t, err)
whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password)
ah := whReq.Header.Get("Authorization") ah := whReq.Header.Get("Authorization")
assert.Equals(t, ah, whReq.Header.Get("Authorization")) assert.Equal(t, ah, whReq.Header.Get("Authorization"))
default: default:
assert.Equals(t, "", r.Header.Get("Authorization")) assert.Equal(t, "", r.Header.Get("Authorization"))
} }
if tc.expectPath != "" { if tc.expectPath != "" {
assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery)
} }
if tc.errStatusCode != 0 { if tc.errStatusCode != 0 {
@ -512,26 +554,33 @@ func TestWebhook_Do(t *testing.T) {
reqBody := new(webhook.RequestBody) reqBody := new(webhook.RequestBody)
err = json.Unmarshal(body, reqBody) err = json.Unmarshal(body, reqBody)
assert.FatalError(t, err) require.NoError(t, err)
// assert.Equals(t, tc.expectToken, reqBody.Token)
err = json.NewEncoder(w).Encode(tc.webhookResponse) err = json.NewEncoder(w).Encode(tc.webhookResponse)
assert.FatalError(t, err) require.NoError(t, err)
})) }))
defer ts.Close() defer ts.Close()
tc.webhook.URL = ts.URL + tc.webhook.URL tc.webhook.URL = ts.URL + tc.webhook.URL
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err) require.NoError(t, err)
got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg)
ctx := context.Background()
if tc.requestID != "" {
ctx = withRequestID(context.Background(), tc.requestID)
}
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg)
if tc.expectErr != nil { if tc.expectErr != nil {
assert.Equals(t, tc.expectErr.Error(), err.Error()) assert.Equal(t, tc.expectErr.Error(), err.Error())
return return
} }
assert.FatalError(t, err) assert.NoError(t, err)
assert.Equals(t, got, &tc.webhookResponse) assert.Equal(t, &tc.webhookResponse, got)
}) })
} }
@ -544,7 +593,7 @@ func TestWebhook_Do(t *testing.T) {
URL: ts.URL, URL: ts.URL,
} }
cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key")
assert.FatalError(t, err) require.NoError(t, err)
transport := http.DefaultTransport.(*http.Transport).Clone() transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{ transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -554,12 +603,19 @@ func TestWebhook_Do(t *testing.T) {
Transport: transport, Transport: transport,
} }
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err) require.NoError(t, err)
_, err = wh.Do(client, reqBody, nil)
assert.FatalError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
wh.DisableTLSClientAuth = true wh.DisableTLSClientAuth = true
_, err = wh.Do(client, reqBody, nil) _, err = wh.DoWithContext(ctx, client, reqBody, nil)
assert.Error(t, err) require.Error(t, err)
}) })
} }

@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
opts, err := a.Authorize(ctx, token) opts, err := a.Authorize(ctx, token)
require.NoError(t, err) require.NoError(t, err)
opts = append(opts, extraOpts...) opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
require.NoError(t, err) require.NoError(t, err)
return certs[0] return certs[0]
} }

@ -152,7 +152,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return cert, err return cert, err
} }
func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) {
var ( var (
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
@ -211,7 +211,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Call enriching webhooks // Call enriching webhooks
if err := a.callEnrichingWebhooksSSH(prov, webhookCtl, cr); err != nil { if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
@ -284,7 +284,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksSSH(prov, webhookCtl, certificate, certTpl); err != nil { if err := a.callAuthorizingWebhooksSSH(ctx, prov, webhookCtl, certificate, certTpl); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
) )
@ -671,7 +671,7 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal) return strings.ReplaceAll(cmd, "<principal>", principal)
} }
func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) { func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
@ -680,7 +680,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook
if whEnrichReq, err = webhook.NewRequestBody( if whEnrichReq, err = webhook.NewRequestBody(
webhook.WithSSHCertificateRequest(cr), webhook.WithSSHCertificateRequest(cr),
); err == nil { ); err == nil {
err = webhookCtl.Enrich(whEnrichReq) err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.SSHWebhookEnriched(prov, err) a.meter.SSHWebhookEnriched(prov, err)
} }
@ -688,7 +688,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook
return return
} }
func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) { func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
@ -697,7 +697,7 @@ func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webho
if whAuthBody, err = webhook.NewRequestBody( if whAuthBody, err = webhook.NewRequestBody(
webhook.WithSSHCertificate(cert, certTpl), webhook.WithSSHCertificate(cert, certTpl),
); err == nil { ); err == nil {
err = webhookCtl.Authorize(whAuthBody) err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.SSHWebhookAuthorized(prov, err) a.meter.SSHWebhookAuthorized(prov, err)
} }

@ -91,14 +91,23 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
} }
} }
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(csr, signOpts, extraOpts...) return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}
// SignWithContext creates a signed certificate from a certificate signing
// request, taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(ctx, csr, signOpts, extraOpts...)
a.meter.X509Signed(prov, err) a.meter.X509Signed(prov, err)
return chain, err return chain, err
} }
func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) { func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) {
var ( var (
certOptions []x509util.Option certOptions []x509util.Option
certValidators []provisioner.CertificateValidator certValidators []provisioner.CertificateValidator
@ -171,7 +180,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
} }
} }
if err := a.callEnrichingWebhooksX509(prov, webhookCtl, attData, csr); err != nil { if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr), errs.WithKeyVal("csr", csr),
@ -265,7 +274,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksX509(prov, webhookCtl, crt, leaf, attData); err != nil { if err := a.callAuthorizingWebhooksX509(ctx, prov, webhookCtl, crt, leaf, attData); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"), errs.ForbiddenErr(err, "error creating certificate"),
opts..., opts...,
@ -986,7 +995,7 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template") return errors.Wrap(cause, "error applying certificate template")
} }
func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) { func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
@ -1003,7 +1012,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo
webhook.WithX509CertificateRequest(csr), webhook.WithX509CertificateRequest(csr),
webhook.WithAttestationData(attested), webhook.WithAttestationData(attested),
); err == nil { ); err == nil {
err = webhookCtl.Enrich(whEnrichReq) err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.X509WebhookEnriched(prov, err) a.meter.X509WebhookEnriched(prov, err)
} }
@ -1011,7 +1020,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo
return return
} }
func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) { func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
@ -1028,7 +1037,7 @@ func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webh
webhook.WithX509Certificate(cert, leaf), webhook.WithX509Certificate(cert, leaf),
webhook.WithAttestationData(attested), webhook.WithAttestationData(attested),
); err == nil { ); err == nil {
err = webhookCtl.Authorize(whAuthBody) err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.X509WebhookAuthorized(prov, err) a.meter.X509WebhookAuthorized(prov, err)
} }

@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
return nil return nil
} }
func TestAuthority_Sign(t *testing.T) { func TestAuthority_SignWithContext(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair() pub, priv, err := keyutil.GenerateDefaultKeyPair()
require.NoError(t, err) require.NoError(t, err)
@ -848,7 +848,7 @@ ZYtQ9Ot36qc=
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption) _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr)
} }
_, err = auth.Renew(cert) _, err = auth.Renew(cert)

@ -1,8 +1,12 @@
package authority package authority
import "github.com/smallstep/certificates/webhook" import (
"context"
"github.com/smallstep/certificates/webhook"
)
type webhookController interface { type webhookController interface {
Enrich(*webhook.RequestBody) error Enrich(context.Context, *webhook.RequestBody) error
Authorize(*webhook.RequestBody) error Authorize(context.Context, *webhook.RequestBody) error
} }

@ -1,6 +1,8 @@
package authority package authority
import ( import (
"context"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
) )
@ -14,7 +16,7 @@ type mockWebhookController struct {
var _ webhookController = &mockWebhookController{} var _ webhookController = &mockWebhookController{}
func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData { for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data) wc.templateData.SetWebhook(key, data)
} }
@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
return wc.enrichErr return wc.enrichErr
} }
func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr return wc.authorizeErr
} }

@ -60,7 +60,7 @@ func MustFromContext(ctx context.Context) *Authority {
// SignAuthority is the interface for a signing authority // SignAuthority is the interface for a signing authority
type SignAuthority interface { type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
LoadProvisionerByName(string) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error)
} }
@ -306,7 +306,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m
} }
signOps = append(signOps, templateOptions) signOps = append(signOps, templateOptions)
certChain, err := a.signAuth.Sign(csr, opts, signOps...) certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...)
if err != nil { if err != nil {
return nil, fmt.Errorf("error generating certificate for order: %w", err) return nil, fmt.Errorf("error generating certificate for order: %w", err)
} }

Loading…
Cancel
Save