diff --git a/api/api_test.go b/api/api_test.go index 6232dde5..1938e300 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -417,17 +417,22 @@ func TestSignRequest_Validate(t *testing.T) { } type mockProvisioner struct { - ret1, ret2, ret3 interface{} - err error - getID func() string - getTokenID func(string) (string, error) - getName func() string - getType func() provisioner.Type - getEncryptedKey func() (string, string, bool) - init func(provisioner.Config) error - authorizeRevoke func(ott string) error - authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) - authorizeRenewal func(*x509.Certificate) error + ret1, ret2, ret3 interface{} + err error + getID func() string + getTokenID func(string) (string, error) + getName func() string + getType func() provisioner.Type + getEncryptedKey func() (string, string, bool) + init func(provisioner.Config) error + authorizeRenew func(ctx context.Context, cert *x509.Certificate) error + authorizeRevoke func(ctx context.Context, token string) error + authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + authorizeRenewal func(*x509.Certificate) error + authorizeSSHSign func(ctx context.Context, token string) ([]provisioner.SignOption, error) + authorizeSSHRevoke func(ctx context.Context, token string) error + authorizeSSHRenew func(ctx context.Context, token string) (*ssh.Certificate, error) + authorizeSSHRekey func(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) } func (m *mockProvisioner) GetID() string { @@ -475,9 +480,16 @@ func (m *mockProvisioner) Init(c provisioner.Config) error { return m.err } -func (m *mockProvisioner) AuthorizeRevoke(ott string) error { +func (m *mockProvisioner) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + if m.authorizeRenew != nil { + return m.authorizeRenew(ctx, cert) + } + return m.err +} + +func (m *mockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { if m.authorizeRevoke != nil { - return m.authorizeRevoke(ott) + return m.authorizeRevoke(ctx, token) } return m.err } @@ -496,6 +508,31 @@ func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { return m.err } +func (m *mockProvisioner) AuthorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { + if m.authorizeSSHSign != nil { + return m.authorizeSSHSign(ctx, token) + } + return m.ret1.([]provisioner.SignOption), m.err +} +func (m *mockProvisioner) AuthorizeSSHRevoke(ctx context.Context, token string) error { + if m.authorizeSSHRevoke != nil { + return m.authorizeSSHRevoke(ctx, token) + } + return m.err +} +func (m *mockProvisioner) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { + if m.authorizeSSHRenew != nil { + return m.authorizeSSHRenew(ctx, token) + } + return m.ret1.(*ssh.Certificate), m.err +} +func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { + if m.authorizeSSHRekey != nil { + return m.authorizeSSHRekey(ctx, token) + } + return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err +} + type mockAuthority struct { ret1, ret2 interface{} err error @@ -509,10 +546,13 @@ type mockAuthority struct { loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) - revoke func(*authority.RevokeOptions) error + revoke func(context.Context, *authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) + renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func() ([]string, error) getSSHRoots func() (*authority.SSHKeys, error) getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) @@ -594,9 +634,9 @@ func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interfac return m.ret1.(provisioner.Interface), m.err } -func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error { +func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { if m.revoke != nil { - return m.revoke(opts) + return m.revoke(ctx, opts) } return m.err } @@ -622,6 +662,27 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { return m.ret1.([]*x509.Certificate), m.err } +func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.renewSSH != nil { + return m.renewSSH(cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.rekeySSH != nil { + return m.rekeySSH(cert, key, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) GetSSHHosts() ([]string, error) { + if m.getSSHHosts != nil { + return m.getSSHHosts() + } + return m.ret1.([]string), m.err +} + func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) { if m.getSSHRoots != nil { return m.getSSHRoots() diff --git a/api/revoke_test.go b/api/revoke_test.go index 477d90e8..9aa37d1a 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -105,7 +106,10 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, auth: &mockAuthority{ - revoke: func(opts *authority.RevokeOptions) error { + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return nil, nil + }, + revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { assert.True(t, opts.PassiveOnly) assert.False(t, opts.MTLS) assert.Equals(t, opts.Serial, "sn") @@ -146,7 +150,10 @@ func Test_caHandler_Revoke(t *testing.T) { statusCode: http.StatusOK, tls: cs, auth: &mockAuthority{ - revoke: func(ri *authority.RevokeOptions) error { + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return nil, nil + }, + revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { assert.True(t, ri.PassiveOnly) assert.True(t, ri.MTLS) assert.Equals(t, ri.Serial, "1404354960355712309") @@ -178,7 +185,10 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusInternalServerError, auth: &mockAuthority{ - revoke: func(opts *authority.RevokeOptions) error { + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return nil, nil + }, + revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return InternalServerError(errors.New("force")) }, }, @@ -197,7 +207,10 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusForbidden, auth: &mockAuthority{ - revoke: func(opts *authority.RevokeOptions) error { + authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + return nil, nil + }, + revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return errors.New("force") }, }, diff --git a/api/ssh_test.go b/api/ssh_test.go index 075428c0..e4e2fd9b 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -432,7 +432,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { func Test_caHandler_SSHConfig(t *testing.T) { userOutput := []templates.Output{ - {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/config/ssh/known_hosts")}, + {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, } hostOutput := []templates.Output{