|
|
|
@ -3,15 +3,17 @@ package identity
|
|
|
|
|
import (
|
|
|
|
|
"crypto"
|
|
|
|
|
"crypto/tls"
|
|
|
|
|
"crypto/x509"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"net/http"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"reflect"
|
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/cli/crypto/pemutil"
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/certificates/api"
|
|
|
|
|
"github.com/smallstep/cli/crypto/pemutil"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func TestLoadDefaultIdentity(t *testing.T) {
|
|
|
|
@ -252,3 +254,97 @@ func TestWriteDefaultIdentity(t *testing.T) {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type renewer struct {
|
|
|
|
|
pool *x509.CertPool
|
|
|
|
|
sign *api.SignResponse
|
|
|
|
|
err error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (r *renewer) GetRootCAs() *x509.CertPool {
|
|
|
|
|
return r.pool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (r *renewer) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
|
|
|
|
return r.sign, r.err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestIdentity_Renew(t *testing.T) {
|
|
|
|
|
tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
oldIdentityDir := identityDir
|
|
|
|
|
defer func() {
|
|
|
|
|
identityDir = oldIdentityDir
|
|
|
|
|
os.RemoveAll(tmpDir)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ok := &renewer{
|
|
|
|
|
sign: &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: certs[0]},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: certs[1]},
|
|
|
|
|
CertChainPEM: []api.Certificate{
|
|
|
|
|
{Certificate: certs[0]},
|
|
|
|
|
{Certificate: certs[1]},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
okOld := &renewer{
|
|
|
|
|
sign: &api.SignResponse{
|
|
|
|
|
ServerPEM: api.Certificate{Certificate: certs[0]},
|
|
|
|
|
CaPEM: api.Certificate{Certificate: certs[1]},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fail := &renewer{
|
|
|
|
|
err: fmt.Errorf("an error"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type fields struct {
|
|
|
|
|
Type string
|
|
|
|
|
Certificate string
|
|
|
|
|
Key string
|
|
|
|
|
}
|
|
|
|
|
type args struct {
|
|
|
|
|
client Renewer
|
|
|
|
|
}
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
prepare func()
|
|
|
|
|
fields fields
|
|
|
|
|
args args
|
|
|
|
|
wantErr bool
|
|
|
|
|
}{
|
|
|
|
|
{"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false},
|
|
|
|
|
{"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false},
|
|
|
|
|
{"ok disabled", func() {}, fields{}, args{nil}, false},
|
|
|
|
|
{"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
|
|
|
|
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
|
|
|
|
|
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
|
|
|
|
{"fail write identity", func() {
|
|
|
|
|
identityDir = filepath.Join(tmpDir, "bad-dir")
|
|
|
|
|
os.MkdirAll(identityDir, 0600)
|
|
|
|
|
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
tt.prepare()
|
|
|
|
|
i := &Identity{
|
|
|
|
|
Type: tt.fields.Type,
|
|
|
|
|
Certificate: tt.fields.Certificate,
|
|
|
|
|
Key: tt.fields.Key,
|
|
|
|
|
}
|
|
|
|
|
if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|