package provisioner import ( "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" ) func TestSSHOptions_Type(t *testing.T) { type fields struct { CertType string } tests := []struct { name string fields fields want uint32 }{ {"user", fields{"user"}, 1}, {"host", fields{"host"}, 2}, {"empty", fields{""}, 0}, {"invalid", fields{"invalid"}, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := SignSSHOptions{ CertType: tt.fields.CertType, } if got := o.Type(); got != tt.want { t.Errorf("SSHOptions.Type() = %v, want %v", got, tt.want) } }) } } func TestSSHOptions_Modify(t *testing.T) { type test struct { so SignSSHOptions cert *ssh.Certificate valid func(*ssh.Certificate) err error } tests := map[string]func() test{ "fail/unexpected-cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, cert: new(ssh.Certificate), err: errors.Errorf("ssh certificate has an unknown type 'foo'"), } }, "fail/validAfter-greater-validBefore": func() test { return test{ so: SignSSHOptions{CertType: "user"}, cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)}, err: errors.Errorf("ssh certificate validAfter cannot be greater than validBefore"), } }, "ok/user-cert": func() test { return test{ so: SignSSHOptions{CertType: "user"}, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.UserCert)) }, } }, "ok/host-cert": func() test { return test{ so: SignSSHOptions{CertType: "host"}, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) }, } }, "ok": func() test { va := time.Now().Add(5 * time.Minute) vb := time.Now().Add(1 * time.Hour) so := SignSSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"}, ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)} return test{ so: so, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) assert.Equals(t, cert.KeyId, so.KeyID) assert.Equals(t, cert.ValidPrincipals, so.Principals) assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix())) assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix())) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if err := tc.so.Modify(tc.cert, tc.so); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.valid(tc.cert) } } }) } } func TestSSHOptions_Match(t *testing.T) { type test struct { so SignSSHOptions cmp SignSSHOptions err error } tests := map[string]func() test{ "fail/cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, cmp: SignSSHOptions{CertType: "bar"}, err: errors.Errorf("ssh certificate type does not match - got bar, want foo"), } }, "fail/pricipals": func() test { return test{ so: SignSSHOptions{Principals: []string{"foo"}}, cmp: SignSSHOptions{Principals: []string{"bar"}}, err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"), } }, "fail/validAfter": func() test { return test{ so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))}, err: errors.Errorf("ssh certificate validAfter does not match"), } }, "fail/validBefore": func() test { return test{ so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))}, err: errors.Errorf("ssh certificate validBefore does not match"), } }, "ok/original-empty": func() test { return test{ so: SignSSHOptions{}, cmp: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), }, } }, "ok/cmp-empty": func() test { return test{ cmp: SignSSHOptions{}, so: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), }, } }, "ok/equal": func() test { n := time.Now() va := NewTimeDuration(n.Add(1 * time.Minute)) vb := NewTimeDuration(n.Add(5 * time.Minute)) return test{ cmp: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: va, ValidBefore: vb, }, so: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: va, ValidBefore: vb, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if err := tc.so.match(tc.cmp); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func Test_sshCertValidAfterModifier_Modify(t *testing.T) { type test struct { modifier sshCertValidAfterModifier cert *ssh.Certificate expected uint64 } tests := map[string]func() test{ "ok": func() test { return test{ modifier: sshCertValidAfterModifier(15), cert: new(ssh.Certificate), expected: 15, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) { assert.Equals(t, tc.cert.ValidAfter, tc.expected) } }) } } func Test_sshCertDefaultValidator_Valid(t *testing.T) { pub, _, err := keyutil.GenerateDefaultKeyPair() assert.FatalError(t, err) sshPub, err := ssh.NewPublicKey(pub) assert.FatalError(t, err) v := sshCertDefaultValidator{} tests := []struct { name string cert *ssh.Certificate err error }{ { "fail/zero-nonce", &ssh.Certificate{}, errors.New("ssh certificate nonce cannot be empty"), }, { "fail/nil-key", &ssh.Certificate{Nonce: []byte("foo")}, errors.New("ssh certificate key cannot be nil"), }, { "fail/zero-serial", &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub}, errors.New("ssh certificate serial cannot be 0"), }, { "fail/unexpected-cert-type", // UserCert = 1, HostCert = 2 &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1}, errors.New("ssh certificate has an unknown type '3'"), }, { "fail/empty-cert-key-id", &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1}, errors.New("ssh certificate key id cannot be empty"), }, { "fail/zero-validAfter", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: 0, }, errors.New("ssh certificate validAfter cannot be 0"), }, { "fail/validBefore-past", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Add(-10 * time.Minute).Unix()), ValidBefore: uint64(time.Now().Add(-5 * time.Minute).Unix()), }, errors.New("ssh certificate validBefore cannot be in the past"), }, { "fail/validAfter-after-validBefore", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Add(15 * time.Minute).Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), }, errors.New("ssh certificate validBefore cannot be before validAfter"), }, { "fail/nil-signature-key", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, }, errors.New("ssh certificate signature key cannot be nil"), }, { "fail/nil-signature", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, SignatureKey: sshPub, }, errors.New("ssh certificate signature cannot be nil"), }, { "ok/userCert", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/hostCert", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 2, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/emptyPrincipals", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/empty-extensions", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := v.Valid(tt.cert, SignSSHOptions{}); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string cert *ssh.Certificate opts SignSSHOptions err error }{ { "fail/validAfter-0", &ssh.Certificate{CertType: ssh.UserCert}, SignSSHOptions{}, errors.New("ssh certificate validAfter cannot be 0"), }, { "fail/validBefore-in-past", &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate validBefore cannot be in the past"), }, { "fail/validBefore-before-validAfter", &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate validBefore cannot be before validAfter"), }, { "fail/cert-type-not-set", &ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate type has not been set"), }, { "fail/unexpected-cert-type", &ssh.Certificate{ CertType: 3, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), }, SignSSHOptions{}, errors.New("ssh certificate has an unknown type '3'"), }, { "fail/durationmax", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(48 * time.Hour).Unix()), }, SignSSHOptions{Backdate: time.Second}, errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"), }, { "ok/duration-exactly-max", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()), }, SignSSHOptions{Backdate: time.Second}, nil, }, { "ok", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(8 * time.Hour).Unix()), }, SignSSHOptions{Backdate: time.Second}, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := v.Valid(tt.cert, tt.opts); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } func Test_sshValidityModifier(t *testing.T) { n, fn := mockNow() defer fn() p, err := generateX5C(nil) assert.FatalError(t, err) type test struct { svm *sshLimitDuration cert *ssh.Certificate valid func(*ssh.Certificate) err error } tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.New("ssh certificate type has not been set"), } }, "fail/type-not-recognized": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.New("ssh certificate has an unknown type: 4"), } }, "fail/requested-validAfter-after-limit": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.Errorf("provisioning credential expiration ("), } }, "fail/requested-validBefore-after-limit": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(2 * time.Hour).Unix()), }, err: errors.New("provisioning credential expiration ("), } }, "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, ValidBefore: vb, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, uint64(n.Add(16*time.Hour).Unix())) }, } }, "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, uint64(n.Add(3*time.Hour).Unix())) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := tt.svm.Modify(tt.cert, SignSSHOptions{}); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.Nil(t, tt.err) { tt.valid(tt.cert) } } }) } } func Test_sshDefaultDuration_Option(t *testing.T) { tm, fn := mockNow() defer fn() newClaimer := func(claims *Claims) *Claimer { c, err := NewClaimer(claims, globalProvisionerClaims) if err != nil { t.Fatal(err) } return c } unix := func(d time.Duration) uint64 { return uint64(tm.Add(d).Unix()) } type fields struct { Claimer *Claimer } type args struct { o SignSSHOptions cert *ssh.Certificate } tests := []struct { name string fields fields args args want *ssh.Certificate wantErr bool }{ {"user", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(16 * time.Hour)}, false}, {"host", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(30 * 24 * time.Hour)}, false}, {"user claim", fields{newClaimer(&Claims{DefaultUserSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false}, {"host claim", fields{newClaimer(&Claims{DefaultHostSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false}, {"user backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(16 * time.Hour)}, false}, {"host backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false}, {"user validAfter", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false}, {"user validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false}, {"host validAfter validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}, false}, {"fail zero", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{}}, &ssh.Certificate{}, true}, {"fail type", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: 3}}, &ssh.Certificate{CertType: 3}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &sshDefaultDuration{ Claimer: tt.fields.Claimer, } if err := m.Modify(tt.args.cert, tt.args.o); (err != nil) != tt.wantErr { t.Errorf("sshDefaultDuration.Option() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(tt.args.cert, tt.want) { t.Errorf("sshDefaultDuration.Option() = %v, want %v", tt.args.cert, tt.want) } }) } }