diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 4cbcc3a2..9ab14e94 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -76,11 +76,38 @@ func TestLoadClient(t *testing.T) { want *Client wantErr bool }{ - {"ok", func() { IdentityFile = "testdata/config/identity.json"; DefaultsFile = "testdata/config/defaults.json" }, expected, false}, - {"fail identity", func() { IdentityFile = "testdata/config/missing.json"; DefaultsFile = "testdata/config/defaults.json" }, nil, true}, - {"fail identity", func() { IdentityFile = "testdata/config/fail.json"; DefaultsFile = "testdata/config/defaults.json" }, nil, true}, - {"fail defaults", func() { IdentityFile = "testdata/config/identity.json"; DefaultsFile = "testdata/config/missing.json" }, nil, true}, - {"fail defaults", func() { IdentityFile = "testdata/config/identity.json"; DefaultsFile = "testdata/config/fail.json" }, nil, true}, + {"ok", func() { + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/defaults.json" + }, expected, false}, + {"fail identity", func() { + IdentityFile = "testdata/config/missing.json" + DefaultsFile = "testdata/config/defaults.json" + }, nil, true}, + {"fail identity", func() { + IdentityFile = "testdata/config/fail.json" + DefaultsFile = "testdata/config/defaults.json" + }, nil, true}, + {"fail defaults", func() { + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/missing.json" + }, nil, true}, + {"fail defaults", func() { + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/fail.json" + }, nil, true}, + {"fail ca", func() { + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/badca.json" + }, nil, true}, + {"fail root", func() { + IdentityFile = "testdata/config/identity.json" + DefaultsFile = "testdata/config/badroot.json" + }, nil, true}, + {"fail type", func() { + IdentityFile = "testdata/config/badIdentity.json" + DefaultsFile = "testdata/config/defaults.json" + }, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/ca/identity/identity.go b/ca/identity/identity.go index 35736236..48ed66e6 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -58,21 +58,26 @@ func LoadDefaultIdentity() (*Identity, error) { return identity, nil } +// configDir and identityDir are used in WriteDefaultIdentity for testing +// purposes. +var ( + configDir = filepath.Join(config.StepPath(), "config") + identityDir = filepath.Join(config.StepPath(), "identity") +) + // WriteDefaultIdentity writes the given certificates and key and the // identity.json pointing to the new files. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { - base := filepath.Join(config.StepPath(), "config") - if err := os.MkdirAll(base, 0700); err != nil { + if err := os.MkdirAll(configDir, 0700); err != nil { return errors.Wrap(err, "error creating config directory") } - base = filepath.Join(config.StepPath(), "identity") - if err := os.MkdirAll(base, 0700); err != nil { + if err := os.MkdirAll(identityDir, 0700); err != nil { return errors.Wrap(err, "error creating identity directory") } - certFilename := filepath.Join(base, "identity.crt") - keyFilename := filepath.Join(base, "identity_key") + certFilename := filepath.Join(identityDir, "identity.crt") + keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate buf := new(bytes.Buffer) diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 58f5db71..1a73afdb 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -1,9 +1,17 @@ package identity import ( + "crypto" "crypto/tls" + "io/ioutil" + "os" + "path/filepath" "reflect" "testing" + + "github.com/smallstep/cli/crypto/pemutil" + + "github.com/smallstep/certificates/api" ) func TestLoadDefaultIdentity(t *testing.T) { @@ -164,3 +172,83 @@ func Test_fileExists(t *testing.T) { }) } } + +func TestWriteDefaultIdentity(t *testing.T) { + tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests") + if err != nil { + t.Fatal(err) + } + + oldConfigDir := configDir + oldIdentityDir := identityDir + oldIdentityFile := IdentityFile + defer func() { + configDir = oldConfigDir + identityDir = oldIdentityDir + IdentityFile = oldIdentityFile + os.RemoveAll(tmpDir) + }() + + certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt") + if err != nil { + t.Fatal(err) + } + key, err := pemutil.Read("testdata/identity/identity_key") + if err != nil { + t.Fatal(err) + } + + var certChain []api.Certificate + for _, c := range certs { + certChain = append(certChain, api.Certificate{Certificate: c}) + } + + configDir = filepath.Join(tmpDir, "config") + identityDir = filepath.Join(tmpDir, "identity") + IdentityFile = filepath.Join(tmpDir, "config", "identity.json") + + type args struct { + certChain []api.Certificate + key crypto.PrivateKey + } + tests := []struct { + name string + prepare func() + args args + wantErr bool + }{ + {"ok", func() {}, args{certChain, key}, false}, + {"fail mkdir config", func() { + configDir = filepath.Join(tmpDir, "identity", "identity.crt") + identityDir = filepath.Join(tmpDir, "identity") + }, args{certChain, key}, true}, + {"fail mkdir identity", func() { + configDir = filepath.Join(tmpDir, "config") + identityDir = filepath.Join(tmpDir, "identity", "identity.crt") + }, args{certChain, key}, true}, + {"fail certificate", func() { + configDir = filepath.Join(tmpDir, "config") + identityDir = filepath.Join(tmpDir, "bad-dir") + os.MkdirAll(identityDir, 0600) + }, args{certChain, key}, true}, + {"fail key", func() { + configDir = filepath.Join(tmpDir, "config") + identityDir = filepath.Join(tmpDir, "identity") + }, args{certChain, "badKey"}, true}, + {"fail write identity", func() { + configDir = filepath.Join(tmpDir, "bad-dir") + identityDir = filepath.Join(tmpDir, "identity") + IdentityFile = filepath.Join(configDir, "identity.json") + os.MkdirAll(configDir, 0600) + }, args{certChain, key}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.prepare() + if err := WriteDefaultIdentity(tt.args.certChain, tt.args.key); (err != nil) != tt.wantErr { + t.Errorf("WriteDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/ca/identity/testdata/config/badIdentity.json b/ca/identity/testdata/config/badIdentity.json new file mode 100644 index 00000000..f1a73ecd --- /dev/null +++ b/ca/identity/testdata/config/badIdentity.json @@ -0,0 +1,5 @@ +{ + "type": "", + "crt": "testdata/identity/identity.crt", + "key": "testdata/identity/identity_key" +} \ No newline at end of file diff --git a/ca/identity/testdata/config/badca.json b/ca/identity/testdata/config/badca.json new file mode 100644 index 00000000..29327ffb --- /dev/null +++ b/ca/identity/testdata/config/badca.json @@ -0,0 +1,6 @@ +{ + "ca-url": ":", + "ca-config": "testdata/config/ca.json", + "fingerprint": "9dc35eef23a234b2520516a3169090d7ec2fc61323bdd6e4fde08bcfec5d0931", + "root": "testdata/certs/root_ca.crt" +} \ No newline at end of file diff --git a/ca/identity/testdata/config/badroot.json b/ca/identity/testdata/config/badroot.json new file mode 100644 index 00000000..50e86d5e --- /dev/null +++ b/ca/identity/testdata/config/badroot.json @@ -0,0 +1,6 @@ +{ + "ca-url": "https://127.0.0.1", + "ca-config": "testdata/config/ca.json", + "fingerprint": "9dc35eef23a234b2520516a3169090d7ec2fc61323bdd6e4fde08bcfec5d0931", + "root": "testdata/certs/missing.crt" +} \ No newline at end of file