Merge pull request #1542 from smallstep/herman/webhook-request-id

Propagate request ID when webhook requests are made
This commit is contained in:
Herman Slatman 2024-02-27 20:07:37 +01:00 committed by GitHub
commit c798735f7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 187 additions and 108 deletions

View File

@ -281,7 +281,7 @@ type mockCA struct {
MockAreSANsallowed func(ctx context.Context, sans []string) error
}
func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}

View File

@ -21,7 +21,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error

View File

@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
signOps = append(signOps, extraOptions...)
// Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{
certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...)

View File

@ -271,16 +271,16 @@ func TestOrder_UpdateStatus(t *testing.T) {
}
type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
err error
}
func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil {
return m.sign(csr, signOpts, extraOpts...)
func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil {
return nil, m.err
}
@ -578,7 +578,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return nil, errors.New("force")
},
@ -628,7 +628,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
@ -685,7 +685,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
@ -770,7 +770,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
@ -863,7 +863,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
@ -973,7 +973,7 @@ func TestOrder_Finalize(t *testing.T) {
// using the mocking functions as a wrapper for actual test helpers generated per test case or per
// function that's tested.
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
@ -1044,7 +1044,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
@ -1108,7 +1108,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
@ -1175,7 +1175,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},

View File

@ -42,7 +42,7 @@ type Authority interface {
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)

View File

@ -189,7 +189,7 @@ type mockAuthority struct {
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
@ -251,9 +251,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil {
return m.sign(cr, opts, signOpts...)
func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, cr, opts, signOpts...)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}

View File

@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
return
}
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return

View File

@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
})
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return

View File

@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) {
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
return tt.addUserCert, tt.addUserErr
},
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr
},
})

View File

@ -1,6 +1,7 @@
package authority
import (
"context"
"crypto"
"crypto/rand"
"crypto/sha256"
@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
csr, err := x509.ParseCertificateRequest(cr)
assert.FatalError(t, err)
cert, err := a.Sign(csr, provisioner.SignOptions{})
cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{})
assert.FatalError(t, err)
assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames)
assert.Equals(t, crt, cert[1])

View File

@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}

View File

@ -15,6 +15,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
@ -36,7 +37,7 @@ type WebhookController struct {
// Enrich fetches data from remote servers and adds returned data to the
// templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
@ -55,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
@ -68,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
}
// Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
@ -87,7 +92,11 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
@ -123,13 +132,6 @@ type Webhook struct {
} `json:"-"`
}
func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return w.DoWithContext(ctx, client, reqBody, data)
}
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil {
@ -169,6 +171,11 @@ retry:
return nil, err
}
requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
}
secret, err := base64.StdEncoding.DecodeString(w.Secret)
if err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
@ -8,15 +9,18 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
@ -92,19 +96,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh))
assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh))
})
}
}
// withRequestID is a helper that calls into [logging.WithRequestID] and returns
// a new context with the requestID added to the provided context.
func withRequestID(ctx context.Context, requestID string) context.Context {
return logging.WithRequestID(ctx, requestID)
}
func TestWebhookController_Enrich(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
type test struct {
ctl *WebhookController
ctx context.Context
req *webhook.RequestBody
responses []*webhook.ResponseBody
expectErr bool
@ -129,6 +138,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false,
@ -143,6 +153,7 @@ func TestWebhookController_Enrich(t *testing.T) {
},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}},
@ -166,6 +177,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{},
certType: linkedca.Webhook_X509,
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}},
@ -185,14 +197,15 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false,
expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}},
assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err)
assert.Equals(t, &webhook.X5CCertificate{
require.NoError(t, err)
assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw,
PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -207,6 +220,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
@ -221,6 +235,7 @@ func TestWebhookController_Enrich(t *testing.T) {
PublicKey: []byte("bad"),
})},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
@ -232,19 +247,21 @@ func TestWebhookController_Enrich(t *testing.T) {
for i, wh := range test.ctl.webhooks {
var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err)
require.NoError(t, err)
}))
// nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close()
wh.URL = ts.URL
}
err := test.ctl.Enrich(test.req)
err := test.ctl.Enrich(test.ctx, test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}
assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData)
assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData)
if test.assertRequest != nil {
test.assertRequest(t, test.req)
}
@ -254,12 +271,11 @@ func TestWebhookController_Enrich(t *testing.T) {
func TestWebhookController_Authorize(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
type test struct {
ctl *WebhookController
ctx context.Context
req *webhook.RequestBody
responses []*webhook.ResponseBody
expectErr bool
@ -280,6 +296,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false,
@ -290,6 +307,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
certType: linkedca.Webhook_SSH,
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false,
@ -300,13 +318,14 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false,
assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err)
assert.Equals(t, &webhook.X5CCertificate{
require.NoError(t, err)
assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw,
PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -320,6 +339,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
@ -332,6 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) {
PublicKey: []byte("bad"),
})},
},
ctx: withRequestID(context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
@ -342,15 +363,17 @@ func TestWebhookController_Authorize(t *testing.T) {
for i, wh := range test.ctl.webhooks {
var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err)
require.NoError(t, err)
}))
// nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close()
wh.URL = ts.URL
}
err := test.ctl.Authorize(test.req)
err := test.ctl.Authorize(test.ctx, test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}
@ -366,6 +389,7 @@ func TestWebhook_Do(t *testing.T) {
type test struct {
webhook Webhook
dataArg any
requestID string
webhookResponse webhook.ResponseBody
expectPath string
errStatusCode int
@ -375,6 +399,16 @@ func TestWebhook_Do(t *testing.T) {
}
tests := map[string]test{
"ok": {
webhook: Webhook{
ID: "abc123",
Secret: "c2VjcmV0Cg==",
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
},
"ok/no-request-id": {
webhook: Webhook{
ID: "abc123",
Secret: "c2VjcmV0Cg==",
@ -389,6 +423,7 @@ func TestWebhook_Do(t *testing.T) {
Secret: "c2VjcmV0Cg==",
BearerToken: "mytoken",
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
@ -405,6 +440,7 @@ func TestWebhook_Do(t *testing.T) {
Password: "mypass",
},
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
@ -416,7 +452,8 @@ func TestWebhook_Do(t *testing.T) {
URL: "/users/{{ .username }}?region={{ .region }}",
Secret: "c2VjcmV0Cg==",
},
dataArg: map[string]interface{}{"username": "areed", "region": "central"},
requestID: "reqID",
dataArg: map[string]interface{}{"username": "areed", "region": "central"},
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
@ -451,6 +488,7 @@ func TestWebhook_Do(t *testing.T) {
ID: "abc123",
Secret: "c2VjcmV0Cg==",
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Allow: true,
},
@ -463,6 +501,7 @@ func TestWebhook_Do(t *testing.T) {
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
requestID: "reqID",
errStatusCode: 404,
serverErrMsg: "item not found",
expectErr: errors.New("Webhook server responded with 404"),
@ -471,17 +510,20 @@ func TestWebhook_Do(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Smallstep-Webhook-ID")
assert.Equals(t, tc.webhook.ID, id)
if tc.requestID != "" {
assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID"))
}
assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID"))
sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature"))
assert.FatalError(t, err)
assert.NoError(t, err)
body, err := io.ReadAll(r.Body)
assert.FatalError(t, err)
assert.NoError(t, err)
secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret)
assert.FatalError(t, err)
assert.NoError(t, err)
h := hmac.New(sha256.New, secret)
h.Write(body)
mac := h.Sum(nil)
@ -490,19 +532,19 @@ func TestWebhook_Do(t *testing.T) {
switch {
case tc.webhook.BearerToken != "":
ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken)
assert.Equals(t, ah, r.Header.Get("Authorization"))
assert.Equal(t, ah, r.Header.Get("Authorization"))
case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "":
whReq, err := http.NewRequest("", "", http.NoBody)
assert.FatalError(t, err)
require.NoError(t, err)
whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password)
ah := whReq.Header.Get("Authorization")
assert.Equals(t, ah, whReq.Header.Get("Authorization"))
assert.Equal(t, ah, whReq.Header.Get("Authorization"))
default:
assert.Equals(t, "", r.Header.Get("Authorization"))
assert.Equal(t, "", r.Header.Get("Authorization"))
}
if tc.expectPath != "" {
assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery)
assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery)
}
if tc.errStatusCode != 0 {
@ -512,26 +554,33 @@ func TestWebhook_Do(t *testing.T) {
reqBody := new(webhook.RequestBody)
err = json.Unmarshal(body, reqBody)
assert.FatalError(t, err)
// assert.Equals(t, tc.expectToken, reqBody.Token)
require.NoError(t, err)
err = json.NewEncoder(w).Encode(tc.webhookResponse)
assert.FatalError(t, err)
require.NoError(t, err)
}))
defer ts.Close()
tc.webhook.URL = ts.URL + tc.webhook.URL
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err)
got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg)
require.NoError(t, err)
ctx := context.Background()
if tc.requestID != "" {
ctx = withRequestID(context.Background(), tc.requestID)
}
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg)
if tc.expectErr != nil {
assert.Equals(t, tc.expectErr.Error(), err.Error())
assert.Equal(t, tc.expectErr.Error(), err.Error())
return
}
assert.FatalError(t, err)
assert.NoError(t, err)
assert.Equals(t, got, &tc.webhookResponse)
assert.Equal(t, &tc.webhookResponse, got)
})
}
@ -544,7 +593,7 @@ func TestWebhook_Do(t *testing.T) {
URL: ts.URL,
}
cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key")
assert.FatalError(t, err)
require.NoError(t, err)
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
@ -554,12 +603,19 @@ func TestWebhook_Do(t *testing.T) {
Transport: transport,
}
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err)
_, err = wh.Do(client, reqBody, nil)
assert.FatalError(t, err)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
wh.DisableTLSClientAuth = true
_, err = wh.Do(client, reqBody, nil)
assert.Error(t, err)
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
require.Error(t, err)
})
}

View File

@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
opts, err := a.Authorize(ctx, token)
require.NoError(t, err)
opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
require.NoError(t, err)
return certs[0]
}

View File

@ -152,7 +152,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return cert, err
}
func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) {
func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) {
var (
certOptions []sshutil.Option
mods []provisioner.SSHCertModifier
@ -211,7 +211,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
}
// Call enriching webhooks
if err := a.callEnrichingWebhooksSSH(prov, webhookCtl, cr); err != nil {
if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
@ -284,7 +284,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
}
// Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksSSH(prov, webhookCtl, certificate, certTpl); err != nil {
if err := a.callAuthorizingWebhooksSSH(ctx, prov, webhookCtl, certificate, certTpl); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
)
@ -671,7 +671,7 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal)
}
func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) {
func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) {
if webhookCtl == nil {
return
}
@ -680,7 +680,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook
if whEnrichReq, err = webhook.NewRequestBody(
webhook.WithSSHCertificateRequest(cr),
); err == nil {
err = webhookCtl.Enrich(whEnrichReq)
err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.SSHWebhookEnriched(prov, err)
}
@ -688,7 +688,7 @@ func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhook
return
}
func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) {
func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) {
if webhookCtl == nil {
return
}
@ -697,7 +697,7 @@ func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webho
if whAuthBody, err = webhook.NewRequestBody(
webhook.WithSSHCertificate(cert, certTpl),
); err == nil {
err = webhookCtl.Authorize(whAuthBody)
err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.SSHWebhookAuthorized(prov, err)
}

View File

@ -91,14 +91,23 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
}
}
// Sign creates a signed certificate from a certificate signing request.
// Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(csr, signOpts, extraOpts...)
return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}
// SignWithContext creates a signed certificate from a certificate signing
// request, taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(ctx, csr, signOpts, extraOpts...)
a.meter.X509Signed(prov, err)
return chain, err
}
func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) {
func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) {
var (
certOptions []x509util.Option
certValidators []provisioner.CertificateValidator
@ -171,7 +180,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
}
}
if err := a.callEnrichingWebhooksX509(prov, webhookCtl, attData, csr); err != nil {
if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
@ -265,7 +274,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
}
// Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksX509(prov, webhookCtl, crt, leaf, attData); err != nil {
if err := a.callAuthorizingWebhooksX509(ctx, prov, webhookCtl, crt, leaf, attData); err != nil {
return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
@ -986,7 +995,7 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template")
}
func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) {
func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) {
if webhookCtl == nil {
return
}
@ -1003,7 +1012,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo
webhook.WithX509CertificateRequest(csr),
webhook.WithAttestationData(attested),
); err == nil {
err = webhookCtl.Enrich(whEnrichReq)
err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.X509WebhookEnriched(prov, err)
}
@ -1011,7 +1020,7 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo
return
}
func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) {
func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) {
if webhookCtl == nil {
return
}
@ -1028,7 +1037,7 @@ func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webh
webhook.WithX509Certificate(cert, leaf),
webhook.WithAttestationData(attested),
); err == nil {
err = webhookCtl.Authorize(whAuthBody)
err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.X509WebhookAuthorized(prov, err)
}

View File

@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
return nil
}
func TestAuthority_Sign(t *testing.T) {
func TestAuthority_SignWithContext(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair()
require.NoError(t, err)
@ -848,7 +848,7 @@ ZYtQ9Ot36qc=
t.Run(name, func(t *testing.T) {
tc := genTestCase(t)
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...)
certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...)
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) {
t.Fatal(err)
}
_, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption)
_, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr)
}
_, err = auth.Renew(cert)

View File

@ -1,8 +1,12 @@
package authority
import "github.com/smallstep/certificates/webhook"
import (
"context"
"github.com/smallstep/certificates/webhook"
)
type webhookController interface {
Enrich(*webhook.RequestBody) error
Authorize(*webhook.RequestBody) error
Enrich(context.Context, *webhook.RequestBody) error
Authorize(context.Context, *webhook.RequestBody) error
}

View File

@ -1,6 +1,8 @@
package authority
import (
"context"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook"
)
@ -14,7 +16,7 @@ type mockWebhookController struct {
var _ webhookController = &mockWebhookController{}
func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data)
}
@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
return wc.enrichErr
}
func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error {
func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr
}

View File

@ -60,7 +60,7 @@ func MustFromContext(ctx context.Context) *Authority {
// SignAuthority is the interface for a signing authority
type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
LoadProvisionerByName(string) (provisioner.Interface, error)
}
@ -306,7 +306,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m
}
signOps = append(signOps, templateOptions)
certChain, err := a.signAuth.Sign(csr, opts, signOps...)
certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...)
if err != nil {
return nil, fmt.Errorf("error generating certificate for order: %w", err)
}