Add SignWithContext method to authority and mocks

This commit is contained in:
Herman Slatman 2023-09-19 16:17:36 +02:00
parent b2301ea127
commit 4e06bdbc51
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F
11 changed files with 69 additions and 31 deletions

View File

@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov
return nil, nil return nil, nil
} }
func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil { if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans) return m.MockAreSANsallowed(ctx, sans)

View File

@ -22,6 +22,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority. // CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface { type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error) IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error

View File

@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
type mockSignAuth struct { type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error) loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{} ret1, ret2 interface{}
@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil {
return nil, m.err
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil { if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans) return m.areSANsAllowed(ctx, sans)

View File

@ -42,6 +42,7 @@ type Authority interface {
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)

View File

@ -193,6 +193,7 @@ type mockAuthority struct {
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, cr, opts, signOpts...)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
if m.renew != nil { if m.renew != nil {
return m.renew(cert) return m.renew(cert)

View File

@ -37,7 +37,7 @@ type WebhookController struct {
// Enrich fetches data from remote servers and adds returned data to the // Enrich fetches data from remote servers and adds returned data to the
// templateData // templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) { if !wc.isCertTypeOK(wh) {
continue continue
} }
// TODO(hs): propagate context from above
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }
@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
} }
// Authorize checks that all remote servers allow the request // Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
continue continue
} }
// TODO(hs): propagate context from above whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() //nolint:gocritic // every request canceled with its own timeout
defer cancel()
resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }

View File

@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) {
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Enrich(test.req) err := test.ctl.Enrich(context.Background(), test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) {
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Authorize(test.req) err := test.ctl.Authorize(context.Background(), test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }

View File

@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*
} }
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var ( var (
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Call enriching webhooks // Call enriching webhooks
if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil { if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil {
return nil, errs.ApplyOptions( return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil { if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil {
return nil, errs.ApplyOptions( return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
) )
@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal) return strings.ReplaceAll(cmd, "<principal>", principal)
} }
func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error { func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error {
if webhookCtl == nil { if webhookCtl == nil {
return nil return nil
} }
@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Enrich(whEnrichReq) return webhookCtl.Enrich(ctx, whEnrichReq)
} }
func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
if webhookCtl == nil { if webhookCtl == nil {
return nil return nil
} }
@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Authorize(whAuthBody) return webhookCtl.Authorize(ctx, whAuthBody)
} }

View File

@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
} }
} }
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}
// SignWithContext creates a signed certificate from a certificate signing request,
// taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var ( var (
certOptions []x509util.Option certOptions []x509util.Option
certValidators []provisioner.CertificateValidator certValidators []provisioner.CertificateValidator
@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
} }
} }
if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil { if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil {
return nil, errs.ApplyOptions( return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr), errs.WithKeyVal("csr", csr),
@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil { if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil {
return nil, errs.ApplyOptions( return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"), errs.ForbiddenErr(err, "error creating certificate"),
opts..., opts...,
@ -952,7 +961,7 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template") return errors.Wrap(cause, "error applying certificate template")
} }
func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error {
if webhookCtl == nil { if webhookCtl == nil {
return nil return nil
} }
@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Enrich(whEnrichReq) return webhookCtl.Enrich(ctx, whEnrichReq)
} }
func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error {
if webhookCtl == nil { if webhookCtl == nil {
return nil return nil
} }
@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Authorize(whAuthBody) return webhookCtl.Authorize(ctx, whAuthBody)
} }

View File

@ -1,8 +1,12 @@
package authority package authority
import "github.com/smallstep/certificates/webhook" import (
"context"
"github.com/smallstep/certificates/webhook"
)
type webhookController interface { type webhookController interface {
Enrich(*webhook.RequestBody) error Enrich(context.Context, *webhook.RequestBody) error
Authorize(*webhook.RequestBody) error Authorize(context.Context, *webhook.RequestBody) error
} }

View File

@ -1,6 +1,8 @@
package authority package authority
import ( import (
"context"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
) )
@ -14,7 +16,7 @@ type mockWebhookController struct {
var _ webhookController = &mockWebhookController{} var _ webhookController = &mockWebhookController{}
func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData { for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data) wc.templateData.SetWebhook(key, data)
} }
@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
return wc.enrichErr return wc.enrichErr
} }
func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr return wc.authorizeErr
} }