|
|
|
@ -1,6 +1,8 @@
|
|
|
|
|
package provisioner
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"reflect"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
@ -10,6 +12,32 @@ import (
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func TestSSHOptions_Type(t *testing.T) {
|
|
|
|
|
type fields struct {
|
|
|
|
|
CertType string
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
fields fields
|
|
|
|
|
want uint32
|
|
|
|
|
}{
|
|
|
|
|
{"user", fields{"user"}, 1},
|
|
|
|
|
{"host", fields{"host"}, 2},
|
|
|
|
|
{"empty", fields{""}, 0},
|
|
|
|
|
{"invalid", fields{"invalid"}, 0},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
o := SSHOptions{
|
|
|
|
|
CertType: tt.fields.CertType,
|
|
|
|
|
}
|
|
|
|
|
if got := o.Type(); got != tt.want {
|
|
|
|
|
t.Errorf("SSHOptions.Type() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
|
|
|
|
pub, _, err := keys.GenerateDefaultKeyPair()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
@ -272,11 +300,13 @@ func Test_sshCertificateValidityValidator(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
n := now()
|
|
|
|
|
n, fn := mockNow()
|
|
|
|
|
defer fn()
|
|
|
|
|
|
|
|
|
|
p, err := generateX5C(nil)
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
type test struct {
|
|
|
|
|
svm *sshValidityModifier
|
|
|
|
|
svm *sshLimitDuration
|
|
|
|
|
cert *ssh.Certificate
|
|
|
|
|
valid func(*ssh.Certificate)
|
|
|
|
|
err error
|
|
|
|
@ -284,7 +314,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
tests := map[string]func() test{
|
|
|
|
|
"fail/type-not-set": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(6 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
ValidAfter: uint64(n.Unix()),
|
|
|
|
|
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
|
|
|
|
@ -294,18 +324,18 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
},
|
|
|
|
|
"fail/type-not-recognized": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(6 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 4,
|
|
|
|
|
ValidAfter: uint64(n.Unix()),
|
|
|
|
|
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
|
|
|
|
|
},
|
|
|
|
|
err: errors.New("unknown ssh certificate type 4"),
|
|
|
|
|
err: errors.New("ssh certificate has an unknown type: 4"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"fail/requested-validAfter-after-limit": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(1 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
|
|
|
|
@ -316,7 +346,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
},
|
|
|
|
|
"fail/requested-validBefore-after-limit": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(1 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
ValidAfter: uint64(n.Unix()),
|
|
|
|
@ -325,10 +355,36 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
err: errors.New("provisioning credential expiration ("),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/no-limit": func() test {
|
|
|
|
|
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
},
|
|
|
|
|
valid: func(cert *ssh.Certificate) {
|
|
|
|
|
assert.Equals(t, cert.ValidAfter, va)
|
|
|
|
|
assert.Equals(t, cert.ValidBefore, vb)
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/defaults": func() test {
|
|
|
|
|
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
},
|
|
|
|
|
valid: func(cert *ssh.Certificate) {
|
|
|
|
|
assert.Equals(t, cert.ValidAfter, va)
|
|
|
|
|
assert.Equals(t, cert.ValidBefore, vb)
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/valid-requested-validBefore": func() test {
|
|
|
|
|
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(3 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
ValidAfter: va,
|
|
|
|
@ -343,21 +399,21 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
"ok/empty-requested-validBefore-limit-after-default": func() test {
|
|
|
|
|
va := uint64(n.Unix())
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(5 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
ValidAfter: va,
|
|
|
|
|
},
|
|
|
|
|
valid: func(cert *ssh.Certificate) {
|
|
|
|
|
assert.Equals(t, cert.ValidAfter, va)
|
|
|
|
|
assert.Equals(t, cert.ValidBefore, uint64(n.Add(4*time.Hour).Unix()))
|
|
|
|
|
assert.Equals(t, cert.ValidBefore, uint64(n.Add(16*time.Hour).Unix()))
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/empty-requested-validBefore-limit-before-default": func() test {
|
|
|
|
|
va := uint64(n.Unix())
|
|
|
|
|
return test{
|
|
|
|
|
svm: &sshValidityModifier{Claimer: p.claimer, validBefore: n.Add(3 * time.Hour)},
|
|
|
|
|
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
|
|
|
|
cert: &ssh.Certificate{
|
|
|
|
|
CertType: 1,
|
|
|
|
|
ValidAfter: va,
|
|
|
|
@ -372,7 +428,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tt := run()
|
|
|
|
|
if err := tt.svm.Modify(tt.cert); err != nil {
|
|
|
|
|
if err := tt.svm.Option(SSHOptions{}).Modify(tt.cert); err != nil {
|
|
|
|
|
if assert.NotNil(t, tt.err) {
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
}
|
|
|
|
@ -384,3 +440,120 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshModifierFunc_Modify(t *testing.T) {
|
|
|
|
|
type args struct {
|
|
|
|
|
cert *ssh.Certificate
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
f sshModifierFunc
|
|
|
|
|
args args
|
|
|
|
|
wantErr bool
|
|
|
|
|
}{
|
|
|
|
|
{"ok", func(cert *ssh.Certificate) error { return nil }, args{&ssh.Certificate{}}, false},
|
|
|
|
|
{"fail", func(cert *ssh.Certificate) error { return fmt.Errorf("an error") }, args{&ssh.Certificate{}}, true},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
if err := tt.f.Modify(tt.args.cert); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("sshModifierFunc.Modify() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshDefaultDuration_Option(t *testing.T) {
|
|
|
|
|
tm, fn := mockNow()
|
|
|
|
|
defer fn()
|
|
|
|
|
|
|
|
|
|
newClaimer := func(claims *Claims) *Claimer {
|
|
|
|
|
c, err := NewClaimer(claims, globalProvisionerClaims)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
unix := func(d time.Duration) uint64 {
|
|
|
|
|
return uint64(tm.Add(d).Unix())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type fields struct {
|
|
|
|
|
Claimer *Claimer
|
|
|
|
|
}
|
|
|
|
|
type args struct {
|
|
|
|
|
o SSHOptions
|
|
|
|
|
cert *ssh.Certificate
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
fields fields
|
|
|
|
|
args args
|
|
|
|
|
want *ssh.Certificate
|
|
|
|
|
wantErr bool
|
|
|
|
|
}{
|
|
|
|
|
{"user", fields{newClaimer(nil)}, args{SSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(16 * time.Hour)}, false},
|
|
|
|
|
{"host", fields{newClaimer(nil)}, args{SSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(30 * 24 * time.Hour)}, false},
|
|
|
|
|
{"user claim", fields{newClaimer(&Claims{DefaultUserSSHDur: &Duration{1 * time.Hour}})}, args{SSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false},
|
|
|
|
|
{"host claim", fields{newClaimer(&Claims{DefaultHostSSHDur: &Duration{1 * time.Hour}})}, args{SSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false},
|
|
|
|
|
{"user backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(16 * time.Hour)}, false},
|
|
|
|
|
{"host backdate", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false},
|
|
|
|
|
{"user validAfter", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Minute), ValidBefore: unix(17 * time.Hour)}, false},
|
|
|
|
|
{"user validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false},
|
|
|
|
|
{"host validAfter validBefore", fields{newClaimer(nil)}, args{SSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}},
|
|
|
|
|
&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}, false},
|
|
|
|
|
{"fail zero", fields{newClaimer(nil)}, args{SSHOptions{}, &ssh.Certificate{}}, &ssh.Certificate{}, true},
|
|
|
|
|
{"fail type", fields{newClaimer(nil)}, args{SSHOptions{}, &ssh.Certificate{CertType: 3}}, &ssh.Certificate{CertType: 3}, true},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
m := &sshDefaultDuration{
|
|
|
|
|
Claimer: tt.fields.Claimer,
|
|
|
|
|
}
|
|
|
|
|
v := m.Option(tt.args.o)
|
|
|
|
|
if err := v.Modify(tt.args.cert); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("sshDefaultDuration.Option() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(tt.args.cert, tt.want) {
|
|
|
|
|
t.Errorf("sshDefaultDuration.Option() = %v, want %v", tt.args.cert, tt.want)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshLimitDuration_Option(t *testing.T) {
|
|
|
|
|
type fields struct {
|
|
|
|
|
Claimer *Claimer
|
|
|
|
|
NotAfter time.Time
|
|
|
|
|
}
|
|
|
|
|
type args struct {
|
|
|
|
|
o SSHOptions
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
fields fields
|
|
|
|
|
args args
|
|
|
|
|
want SSHCertificateModifier
|
|
|
|
|
}{
|
|
|
|
|
// TODO: Add test cases.
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
m := &sshLimitDuration{
|
|
|
|
|
Claimer: tt.fields.Claimer,
|
|
|
|
|
NotAfter: tt.fields.NotAfter,
|
|
|
|
|
}
|
|
|
|
|
if got := m.Option(tt.args.o); !reflect.DeepEqual(got, tt.want) {
|
|
|
|
|
t.Errorf("sshLimitDuration.Option() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|