|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
package provisioner
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"reflect"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
@ -40,7 +39,7 @@ func TestSSHOptions_Type(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
func TestSSHOptions_Modify(t *testing.T) {
|
|
|
|
|
type test struct {
|
|
|
|
|
so *SignSSHOptions
|
|
|
|
|
so SignSSHOptions
|
|
|
|
|
cert *ssh.Certificate
|
|
|
|
|
valid func(*ssh.Certificate)
|
|
|
|
|
err error
|
|
|
|
@ -48,21 +47,21 @@ func TestSSHOptions_Modify(t *testing.T) {
|
|
|
|
|
tests := map[string](func() test){
|
|
|
|
|
"fail/unexpected-cert-type": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
so: &SignSSHOptions{CertType: "foo"},
|
|
|
|
|
so: SignSSHOptions{CertType: "foo"},
|
|
|
|
|
cert: new(ssh.Certificate),
|
|
|
|
|
err: errors.Errorf("ssh certificate has an unknown type - foo"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"fail/validAfter-greater-validBefore": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
so: &SignSSHOptions{CertType: "user"},
|
|
|
|
|
so: SignSSHOptions{CertType: "user"},
|
|
|
|
|
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
|
|
|
|
|
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"ok/user-cert": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
so: &SignSSHOptions{CertType: "user"},
|
|
|
|
|
so: SignSSHOptions{CertType: "user"},
|
|
|
|
|
cert: new(ssh.Certificate),
|
|
|
|
|
valid: func(cert *ssh.Certificate) {
|
|
|
|
|
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
|
|
|
|
@ -71,7 +70,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
|
|
|
|
},
|
|
|
|
|
"ok/host-cert": func() test {
|
|
|
|
|
return test{
|
|
|
|
|
so: &SignSSHOptions{CertType: "host"},
|
|
|
|
|
so: SignSSHOptions{CertType: "host"},
|
|
|
|
|
cert: new(ssh.Certificate),
|
|
|
|
|
valid: func(cert *ssh.Certificate) {
|
|
|
|
|
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
|
|
|
|
@ -81,7 +80,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
|
|
|
|
"ok": func() test {
|
|
|
|
|
va := time.Now().Add(5 * time.Minute)
|
|
|
|
|
vb := time.Now().Add(1 * time.Hour)
|
|
|
|
|
so := &SignSSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
|
|
|
|
|
so := SignSSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
|
|
|
|
|
ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)}
|
|
|
|
|
return test{
|
|
|
|
|
so: so,
|
|
|
|
@ -99,7 +98,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if err := tc.so.Modify(tc.cert); err != nil {
|
|
|
|
|
if err := tc.so.Modify(tc.cert, tc.so); err != nil {
|
|
|
|
|
if assert.NotNil(t, tc.err) {
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
|
|
|
}
|
|
|
|
@ -222,7 +221,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
|
|
|
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
@ -248,7 +247,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
|
|
|
assert.Equals(t, tc.cert.KeyId, tc.expected)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
@ -287,7 +286,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
|
|
|
assert.Equals(t, tc.cert.CertType, uint32(tc.expected))
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
@ -312,7 +311,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
|
|
|
assert.Equals(t, tc.cert.ValidAfter, tc.expected)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
@ -375,7 +374,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert)) {
|
|
|
|
|
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
|
|
|
|
|
tc.valid(tc.cert)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
@ -476,7 +475,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := run()
|
|
|
|
|
if err := tc.modifier.Modify(tc.cert); err != nil {
|
|
|
|
|
if err := tc.modifier.Modify(tc.cert, SignSSHOptions{}); err != nil {
|
|
|
|
|
if assert.NotNil(t, tc.err) {
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
|
|
|
}
|
|
|
|
@ -908,7 +907,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tt := run()
|
|
|
|
|
if err := tt.svm.Option(SignSSHOptions{}).Modify(tt.cert); err != nil {
|
|
|
|
|
if err := tt.svm.Modify(tt.cert, SignSSHOptions{}); err != nil {
|
|
|
|
|
if assert.NotNil(t, tt.err) {
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
|
|
|
|
}
|
|
|
|
@ -921,28 +920,6 @@ func Test_sshValidityModifier(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshModifierFunc_Modify(t *testing.T) {
|
|
|
|
|
type args struct {
|
|
|
|
|
cert *ssh.Certificate
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
f sshModifierFunc
|
|
|
|
|
args args
|
|
|
|
|
wantErr bool
|
|
|
|
|
}{
|
|
|
|
|
{"ok", func(cert *ssh.Certificate) error { return nil }, args{&ssh.Certificate{}}, false},
|
|
|
|
|
{"fail", func(cert *ssh.Certificate) error { return fmt.Errorf("an error") }, args{&ssh.Certificate{}}, true},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
if err := tt.f.Modify(tt.args.cert); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("sshModifierFunc.Modify() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshDefaultDuration_Option(t *testing.T) {
|
|
|
|
|
tm, fn := mockNow()
|
|
|
|
|
defer fn()
|
|
|
|
@ -998,8 +975,7 @@ func Test_sshDefaultDuration_Option(t *testing.T) {
|
|
|
|
|
m := &sshDefaultDuration{
|
|
|
|
|
Claimer: tt.fields.Claimer,
|
|
|
|
|
}
|
|
|
|
|
v := m.Option(tt.args.o)
|
|
|
|
|
if err := v.Modify(tt.args.cert); (err != nil) != tt.wantErr {
|
|
|
|
|
if err := m.Modify(tt.args.cert, tt.args.o); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("sshDefaultDuration.Option() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
if !reflect.DeepEqual(tt.args.cert, tt.want) {
|
|
|
|
@ -1008,32 +984,3 @@ func Test_sshDefaultDuration_Option(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_sshLimitDuration_Option(t *testing.T) {
|
|
|
|
|
type fields struct {
|
|
|
|
|
Claimer *Claimer
|
|
|
|
|
NotAfter time.Time
|
|
|
|
|
}
|
|
|
|
|
type args struct {
|
|
|
|
|
o SignSSHOptions
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
fields fields
|
|
|
|
|
args args
|
|
|
|
|
want SSHCertModifier
|
|
|
|
|
}{
|
|
|
|
|
// TODO: Add test cases.
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
m := &sshLimitDuration{
|
|
|
|
|
Claimer: tt.fields.Claimer,
|
|
|
|
|
NotAfter: tt.fields.NotAfter,
|
|
|
|
|
}
|
|
|
|
|
if got := m.Option(tt.args.o); !reflect.DeepEqual(got, tt.want) {
|
|
|
|
|
t.Errorf("sshLimitDuration.Option() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|