Add method to renew the identity.

pull/166/head^2
Mariano Cano 5 years ago committed by max furman
parent 9aafe265d0
commit 14e59775bd

@ -8,6 +8,7 @@ import (
"encoding/json"
"encoding/pem"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
@ -191,6 +192,62 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
}
}
// Renewer is that interface that a renew client must implement.
type Renewer interface {
GetRootCAs() *x509.CertPool
Renew(tr http.RoundTripper) (*api.SignResponse, error)
}
// Renew renews the current identity certificate using a client with a renew
// method.
func (i *Identity) Renew(client Renewer) error {
switch i.Kind() {
case Disabled:
return nil
case MutualTLS:
cert, err := i.TLSCertificate()
if err != nil {
return err
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: client.GetRootCAs(),
PreferServerCipherSuites: true,
}
sign, err := client.Renew(tr)
if err != nil {
return err
}
if sign.CertChainPEM == nil || len(sign.CertChainPEM) == 0 {
sign.CertChainPEM = []api.Certificate{sign.ServerPEM, sign.CaPEM}
}
// Write certificate
buf := new(bytes.Buffer)
for _, crt := range sign.CertChainPEM {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate")
}
}
certFilename := filepath.Join(identityDir, "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
return nil
default:
return errors.Errorf("unsupported identity type %s", i.Type)
}
}
func fileExists(filename string) error {
info, err := os.Stat(filename)
if err != nil {

@ -3,15 +3,17 @@ package identity
import (
"crypto"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/cli/crypto/pemutil"
)
func TestLoadDefaultIdentity(t *testing.T) {
@ -252,3 +254,97 @@ func TestWriteDefaultIdentity(t *testing.T) {
})
}
}
type renewer struct {
pool *x509.CertPool
sign *api.SignResponse
err error
}
func (r *renewer) GetRootCAs() *x509.CertPool {
return r.pool
}
func (r *renewer) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
return r.sign, r.err
}
func TestIdentity_Renew(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
if err != nil {
t.Fatal(err)
}
oldIdentityDir := identityDir
defer func() {
identityDir = oldIdentityDir
os.RemoveAll(tmpDir)
}()
certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
if err != nil {
t.Fatal(err)
}
ok := &renewer{
sign: &api.SignResponse{
ServerPEM: api.Certificate{Certificate: certs[0]},
CaPEM: api.Certificate{Certificate: certs[1]},
CertChainPEM: []api.Certificate{
{Certificate: certs[0]},
{Certificate: certs[1]},
},
},
}
okOld := &renewer{
sign: &api.SignResponse{
ServerPEM: api.Certificate{Certificate: certs[0]},
CaPEM: api.Certificate{Certificate: certs[1]},
},
}
fail := &renewer{
err: fmt.Errorf("an error"),
}
type fields struct {
Type string
Certificate string
Key string
}
type args struct {
client Renewer
}
tests := []struct {
name string
prepare func()
fields fields
args args
wantErr bool
}{
{"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false},
{"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false},
{"ok disabled", func() {}, fields{}, args{nil}, false},
{"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir")
os.MkdirAll(identityDir, 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.prepare()
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
}
if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr {
t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

Loading…
Cancel
Save