diff --git a/ca/identity/identity.go b/ca/identity/identity.go index 48ed66e6..fa286a50 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -8,6 +8,7 @@ import ( "encoding/json" "encoding/pem" "io/ioutil" + "net/http" "os" "path/filepath" "strings" @@ -191,6 +192,62 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) { } } +// Renewer is that interface that a renew client must implement. +type Renewer interface { + GetRootCAs() *x509.CertPool + Renew(tr http.RoundTripper) (*api.SignResponse, error) +} + +// Renew renews the current identity certificate using a client with a renew +// method. +func (i *Identity) Renew(client Renewer) error { + switch i.Kind() { + case Disabled: + return nil + case MutualTLS: + cert, err := i.TLSCertificate() + if err != nil { + return err + } + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: client.GetRootCAs(), + PreferServerCipherSuites: true, + } + + sign, err := client.Renew(tr) + if err != nil { + return err + } + + if sign.CertChainPEM == nil || len(sign.CertChainPEM) == 0 { + sign.CertChainPEM = []api.Certificate{sign.ServerPEM, sign.CaPEM} + } + + // Write certificate + buf := new(bytes.Buffer) + for _, crt := range sign.CertChainPEM { + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + } + if err := pem.Encode(buf, block); err != nil { + return errors.Wrap(err, "error encoding identity certificate") + } + } + certFilename := filepath.Join(identityDir, "identity.crt") + if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing identity certificate") + } + + return nil + default: + return errors.Errorf("unsupported identity type %s", i.Type) + } +} + func fileExists(filename string) error { info, err := os.Stat(filename) if err != nil { diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 1a73afdb..3c04f982 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -3,15 +3,17 @@ package identity import ( "crypto" "crypto/tls" + "crypto/x509" + "fmt" "io/ioutil" + "net/http" "os" "path/filepath" "reflect" "testing" - "github.com/smallstep/cli/crypto/pemutil" - "github.com/smallstep/certificates/api" + "github.com/smallstep/cli/crypto/pemutil" ) func TestLoadDefaultIdentity(t *testing.T) { @@ -252,3 +254,97 @@ func TestWriteDefaultIdentity(t *testing.T) { }) } } + +type renewer struct { + pool *x509.CertPool + sign *api.SignResponse + err error +} + +func (r *renewer) GetRootCAs() *x509.CertPool { + return r.pool +} + +func (r *renewer) Renew(tr http.RoundTripper) (*api.SignResponse, error) { + return r.sign, r.err +} + +func TestIdentity_Renew(t *testing.T) { + tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests") + if err != nil { + t.Fatal(err) + } + + oldIdentityDir := identityDir + defer func() { + identityDir = oldIdentityDir + os.RemoveAll(tmpDir) + }() + + certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt") + if err != nil { + t.Fatal(err) + } + + ok := &renewer{ + sign: &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: certs[0]}, + CaPEM: api.Certificate{Certificate: certs[1]}, + CertChainPEM: []api.Certificate{ + {Certificate: certs[0]}, + {Certificate: certs[1]}, + }, + }, + } + + okOld := &renewer{ + sign: &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: certs[0]}, + CaPEM: api.Certificate{Certificate: certs[1]}, + }, + } + + fail := &renewer{ + err: fmt.Errorf("an error"), + } + + type fields struct { + Type string + Certificate string + Key string + } + type args struct { + client Renewer + } + tests := []struct { + name string + prepare func() + fields fields + args args + wantErr bool + }{ + {"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false}, + {"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false}, + {"ok disabled", func() {}, fields{}, args{nil}, false}, + {"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, + {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, + {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, + {"fail write identity", func() { + identityDir = filepath.Join(tmpDir, "bad-dir") + os.MkdirAll(identityDir, 0600) + }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.prepare() + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + } + if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr { + t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}