diff --git a/api/api.go b/api/api.go index 9e6ee301..16e24bb2 100644 --- a/api/api.go +++ b/api/api.go @@ -25,9 +25,6 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/templates" - "go.step.sm/linkedca" - "golang.org/x/crypto/ssh" ) // Authority is the interface implemented by a CA authority. @@ -51,21 +48,6 @@ type Authority interface { Version() authority.Version } -type LinkedAuthority interface { // TODO(hs): name is not great; it is related to LinkedCA, though - Authority - IsAdminAPIEnabled() bool - LoadAdminByID(id string) (*linkedca.Admin, bool) - GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) - StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error - UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) - RemoveAdmin(ctx context.Context, id string) error - AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) - StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error - LoadProvisionerByID(id string) (provisioner.Interface, error) - UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error - RemoveProvisioner(ctx context.Context, id string) error -} - // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -475,296 +457,3 @@ func fmtPublicKey(cert *x509.Certificate) string { } return fmt.Sprintf("%s %s", cert.PublicKeyAlgorithm, params) } - -type MockAuthority struct { - ret1, ret2 interface{} - err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) - getTLSOptions func() *authority.TLSOptions - root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - renew func(cert *x509.Certificate) ([]*x509.Certificate, error) - rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) - loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) - MockLoadProvisionerByName func(name string) (provisioner.Interface, error) - MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) - revoke func(context.Context, *authority.RevokeOptions) error - getEncryptedKey func(kid string) (string, error) - getRoots func() ([]*x509.Certificate, error) - getFederation func() ([]*x509.Certificate, error) - signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) - getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) - getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) - getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) - checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) - version func() authority.Version - - MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two - MockErr error - MockIsAdminAPIEnabled func() bool - MockLoadAdminByID func(id string) (*linkedca.Admin, bool) - MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) - MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error - MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) - MockRemoveAdmin func(ctx context.Context, id string) error - MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) - MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error - MockLoadProvisionerByID func(id string) (provisioner.Interface, error) - MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error - MockRemoveProvisioner func(ctx context.Context, id string) error -} - -// TODO: remove once Authorize is deprecated. -func (m *MockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *MockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) - } - return m.ret1.([]provisioner.SignOption), m.err -} - -func (m *MockAuthority) GetTLSOptions() *authority.TLSOptions { - if m.getTLSOptions != nil { - return m.getTLSOptions() - } - return m.ret1.(*authority.TLSOptions) -} - -func (m *MockAuthority) Root(shasum string) (*x509.Certificate, error) { - if m.root != nil { - return m.root(shasum) - } - return m.ret1.(*x509.Certificate), m.err -} - -func (m *MockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *MockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { - if m.renew != nil { - return m.renew(cert) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *MockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { - if m.rekey != nil { - return m.rekey(oldcert, pk) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *MockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { - if m.MockGetProvisioners != nil { - return m.MockGetProvisioners(nextCursor, limit) - } - return m.ret1.(provisioner.List), m.ret2.(string), m.err -} - -func (m *MockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { - if m.loadProvisionerByCertificate != nil { - return m.loadProvisionerByCertificate(cert) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *MockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { - if m.MockLoadProvisionerByName != nil { - return m.MockLoadProvisionerByName(name) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *MockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { - if m.revoke != nil { - return m.revoke(ctx, opts) - } - return m.err -} - -func (m *MockAuthority) GetEncryptedKey(kid string) (string, error) { - if m.getEncryptedKey != nil { - return m.getEncryptedKey(kid) - } - return m.ret1.(string), m.err -} - -func (m *MockAuthority) GetRoots() ([]*x509.Certificate, error) { - if m.getRoots != nil { - return m.getRoots() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *MockAuthority) GetFederation() ([]*x509.Certificate, error) { - if m.getFederation != nil { - return m.getFederation() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *MockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(ctx, key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *MockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(ctx, key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *MockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.renewSSH != nil { - return m.renewSSH(ctx, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *MockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.rekeySSH != nil { - return m.rekeySSH(ctx, cert, key, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *MockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { - if m.getSSHHosts != nil { - return m.getSSHHosts(ctx, cert) - } - return m.ret1.([]authority.Host), m.err -} - -func (m *MockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHRoots != nil { - return m.getSSHRoots(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *MockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHFederation != nil { - return m.getSSHFederation(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *MockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { - if m.getSSHConfig != nil { - return m.getSSHConfig(ctx, typ, data) - } - return m.ret1.([]templates.Output), m.err -} - -func (m *MockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { - if m.checkSSHHost != nil { - return m.checkSSHHost(ctx, principal, token) - } - return m.ret1.(bool), m.err -} - -func (m *MockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { - if m.getSSHBastion != nil { - return m.getSSHBastion(ctx, user, hostname) - } - return m.ret1.(*authority.Bastion), m.err -} - -func (m *MockAuthority) Version() authority.Version { - if m.version != nil { - return m.version() - } - return m.ret1.(authority.Version) -} - -func (m *MockAuthority) IsAdminAPIEnabled() bool { - if m.MockIsAdminAPIEnabled != nil { - return m.MockIsAdminAPIEnabled() - } - return m.MockRet1.(bool) -} - -func (m *MockAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { - if m.MockLoadAdminByID != nil { - return m.MockLoadAdminByID(id) - } - return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) -} - -func (m *MockAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { - if m.MockGetAdmins != nil { - return m.MockGetAdmins(cursor, limit) - } - return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr -} - -func (m *MockAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { - if m.MockStoreAdmin != nil { - return m.MockStoreAdmin(ctx, adm, prov) - } - return m.MockErr -} - -func (m *MockAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { - if m.MockUpdateAdmin != nil { - return m.MockUpdateAdmin(ctx, id, nu) - } - return m.MockRet1.(*linkedca.Admin), m.MockErr -} - -func (m *MockAuthority) RemoveAdmin(ctx context.Context, id string) error { - if m.MockRemoveAdmin != nil { - return m.MockRemoveAdmin(ctx, id) - } - return m.MockErr -} - -func (m *MockAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { - if m.MockAuthorizeAdminToken != nil { - return m.MockAuthorizeAdminToken(r, token) - } - return m.MockRet1.(*linkedca.Admin), m.MockErr -} - -func (m *MockAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { - if m.MockStoreProvisioner != nil { - return m.MockStoreProvisioner(ctx, prov) - } - return m.MockErr -} - -func (m *MockAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { - if m.MockLoadProvisionerByID != nil { - return m.MockLoadProvisionerByID(id) - } - return m.MockRet1.(provisioner.Interface), m.MockErr -} - -func (m *MockAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { - if m.MockUpdateProvisioner != nil { - return m.MockUpdateProvisioner(ctx, nu) - } - return m.MockErr -} - -func (m *MockAuthority) RemoveProvisioner(ctx context.Context, id string) error { - if m.MockRemoveProvisioner != nil { - return m.MockRemoveProvisioner(ctx, id) - } - return m.MockErr -} diff --git a/api/api_test.go b/api/api_test.go index 6a845249..c7528f9b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto" "crypto/dsa" //nolint "crypto/ecdsa" "crypto/ed25519" @@ -31,6 +32,7 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) @@ -165,6 +167,208 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } +type mockAuthority struct { + ret1, ret2 interface{} + err error + authorizeSign func(ott string) ([]provisioner.SignOption, error) + getTLSOptions func() *authority.TLSOptions + root func(shasum string) (*x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) + loadProvisionerByName func(name string) (provisioner.Interface, error) + getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + revoke func(context.Context, *authority.RevokeOptions) error + getEncryptedKey func(kid string) (string, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) + checkSSHHost func(ctx context.Context, principal, token string) (bool, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) + version func() authority.Version +} + +// TODO: remove once Authorize is deprecated. +func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + return m.AuthorizeSign(ott) +} + +func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { + if m.authorizeSign != nil { + return m.authorizeSign(ott) + } + return m.ret1.([]provisioner.SignOption), m.err +} + +func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { + if m.getTLSOptions != nil { + return m.getTLSOptions() + } + return m.ret1.(*authority.TLSOptions) +} + +func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { + if m.root != nil { + return m.root(shasum) + } + return m.ret1.(*x509.Certificate), m.err +} + +func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.sign != nil { + return m.sign(cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.renew != nil { + return m.renew(cert) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.rekey != nil { + return m.rekey(oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.getProvisioners != nil { + return m.getProvisioners(nextCursor, limit) + } + return m.ret1.(provisioner.List), m.ret2.(string), m.err +} + +func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { + if m.loadProvisionerByCertificate != nil { + return m.loadProvisionerByCertificate(cert) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.loadProvisionerByName != nil { + return m.loadProvisionerByName(name) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { + if m.revoke != nil { + return m.revoke(ctx, opts) + } + return m.err +} + +func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { + if m.getEncryptedKey != nil { + return m.getEncryptedKey(kid) + } + return m.ret1.(string), m.err +} + +func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { + if m.getRoots != nil { + return m.getRoots() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.renewSSH != nil { + return m.renewSSH(ctx, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.rekeySSH != nil { + return m.rekeySSH(ctx, cert, key, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { + if m.getSSHHosts != nil { + return m.getSSHHosts(ctx, cert) + } + return m.ret1.([]authority.Host), m.err +} + +func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHRoots != nil { + return m.getSSHRoots(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHFederation != nil { + return m.getSSHFederation(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { + if m.getSSHConfig != nil { + return m.getSSHConfig(ctx, typ, data) + } + return m.ret1.([]templates.Output), m.err +} + +func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { + if m.checkSSHHost != nil { + return m.checkSSHHost(ctx, principal, token) + } + return m.ret1.(bool), m.err +} + +func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { + if m.getSSHBastion != nil { + return m.getSSHBastion(ctx, user, hostname) + } + return m.ret1.(*authority.Bastion), m.err +} + +func (m *mockAuthority) Version() authority.Version { + if m.version != nil { + return m.version() + } + return m.ret1.(authority.Version) +} + func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { @@ -561,7 +765,7 @@ func Test_caHandler_Route(t *testing.T) { fields fields args args }{ - {"ok", fields{&MockAuthority{}}, args{chi.NewRouter()}}, + {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -576,7 +780,7 @@ func Test_caHandler_Route(t *testing.T) { func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&MockAuthority{}).(*caHandler) + h := New(&mockAuthority{}).(*caHandler) h.Health(w, req) res := w.Result() @@ -616,7 +820,7 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) w := httptest.NewRecorder() h.Root(w, req) res := w.Result() @@ -680,7 +884,7 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr @@ -734,7 +938,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -795,7 +999,7 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -873,9 +1077,9 @@ func Test_caHandler_Provisioners(t *testing.T) { args args statusCode int }{ - {"ok", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&MockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, - {"limit fail", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, + {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, + {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) @@ -950,8 +1154,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { args args statusCode int }{ - {"ok", fields{&MockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&MockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, + {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) @@ -1010,7 +1214,7 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() @@ -1056,7 +1260,7 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() diff --git a/api/revoke_test.go b/api/revoke_test.go index b0eaef3d..4ed4e3fe 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -106,7 +106,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusOK, - auth: &MockAuthority{ + auth: &mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -150,7 +150,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, tls: cs, - auth: &MockAuthority{ + auth: &mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -185,7 +185,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusInternalServerError, - auth: &MockAuthority{ + auth: &mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -207,7 +207,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusForbidden, - auth: &MockAuthority{ + auth: &mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, diff --git a/api/ssh_test.go b/api/ssh_test.go index df9e2f45..a3d7da0d 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -314,7 +314,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, @@ -377,7 +377,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -431,7 +431,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -491,7 +491,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, @@ -538,7 +538,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, @@ -589,7 +589,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, @@ -644,7 +644,7 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&MockAuthority{ + h := New(&mockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 48e81ecc..18959acb 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -54,7 +54,7 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { // provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME // provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *admin.Error) { +func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, error) { var ( p provisioner.Interface err error diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 15c581f4..ba956f21 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -17,7 +17,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" @@ -39,7 +38,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context db admin.DB - auth api.LinkedAuthority + auth adminAuthority next nextHTTP err *admin.Error statusCode int @@ -49,7 +48,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("prov", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") @@ -68,7 +67,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("prov", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ @@ -108,7 +107,7 @@ func TestHandler_requireEABEnabled(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("prov", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ @@ -185,14 +184,14 @@ func TestHandler_requireEABEnabled(t *testing.T) { func TestHandler_provisionerHasEABEnabled(t *testing.T) { type test struct { db admin.DB - auth api.LinkedAuthority + auth adminAuthority provisionerName string want bool err *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadProvisionerByName": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") @@ -206,7 +205,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } }, "fail/db.GetProvisioner": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ @@ -231,7 +230,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } }, "fail/prov.GetDetails": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ @@ -260,7 +259,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } }, "fail/details.GetACME": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ @@ -293,7 +292,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } }, "ok/eab-disabled": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "eab-disabled", name) return &provisioner.MockProvisioner{ @@ -327,7 +326,7 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { } }, "ok/eab-enabled": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "eab-enabled", name) return &provisioner.MockProvisioner{ @@ -375,16 +374,13 @@ func TestHandler_provisionerHasEABEnabled(t *testing.T) { return } if tc.err != nil { - // TODO(hs): the output of the diff seems to be equal to each other; not sure why it's marked as different =/ - // opts := []cmp.Option{cmpopts.EquateErrors()} - // if !cmp.Equal(tc.err, err, opts...) { - // t.Errorf("Handler.provisionerHasEABEnabled() diff =\n%v", cmp.Diff(tc.err, err, opts...)) - // } - assert.Equals(t, tc.err.Type, err.Type) - assert.Equals(t, tc.err.Status, err.Status) - assert.Equals(t, tc.err.StatusCode(), err.StatusCode()) - assert.Equals(t, tc.err.Message, err.Message) - assert.Equals(t, tc.err.Detail, err.Detail) + assert.Type(t, &admin.Error{}, err) + adminError, _ := err.(*admin.Error) + assert.Equals(t, tc.err.Type, adminError.Type) + assert.Equals(t, tc.err.Status, adminError.Status) + assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode()) + assert.Equals(t, tc.err.Message, adminError.Message) + assert.Equals(t, tc.err.Detail, adminError.Detail) return } if got != tc.want { diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index bf79ebcf..7aa66d0f 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -1,14 +1,32 @@ package api import ( + "context" "net/http" "github.com/go-chi/chi" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" ) +type adminAuthority interface { + LoadProvisionerByName(string) (provisioner.Interface, error) + GetProvisioners(cursor string, limit int) (provisioner.List, string, error) + IsAdminAPIEnabled() bool + LoadAdminByID(id string) (*linkedca.Admin, bool) + GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) + StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + RemoveAdmin(ctx context.Context, id string) error + AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) + StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + LoadProvisionerByID(id string) (provisioner.Interface, error) + UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error + RemoveProvisioner(ctx context.Context, id string) error +} + // CreateAdminRequest represents the body for a CreateAdmin request. type CreateAdminRequest struct { Subject string `json:"subject"` diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index da044d58..8d223b52 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -15,13 +15,121 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/smallstep/assert" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" "google.golang.org/protobuf/types/known/timestamppb" ) +type mockAdminAuthority struct { + MockLoadProvisionerByName func(name string) (provisioner.Interface, error) + MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two + MockErr error + MockIsAdminAPIEnabled func() bool + MockLoadAdminByID func(id string) (*linkedca.Admin, bool) + MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) + MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + MockRemoveAdmin func(ctx context.Context, id string) error + MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) + MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockLoadProvisionerByID func(id string) (provisioner.Interface, error) + MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error + MockRemoveProvisioner func(ctx context.Context, id string) error +} + +func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { + if m.MockIsAdminAPIEnabled != nil { + return m.MockIsAdminAPIEnabled() + } + return m.MockRet1.(bool) +} + +func (m *mockAdminAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByName != nil { + return m.MockLoadProvisionerByName(name) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *mockAdminAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.MockGetProvisioners != nil { + return m.MockGetProvisioners(nextCursor, limit) + } + return m.MockRet1.(provisioner.List), m.MockRet2.(string), m.MockErr +} + +func (m *mockAdminAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { + if m.MockLoadAdminByID != nil { + return m.MockLoadAdminByID(id) + } + return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) +} + +func (m *mockAdminAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { + if m.MockGetAdmins != nil { + return m.MockGetAdmins(cursor, limit) + } + return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr +} + +func (m *mockAdminAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + if m.MockStoreAdmin != nil { + return m.MockStoreAdmin(ctx, adm, prov) + } + return m.MockErr +} + +func (m *mockAdminAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + if m.MockUpdateAdmin != nil { + return m.MockUpdateAdmin(ctx, id, nu) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *mockAdminAuthority) RemoveAdmin(ctx context.Context, id string) error { + if m.MockRemoveAdmin != nil { + return m.MockRemoveAdmin(ctx, id) + } + return m.MockErr +} + +func (m *mockAdminAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { + if m.MockAuthorizeAdminToken != nil { + return m.MockAuthorizeAdminToken(r, token) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *mockAdminAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockStoreProvisioner != nil { + return m.MockStoreProvisioner(ctx, prov) + } + return m.MockErr +} + +func (m *mockAdminAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByID != nil { + return m.MockLoadProvisionerByID(id) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *mockAdminAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { + if m.MockUpdateProvisioner != nil { + return m.MockUpdateProvisioner(ctx, nu) + } + return m.MockErr +} + +func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) error { + if m.MockRemoveProvisioner != nil { + return m.MockRemoveProvisioner(ctx, id) + } + return m.MockErr +} + func TestCreateAdminRequest_Validate(t *testing.T) { type fields struct { Subject string @@ -148,7 +256,7 @@ func TestUpdateAdminRequest_Validate(t *testing.T) { func TestHandler_GetAdmin(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority statusCode int err *admin.Error adm *linkedca.Admin @@ -158,7 +266,7 @@ func TestHandler_GetAdmin(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { assert.Equals(t, "adminID", id) return nil, false @@ -191,7 +299,7 @@ func TestHandler_GetAdmin(t *testing.T) { CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { assert.Equals(t, "adminID", id) return adm, true @@ -254,7 +362,7 @@ func TestHandler_GetAdmin(t *testing.T) { func TestHandler_GetAdmins(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority req *http.Request statusCode int err *admin.Error @@ -277,7 +385,7 @@ func TestHandler_GetAdmins(t *testing.T) { }, "fail/auth.GetAdmins": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", nil) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) @@ -319,7 +427,7 @@ func TestHandler_GetAdmins(t *testing.T) { CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) @@ -390,7 +498,7 @@ func TestHandler_GetAdmins(t *testing.T) { func TestHandler_CreateAdmin(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority body []byte statusCode int err *admin.Error @@ -439,7 +547,7 @@ func TestHandler_CreateAdmin(t *testing.T) { } body, err := json.Marshal(req) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return nil, errors.New("force") @@ -466,7 +574,7 @@ func TestHandler_CreateAdmin(t *testing.T) { } body, err := json.Marshal(req) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return &provisioner.ACME{ @@ -501,7 +609,7 @@ func TestHandler_CreateAdmin(t *testing.T) { } body, err := json.Marshal(req) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return &provisioner.ACME{ @@ -576,7 +684,7 @@ func TestHandler_CreateAdmin(t *testing.T) { func TestHandler_DeleteAdmin(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority statusCode int err *admin.Error } @@ -585,7 +693,7 @@ func TestHandler_DeleteAdmin(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockRemoveAdmin: func(ctx context.Context, id string) error { assert.Equals(t, "adminID", id) return errors.New("force") @@ -607,7 +715,7 @@ func TestHandler_DeleteAdmin(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockRemoveAdmin: func(ctx context.Context, id string) error { assert.Equals(t, "adminID", id) return nil @@ -666,7 +774,7 @@ func TestHandler_DeleteAdmin(t *testing.T) { func TestHandler_UpdateAdmin(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority body []byte statusCode int err *admin.Error @@ -714,7 +822,7 @@ func TestHandler_UpdateAdmin(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { assert.Equals(t, "adminID", id) assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) @@ -749,7 +857,7 @@ func TestHandler_UpdateAdmin(t *testing.T) { Subject: "admin", Type: linkedca.Admin_SUPER_ADMIN, } - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { assert.Equals(t, "adminID", id) assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index b3ed04bf..fcdb626b 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -9,12 +9,12 @@ import ( // Handler is the Admin API request handler. type Handler struct { db admin.DB - auth api.LinkedAuthority // was: *authority.Authority + auth adminAuthority acmeDB acme.DB } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth api.LinkedAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { return &Handler{ db: adminDB, auth: auth, diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 3231cc6d..7fb4671a 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -13,7 +13,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/smallstep/assert" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "go.step.sm/linkedca" "google.golang.org/protobuf/types/known/timestamppb" @@ -22,7 +21,7 @@ import ( func TestHandler_requireAPIEnabled(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority next nextHTTP err *admin.Error statusCode int @@ -31,7 +30,7 @@ func TestHandler_requireAPIEnabled(t *testing.T) { "fail/auth.IsAdminAPIEnabled": func(t *testing.T) test { return test{ ctx: context.Background(), - auth: &api.MockAuthority{ + auth: &mockAdminAuthority{ MockIsAdminAPIEnabled: func() bool { return false }, @@ -46,7 +45,7 @@ func TestHandler_requireAPIEnabled(t *testing.T) { } }, "ok": func(t *testing.T) test { - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockIsAdminAPIEnabled: func() bool { return true }, @@ -101,7 +100,7 @@ func TestHandler_requireAPIEnabled(t *testing.T) { func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority req *http.Request next nextHTTP err *admin.Error @@ -126,7 +125,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { "fail/auth.AuthorizeAdminToken": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", nil) req.Header["Authorization"] = []string{"token"} - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) return nil, admin.NewError( @@ -162,7 +161,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) return admin, nil diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 68a54fe8..6c463590 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -15,7 +15,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/smallstep/assert" - "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" @@ -26,7 +25,7 @@ import ( func TestHandler_GetProvisioner(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority db admin.DB req *http.Request statusCode int @@ -38,7 +37,7 @@ func TestHandler_GetProvisioner(t *testing.T) { req := httptest.NewRequest("GET", "/foo?id=provID", nil) chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") @@ -62,7 +61,7 @@ func TestHandler_GetProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") @@ -86,7 +85,7 @@ func TestHandler_GetProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.ACME{ @@ -120,7 +119,7 @@ func TestHandler_GetProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.ACME{ @@ -198,7 +197,7 @@ func TestHandler_GetProvisioner(t *testing.T) { func TestHandler_GetProvisioners(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority req *http.Request statusCode int err *admin.Error @@ -221,7 +220,7 @@ func TestHandler_GetProvisioners(t *testing.T) { }, "fail/auth.GetProvisioners": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", nil) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) @@ -255,7 +254,7 @@ func TestHandler_GetProvisioners(t *testing.T) { RequireEAB: false, }, } - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) @@ -324,7 +323,7 @@ func TestHandler_GetProvisioners(t *testing.T) { func TestHandler_CreateProvisioner(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority body []byte statusCode int err *admin.Error @@ -357,7 +356,7 @@ func TestHandler_CreateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { assert.Equals(t, "provID", prov.Id) return errors.New("force") @@ -384,7 +383,7 @@ func TestHandler_CreateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { assert.Equals(t, "provID", prov.Id) return nil @@ -447,7 +446,7 @@ func TestHandler_CreateProvisioner(t *testing.T) { func TestHandler_DeleteProvisioner(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority req *http.Request statusCode int err *admin.Error @@ -457,7 +456,7 @@ func TestHandler_DeleteProvisioner(t *testing.T) { req := httptest.NewRequest("DELETE", "/foo?id=provID", nil) chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") @@ -481,7 +480,7 @@ func TestHandler_DeleteProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") @@ -505,7 +504,7 @@ func TestHandler_DeleteProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -537,7 +536,7 @@ func TestHandler_DeleteProvisioner(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -604,7 +603,7 @@ func TestHandler_DeleteProvisioner(t *testing.T) { func TestHandler_UpdateProvisioner(t *testing.T) { type test struct { ctx context.Context - auth api.LinkedAuthority + auth adminAuthority body []byte db admin.DB statusCode int @@ -637,13 +636,9 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) - // return &provisioner.OIDC{ - // ID: "provID", - // Name: "provName", - // }, nil return nil, errors.New("force") }, } @@ -671,7 +666,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -711,7 +706,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -754,7 +749,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -799,7 +794,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -847,7 +842,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -898,7 +893,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -952,7 +947,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ @@ -1017,7 +1012,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { } body, err := protojson.Marshal(prov) assert.FatalError(t, err) - auth := &api.MockAuthority{ + auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{