diff --git a/authority/authority.go b/authority/authority.go index f396c588..cc26635e 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -70,10 +70,12 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) - sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) - getIdentityFunc provisioner.GetIdentityFunc + sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) + sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) + getIdentityFunc provisioner.GetIdentityFunc + authorizeRenewFunc provisioner.AuthorizeRenewFunc + authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc adminMutex sync.RWMutex } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 74f313e7..b0ab04ec 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1011,6 +1011,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } func TestAuthority_authorizeSSHRenew(t *testing.T) { + now := time.Now().UTC() + sshpop := func(a *Authority) (*ssh.Certificate, string) { + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return cert, token + } + a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) @@ -1020,8 +1037,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) - now := time.Now().UTC() - validIssuer := "step-cli" type authorizeTest struct { @@ -1058,27 +1073,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { code: http.StatusUnauthorized, } }, + "fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return errs.Forbidden("forbidden") + })) + _, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, + err: errors.New("authority.authorizeSSHRenew: forbidden"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *authorizeTest { - key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") - assert.FatalError(t, err) - signer, ok := key.(crypto.Signer) - assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") - sshSigner, err := ssh.NewSignerFromSigner(signer) - assert.FatalError(t, err) - - cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) - assert.FatalError(t, err) - - p, ok := a.provisioners.Load("sshpop/sshpop") - assert.Fatal(t, ok, "sshpop provisioner not found in test authority") - - tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", - []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) - assert.FatalError(t, err) - + cert, token := sshpop(a) return &authorizeTest{ auth: a, - token: tok, + token: token, + cert: cert, + } + }, + "ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return nil + })) + cert, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, cert: cert, } }, diff --git a/authority/options.go b/authority/options.go index f92db99b..a1238b1d 100644 --- a/authority/options.go +++ b/authority/options.go @@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e } } +// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of +// an X.509 certificate. +func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeRenewFunc = fn + return nil + } +} + +// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal +// of a SSH certificate. +func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeSSHRenewFunc = fn + return nil + } +} + // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { diff --git a/authority/provisioners.go b/authority/provisioners.go index 8dc27c6a..780d12c0 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner. UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, - GetIdentityFunc: a.getIdentityFunc, + GetIdentityFunc: a.getIdentityFunc, + AuthorizeRenewFunc: a.authorizeRenewFunc, + AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, }, nil } diff --git a/authority/tls_test.go b/authority/tls_test.go index 07538701..6ccf02ca 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -802,6 +802,19 @@ func TestAuthority_Renew(t *testing.T) { code: http.StatusUnauthorized, }, nil }, + "fail/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return errs.Unauthorized("not authorized") + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), + code: http.StatusUnauthorized, + }, nil + }, "ok": func() (*renewTest, error) { return &renewTest{ auth: a, @@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) { cert: cert, }, nil }, + "ok/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return nil + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + }, nil + }, } for name, genTestCase := range tests {