package authority import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/base64" "fmt" "net/http" "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "github.com/smallstep/cli/jose" "golang.org/x/crypto/ssh" ) type sshTestModifier ssh.Certificate func (m sshTestModifier) Modify(cert *ssh.Certificate) error { if m.CertType != 0 { cert.CertType = m.CertType } if m.KeyId != "" { cert.KeyId = m.KeyId } if m.ValidAfter != 0 { cert.ValidAfter = m.ValidAfter } if m.ValidBefore != 0 { cert.ValidBefore = m.ValidBefore } if len(m.ValidPrincipals) != 0 { cert.ValidPrincipals = m.ValidPrincipals } if m.Permissions.CriticalOptions != nil { cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions } if m.Permissions.Extensions != nil { cert.Permissions.Extensions = m.Permissions.Extensions } return nil } type sshTestCertModifier string func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error { if m == "" { return nil } return fmt.Errorf(string(m)) } type sshTestCertValidator string func (v sshTestCertValidator) Valid(crt *ssh.Certificate, opts provisioner.SSHOptions) error { if v == "" { return nil } return fmt.Errorf(string(v)) } type sshTestOptionsValidator string func (v sshTestOptionsValidator) Valid(opts provisioner.SSHOptions) error { if v == "" { return nil } return fmt.Errorf(string(v)) } type sshTestOptionsModifier string func (m sshTestOptionsModifier) Option(opts provisioner.SSHOptions) provisioner.SSHCertModifier { return sshTestCertModifier(string(m)) } func TestAuthority_SignSSH(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) userOptions := sshTestModifier{ CertType: ssh.UserCert, } hostOptions := sshTestModifier{ CertType: ssh.HostCert, } now := time.Now() type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer } type args struct { key ssh.PublicKey opts provisioner.SSHOptions signOpts []provisioner.SignOption } type want struct { CertType uint32 Principals []string ValidAfter uint64 ValidBefore uint64 } tests := []struct { name string fields fields args args want want wantErr bool }{ {"ok-user", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions}}, want{CertType: ssh.UserCert}, false}, {"ok-host", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{hostOptions}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, {"ok-opts-valid-after", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, {"ok-opts-valid-before", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, {"ok-cert-validator", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, {"ok-cert-modifier", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-validator", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-modifier", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, {"fail-opts-type", fields{signer, signer}, args{pub, provisioner.SSHOptions{CertType: "foo"}, []provisioner.SignOption{}}, want{}, true}, {"fail-cert-validator", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertValidator("an error")}}, want{}, true}, {"fail-cert-modifier", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestCertModifier("an error")}}, want{}, true}, {"fail-opts-validator", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, {"fail-opts-modifier", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, {"fail-bad-sign-options", fields{signer, signer}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{userOptions, "wrong type"}}, want{}, true}, {"fail-no-user-key", fields{nil, signer}, args{pub, provisioner.SSHOptions{CertType: "user"}, []provisioner.SignOption{}}, want{}, true}, {"fail-no-host-key", fields{signer, nil}, args{pub, provisioner.SSHOptions{CertType: "host"}, []provisioner.SignOption{}}, want{}, true}, {"fail-bad-type", fields{signer, nil}, args{pub, provisioner.SSHOptions{}, []provisioner.SignOption{sshTestModifier{CertType: 0}}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil && assert.NotNil(t, got) { assert.Equals(t, tt.want.CertType, got.CertType) assert.Equals(t, tt.want.Principals, got.ValidPrincipals) assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) assert.NotNil(t, got.Key) assert.NotNil(t, got.Nonce) assert.NotEquals(t, 0, got.Serial) assert.NotNil(t, got.Signature) assert.NotNil(t, got.SignatureKey) } }) } } func TestAuthority_SignSSHAddUser(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer addUserPrincipal string addUserCommand string } type args struct { key ssh.PublicKey subject *ssh.Certificate } type want struct { CertType uint32 Principals []string ValidAfter uint64 ValidBefore uint64 ForceCommand string } now := time.Now() validCert := &ssh.Certificate{ CertType: ssh.UserCert, ValidPrincipals: []string{"user"}, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), } validWant := want{ CertType: ssh.UserCert, Principals: []string{"provisioner"}, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), ForceCommand: "sudo useradd -m user; nc -q0 localhost 22", } tests := []struct { name string fields fields args args want want wantErr bool }{ {"ok", fields{signer, signer, "", ""}, args{pub, validCert}, validWant, false}, {"ok-no-host-key", fields{signer, nil, "", ""}, args{pub, validCert}, validWant, false}, {"ok-custom-principal", fields{signer, signer, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false}, {"ok-custom-command", fields{signer, signer, "", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false}, {"ok-custom-principal-and-command", fields{signer, signer, "my-principal", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false}, {"fail-no-user-key", fields{nil, signer, "", ""}, args{pub, validCert}, want{}, true}, {"fail-no-user-cert", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true}, {"fail-no-principals", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true}, {"fail-many-principals", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey a.config.SSH = &SSHConfig{ AddUserPrincipal: tt.fields.addUserPrincipal, AddUserCommand: tt.fields.addUserCommand, } got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil && assert.NotNil(t, got) { assert.Equals(t, tt.want.CertType, got.CertType) assert.Equals(t, tt.want.Principals, got.ValidPrincipals) assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId) assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions) assert.Equals(t, nil, got.Extensions) assert.NotNil(t, got.Key) assert.NotNil(t, got.Nonce) assert.NotEquals(t, 0, got.Serial) assert.NotNil(t, got.Signature) assert.NotNil(t, got.SignatureKey) } }) } } func TestAuthority_GetSSHRoots(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) type fields struct { sshCAUserCerts []ssh.PublicKey sshCAHostCerts []ssh.PublicKey } tests := []struct { name string fields fields want *SSHKeys wantErr bool }{ {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, {"nil", fields{}, &SSHKeys{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCerts = tt.fields.sshCAUserCerts a.sshCAHostCerts = tt.fields.sshCAHostCerts got, err := a.GetSSHRoots(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHRoots() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHFederation(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) type fields struct { sshCAUserFederatedCerts []ssh.PublicKey sshCAHostFederatedCerts []ssh.PublicKey } tests := []struct { name string fields fields want *SSHKeys wantErr bool }{ {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, {"nil", fields{}, &SSHKeys{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts got, err := a.GetSSHFederation(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHFederation() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHConfig(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) userSigner, err := ssh.NewSignerFromSigner(key) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) hostSigner, err := ssh.NewSignerFromSigner(key) assert.FatalError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tmplConfig := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "known_host.tpl", Type: templates.File, TemplatePath: "./testdata/templates/known_hosts.tpl", Path: "ssh/known_host", Comment: "#"}, }, Host: []templates.Template{ {Name: "ca.tpl", Type: templates.File, TemplatePath: "./testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, }, }, Data: map[string]interface{}{ "Step": &templates.Step{ SSH: templates.StepSSH{ UserKey: user, HostKey: host, }, }, }, } userOutput := []templates.Output{ {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64))}, } hostOutput := []templates.Output{ {Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte(user.Type() + " " + userB64)}, } tmplConfigWithUserData := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "include.tpl", Type: templates.File, TemplatePath: "./testdata/templates/include.tpl", Path: "ssh/include", Comment: "#"}, {Name: "config.tpl", Type: templates.File, TemplatePath: "./testdata/templates/config.tpl", Path: "ssh/config", Comment: "#"}, }, Host: []templates.Template{ {Name: "sshd_config.tpl", Type: templates.File, TemplatePath: "./testdata/templates/sshd_config.tpl", Path: "/etc/ssh/sshd_config", Comment: "#"}, }, }, Data: map[string]interface{}{ "Step": &templates.Step{ SSH: templates.StepSSH{ UserKey: user, HostKey: host, }, }, }, } userOutputWithUserData := []templates.Output{ {Name: "include.tpl", Type: templates.File, Comment: "#", Path: "ssh/include", Content: []byte("Host *\n\tInclude /home/user/.step/ssh/config")}, {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("Match exec \"step ssh check-host %h\"\n\tUserKnownHostsFile /home/user/.step/ssh/known_hosts\n\tProxyCommand step ssh proxycommand %r %h %p\n")}, } hostOutputWithUserData := []templates.Output{ {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("TrustedUserCAKeys /etc/ssh/ca.pub\nHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\nHostKey /etc/ssh/ssh_host_ecdsa_key")}, } tmplConfigErr := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "error.tpl", Type: templates.File, TemplatePath: "./testdata/templates/error.tpl", Path: "ssh/error", Comment: "#"}, }, Host: []templates.Template{ {Name: "error.tpl", Type: templates.File, TemplatePath: "./testdata/templates/error.tpl", Path: "ssh/error", Comment: "#"}, }, }, } type fields struct { templates *templates.Templates userSigner ssh.Signer hostSigner ssh.Signer } type args struct { typ string data map[string]string } tests := []struct { name string fields fields args args want []templates.Output wantErr bool }{ {"user", fields{tmplConfig, userSigner, hostSigner}, args{"user", nil}, userOutput, false}, {"user", fields{tmplConfig, userSigner, nil}, args{"user", nil}, userOutput, false}, {"host", fields{tmplConfig, userSigner, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"hostError", fields{tmplConfigErr, userSigner, hostSigner}, args{"host", map[string]string{"Function": "foo"}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.config.Templates = tt.fields.templates a.sshCAUserCertSignKey = tt.fields.userSigner a.sshCAHostCertSignKey = tt.fields.hostSigner got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHConfig() = %v, want %v", got, tt.want) } }) } } func TestAuthority_CheckSSHHost(t *testing.T) { type fields struct { exists bool err error } type args struct { ctx context.Context principal string token string } tests := []struct { name string fields fields args args want bool wantErr bool }{ {"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false}, {"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false}, {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsSSHHost: func(_ string) (bool, error) { return tt.fields.exists, tt.fields.err }, } got, err := a.CheckSSHHost(tt.args.ctx, tt.args.principal, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Authority.CheckSSHHost() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Authority.CheckSSHHost() = %v, want %v", got, tt.want) } }) } } func TestSSHConfig_Validate(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) tests := []struct { name string sshConfig *SSHConfig wantErr bool }{ {"nil", nil, false}, {"ok", &SSHConfig{Keys: []*SSHPublicKey{{Type: "user", Key: key.Public()}}}, false}, {"ok", &SSHConfig{Keys: []*SSHPublicKey{{Type: "host", Key: key.Public()}}}, false}, {"badType", &SSHConfig{Keys: []*SSHPublicKey{{Type: "bad", Key: key.Public()}}}, true}, {"badKey", &SSHConfig{Keys: []*SSHPublicKey{{Type: "user", Key: *key}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.sshConfig.Validate(); (err != nil) != tt.wantErr { t.Errorf("SSHConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestSSHPublicKey_Validate(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) type fields struct { Type string Federated bool Key jose.JSONWebKey } tests := []struct { name string fields fields wantErr bool }{ {"user", fields{"user", true, key.Public()}, false}, {"host", fields{"host", false, key.Public()}, false}, {"empty", fields{"", true, key.Public()}, true}, {"badType", fields{"bad", false, key.Public()}, true}, {"badKey", fields{"user", false, *key}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &SSHPublicKey{ Type: tt.fields.Type, Federated: tt.fields.Federated, Key: tt.fields.Key, } if err := k.Validate(); (err != nil) != tt.wantErr { t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestSSHPublicKey_PublicKey(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public().Key) assert.FatalError(t, err) type fields struct { publicKey ssh.PublicKey } tests := []struct { name string fields fields want ssh.PublicKey }{ {"ok", fields{pub}, pub}, {"nil", fields{nil}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &SSHPublicKey{ publicKey: tt.fields.publicKey, } if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) { t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHBastion(t *testing.T) { bastion := &Bastion{ Hostname: "bastion.local", Port: "2222", } type fields struct { config *Config sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) } type args struct { user string hostname string } tests := []struct { name string fields fields args args want *Bastion wantErr bool }{ {"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false}, {"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false}, {"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false}, {"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, {"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, {"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, sshBastionFunc: tt.fields.sshBastionFunc, } got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { _, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHHosts(t *testing.T) { a := testAuthority(t) type test struct { getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error) auth *Authority cert *x509.Certificate cmp func(got []sshutil.Host) err error code int } tests := map[string]func(t *testing.T) *test{ "fail/getHostsFunc-fail": func(t *testing.T) *test { return &test{ getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return nil, errors.New("force") }, cert: &x509.Certificate{}, err: errors.New("getSSHHosts: force"), code: http.StatusInternalServerError, } }, "ok/getHostsFunc-defined": func(t *testing.T) *test { hosts := []sshutil.Host{ {HostID: "1", Hostname: "foo"}, {HostID: "2", Hostname: "bar"}, } return &test{ getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return hosts, nil }, cert: &x509.Certificate{}, cmp: func(got []sshutil.Host) { assert.Equals(t, got, hosts) }, } }, "fail/db-get-fail": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MGetSSHHostPrincipals: func() ([]string, error) { return nil, errors.New("force") }, })), cert: &x509.Certificate{}, err: errors.New("getSSHHosts: force"), code: http.StatusInternalServerError, } }, "ok": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MGetSSHHostPrincipals: func() ([]string, error) { return []string{"foo", "bar"}, nil }, })), cert: &x509.Certificate{}, cmp: func(got []sshutil.Host) { assert.Equals(t, got, []sshutil.Host{ {Hostname: "foo"}, {Hostname: "bar"}, }) }, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) auth := tc.auth if auth == nil { auth = a } auth.sshGetHostsFunc = tc.getHostsFunc hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.cmp(hosts) } } }) } } func TestAuthority_RekeySSH(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) userOptions := sshTestModifier{ CertType: ssh.UserCert, } now := time.Now().UTC() a := testAuthority(t) type test struct { auth *Authority userSigner ssh.Signer hostSigner ssh.Signer cert *ssh.Certificate key ssh.PublicKey signOpts []provisioner.SignOption cmpResult func(old, n *ssh.Certificate) err error code int } tests := map[string]func(t *testing.T) *test{ "fail/opts-type": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, key: pub, signOpts: []provisioner.SignOption{userOptions}, err: errors.New("rekeySSH; invalid extra option type"), code: http.StatusInternalServerError, } }, "fail/old-cert-validAfter": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; cannot rekey certificate without validity period"), code: http.StatusBadRequest, } }, "fail/old-cert-validBefore": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; cannot rekey certificate without validity period"), code: http.StatusBadRequest, } }, "fail/old-cert-no-user-key": func(t *testing.T) *test { return &test{ userSigner: nil, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; user certificate signing is not enabled"), code: http.StatusNotImplemented, } }, "fail/old-cert-no-host-key": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; host certificate signing is not enabled"), code: http.StatusNotImplemented, } }, "fail/unexpected-old-cert-type": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), code: http.StatusBadRequest, } }, "fail/db-store": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MStoreSSHCertificate: func(cert *ssh.Certificate) error { return errors.New("force") }, })), userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; error storing certificate in db: force"), code: http.StatusInternalServerError, } }, "ok": func(t *testing.T) *test { va1 := now.Add(-24 * time.Hour) vb1 := now.Add(-23 * time.Hour) return &test{ userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ ValidAfter: uint64(va1.Unix()), ValidBefore: uint64(vb1.Unix()), CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}, KeyId: "foo", }, key: pub, signOpts: []provisioner.SignOption{}, cmpResult: func(old, n *ssh.Certificate) { assert.Equals(t, n.CertType, old.CertType) assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals) assert.Equals(t, n.KeyId, old.KeyId) assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix())) assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix())) l8r := now.Add(1 * time.Hour) assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix())) assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix())) }, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) auth := tc.auth if auth == nil { auth = a } a.sshCAUserCertSignKey = tc.userSigner a.sshCAHostCertSignKey = tc.hostSigner cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.cmpResult(tc.cert, cert) } } }) } }