|
|
@ -1,9 +1,17 @@
|
|
|
|
package identity
|
|
|
|
package identity
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
import (
|
|
|
|
|
|
|
|
"crypto"
|
|
|
|
"crypto/tls"
|
|
|
|
"crypto/tls"
|
|
|
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"path/filepath"
|
|
|
|
"reflect"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/cli/crypto/pemutil"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/certificates/api"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func TestLoadDefaultIdentity(t *testing.T) {
|
|
|
|
func TestLoadDefaultIdentity(t *testing.T) {
|
|
|
@ -164,3 +172,83 @@ func Test_fileExists(t *testing.T) {
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func TestWriteDefaultIdentity(t *testing.T) {
|
|
|
|
|
|
|
|
tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oldConfigDir := configDir
|
|
|
|
|
|
|
|
oldIdentityDir := identityDir
|
|
|
|
|
|
|
|
oldIdentityFile := IdentityFile
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
|
|
configDir = oldConfigDir
|
|
|
|
|
|
|
|
identityDir = oldIdentityDir
|
|
|
|
|
|
|
|
IdentityFile = oldIdentityFile
|
|
|
|
|
|
|
|
os.RemoveAll(tmpDir)
|
|
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
key, err := pemutil.Read("testdata/identity/identity_key")
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var certChain []api.Certificate
|
|
|
|
|
|
|
|
for _, c := range certs {
|
|
|
|
|
|
|
|
certChain = append(certChain, api.Certificate{Certificate: c})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "config")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "identity")
|
|
|
|
|
|
|
|
IdentityFile = filepath.Join(tmpDir, "config", "identity.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type args struct {
|
|
|
|
|
|
|
|
certChain []api.Certificate
|
|
|
|
|
|
|
|
key crypto.PrivateKey
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
|
|
name string
|
|
|
|
|
|
|
|
prepare func()
|
|
|
|
|
|
|
|
args args
|
|
|
|
|
|
|
|
wantErr bool
|
|
|
|
|
|
|
|
}{
|
|
|
|
|
|
|
|
{"ok", func() {}, args{certChain, key}, false},
|
|
|
|
|
|
|
|
{"fail mkdir config", func() {
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "identity", "identity.crt")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "identity")
|
|
|
|
|
|
|
|
}, args{certChain, key}, true},
|
|
|
|
|
|
|
|
{"fail mkdir identity", func() {
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "config")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "identity", "identity.crt")
|
|
|
|
|
|
|
|
}, args{certChain, key}, true},
|
|
|
|
|
|
|
|
{"fail certificate", func() {
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "config")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "bad-dir")
|
|
|
|
|
|
|
|
os.MkdirAll(identityDir, 0600)
|
|
|
|
|
|
|
|
}, args{certChain, key}, true},
|
|
|
|
|
|
|
|
{"fail key", func() {
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "config")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "identity")
|
|
|
|
|
|
|
|
}, args{certChain, "badKey"}, true},
|
|
|
|
|
|
|
|
{"fail write identity", func() {
|
|
|
|
|
|
|
|
configDir = filepath.Join(tmpDir, "bad-dir")
|
|
|
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "identity")
|
|
|
|
|
|
|
|
IdentityFile = filepath.Join(configDir, "identity.json")
|
|
|
|
|
|
|
|
os.MkdirAll(configDir, 0600)
|
|
|
|
|
|
|
|
}, args{certChain, key}, true},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
|
|
tt.prepare()
|
|
|
|
|
|
|
|
if err := WriteDefaultIdentity(tt.args.certChain, tt.args.key); (err != nil) != tt.wantErr {
|
|
|
|
|
|
|
|
t.Errorf("WriteDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|