Use mTLS by default on SDK methods.

Add options to modify the tls.Config for different configurations.
Fixes #7
pull/11/head
Mariano Cano 6 years ago
parent bb03aadddf
commit d872f09910

@ -43,6 +43,12 @@ func Bootstrap(token string) (*Client, error) {
// Authority. By default the server will kick off a routine that will renew the
// certificate after 2/3rd of the certificate's lifetime has expired.
//
// Without any extra option the server will be configured for mTLS, it will
// require and verify clients certificates, but options can be used to drop this
// requirement, the most common will be only verify the certs if given with
// ca.VerifyClientCertIfGiven(), or add extra CAs with
// ca.AddClientCA(*x509.Certificate).
//
// Usage:
// // Default example with certificate rotation.
// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{
@ -61,60 +67,7 @@ func Bootstrap(token string) (*Client, error) {
// return err
// }
// srv.ListenAndServeTLS("", "")
func BootstrapServer(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
if base.TLSConfig != nil {
return nil, errors.New("server TLSConfig is already set")
}
client, err := Bootstrap(token)
if err != nil {
return nil, err
}
req, pk, err := CreateSignRequest(token)
if err != nil {
return nil, err
}
sign, err := client.Sign(req)
if err != nil {
return nil, err
}
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
base.TLSConfig = tlsConfig
return base, nil
}
// BootstrapServerWithMTLS is a helper function that using the given token
// returns the given http.Server configured with a TLS certificate signed by the
// Certificate Authority, this server will always require and verify a client
// certificate. By default the server will kick off a routine that will renew
// the certificate after 2/3rd of the certificate's lifetime has expired.
//
// Usage:
// // Default example with certificate rotation.
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
// Addr: ":443",
// Handler: handler,
// })
//
// // Example canceling automatic certificate rotation.
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
// Addr: ":443",
// Handler: handler,
// })
// if err != nil {
// return err
// }
// srv.ListenAndServeTLS("", "")
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) {
if base.TLSConfig != nil {
return nil, errors.New("server TLSConfig is already set")
}
@ -134,7 +87,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
return nil, err
}
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
if err != nil {
return nil, err
}
@ -161,7 +114,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
// return err
// }
// resp, err := client.Get("https://internal.smallstep.com")
func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) {
client, err := Bootstrap(token)
if err != nil {
return nil, err
@ -177,7 +130,7 @@ func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
return nil, err
}
transport, err := client.Transport(ctx, sign, pk)
transport, err := client.Transport(ctx, sign, pk, options...)
if err != nil {
return nil, err
}

@ -124,7 +124,7 @@ func TestBootstrap(t *testing.T) {
}
}
func TestBootstrapServer(t *testing.T) {
func TestBootstrapServerWithoutMTLS(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()
token := func() string {
@ -146,7 +146,7 @@ func TestBootstrapServer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven())
if (err != nil) != tt.wantErr {
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
return
@ -192,24 +192,24 @@ func TestBootstrapServerWithMTLS(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
if (err != nil) != tt.wantErr {
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
if got != nil {
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
t.Errorf("BootstrapServer() = %v, want nil", got)
}
} else {
expected := &http.Server{
TLSConfig: got.TLSConfig,
}
if !reflect.DeepEqual(got, expected) {
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
t.Errorf("BootstrapServer() = %v, want %v", got, expected)
}
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
}
}
})

@ -20,7 +20,7 @@ import (
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring.
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
@ -36,10 +36,15 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
if pool := getCertPool(sign); pool != nil {
tlsConfig.RootCAs = pool
}
// Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil {
return nil, err
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
@ -56,7 +61,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
// sign certificate, and a new certificate pool with the sign root certificate.
// The returned tls.Config will only verify the client certificate if provided.
// The server certificate will automatically rotate before expiring.
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
@ -74,13 +79,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
if pool := getCertPool(sign); pool != nil {
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// Add RootCAs for refresh client
tlsConfig.RootCAs = pool
}
// Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil {
return nil, err
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
@ -93,44 +103,15 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
return tlsConfig, nil
}
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
// the sign certificate, and a new certificate pool with the sign root certificate.
// The returned tls.Config will always require and verify a client certificate.
// The server certificate will automatically rotate before expiring.
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
return tlsConfig, nil
}
// Transport returns an http.Transport configured to use the client certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
if err != nil {
return nil, err
}
return getDefaultTransport(tlsConfig)
}
// getCertPool returns the transport x509.CertPool or the one from the sign
// request.
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
// Return the transport certPool
if c.certPool != nil {
return c.certPool
}
// Return certificate used in sign request.
if root, err := RootCertificate(sign); err == nil {
pool := x509.NewCertPool()
pool.AddCert(root)
return pool
}
return nil
}
// Certificate returns the server or client certificate from the sign response.
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.ServerPEM.Certificate == nil {
@ -189,6 +170,17 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
return &cert, nil
}
// getCertPool returns the transport x509.CertPool or the one from the sign
// request.
func getCertPool(sign *api.SignResponse) *x509.CertPool {
if root, err := RootCertificate(sign); err == nil {
pool := x509.NewCertPool()
pool.AddCert(root)
return pool
}
return nil
}
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
if sign.TLSOptions != nil {
return sign.TLSOptions.TLSConfig()

@ -0,0 +1,64 @@
package ca
import (
"crypto/tls"
"crypto/x509"
)
// TLSOption defines the type of a function that modifies a tls.Config.
type TLSOption func(c *tls.Config) error
// setTLSOptions takes one or more option function and applies them in order to
// a tls.Config.
func setTLSOptions(c *tls.Config, options []TLSOption) error {
for _, opt := range options {
if err := opt(c); err != nil {
return err
}
}
return nil
}
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
// a valid TLS client certificate. This is the default option for mTLS servers.
func RequireAndVerifyClientCert() TLSOption {
return func(c *tls.Config) error {
c.ClientAuth = tls.RequireAndVerifyClientCert
return nil
}
}
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
// a TLS client certificate if it is provided. It does not requires a certificate.
func VerifyClientCertIfGiven() TLSOption {
return func(c *tls.Config) error {
c.ClientAuth = tls.VerifyClientCertIfGiven
return nil
}
}
// AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs
// defines the set of root certificate authorities that clients use when
// verifying server certificates.
func AddRootCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error {
if c.RootCAs == nil {
c.RootCAs = x509.NewCertPool()
}
c.RootCAs.AddCert(cert)
return nil
}
}
// AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs
// defines the set of root certificate authorities that servers use if required
// to verify a client certificate by the policy in ClientAuth.
func AddClientCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error {
if c.ClientCAs == nil {
c.ClientCAs = x509.NewCertPool()
}
c.ClientCAs.AddCert(cert)
return nil
}
}

@ -0,0 +1,137 @@
package ca
import (
"crypto/tls"
"crypto/x509"
"fmt"
"reflect"
"testing"
)
func Test_setTLSOptions(t *testing.T) {
fail := func() TLSOption {
return func(c *tls.Config) error {
return fmt.Errorf("an error")
}
}
type args struct {
c *tls.Config
options []TLSOption
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRequireAndVerifyClientCert(t *testing.T) {
tests := []struct {
name string
want *tls.Config
}{
{"ok", &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := RequireAndVerifyClientCert()(got); err != nil {
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
}
})
}
}
func TestVerifyClientCertIfGiven(t *testing.T) {
tests := []struct {
name string
want *tls.Config
}{
{"ok", &tls.Config{ClientAuth: tls.VerifyClientCertIfGiven}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := VerifyClientCertIfGiven()(got); err != nil {
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
}
})
}
}
func TestAddRootCA(t *testing.T) {
cert := parseCertificate(rootPEM)
pool := x509.NewCertPool()
pool.AddCert(cert)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *tls.Config
}{
{"ok", args{cert}, &tls.Config{RootCAs: pool}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := AddRootCA(tt.args.cert)(got); err != nil {
t.Errorf("AddRootCA() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
}
})
}
}
func TestAddClientCA(t *testing.T) {
cert := parseCertificate(rootPEM)
pool := x509.NewCertPool()
pool.AddCert(cert)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *tls.Config
}{
{"ok", args{cert}, &tls.Config{ClientCAs: pool}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{}
if err := AddClientCA(tt.args.cert)(got); err != nil {
t.Errorf("AddClientCA() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
}
})
}
}

@ -104,15 +104,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
return client, sr, pk
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
func serverHandler(t *testing.T, clientDomain string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
@ -129,245 +122,46 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
}
w.Write([]byte("ok"))
}))
defer srv.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
}{
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}},
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
}
})
}
w.Write([]byte("ok"))
})
}
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Start CA
ca := startCATestServer()
defer ca.Close()
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
client, sr, pk := sign("127.0.0.1")
// Clients: transport and tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
// Create mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
}
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
// Create TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
}
w.Write([]byte("ok"))
}))
defer srv.Close()
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
wantErr map[string]bool
}{
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
@ -376,8 +170,8 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
@ -391,8 +185,8 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
@ -410,36 +204,44 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
return &http.Client{
Transport: tr,
}
}},
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
if cli == nil {
return
}
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := cli.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
return
}
if wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
})
}
}
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
@ -448,44 +250,36 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
ca := startCATestServer()
defer ca.Close()
clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
// Start mTLS server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close()
// Start TLS server
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close()
// Clients: transport and tlsConfig
// Transport
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
@ -495,38 +289,82 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// No client cert
root, err := RootCertificate(sr)
if err != nil {
t.Fatalf("RootCertificate() error = %v", err)
}
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tr3.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
name string
client *http.Client
wantErr map[string]bool
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
srvTLS.URL: false,
srvMTLS.URL: false,
}},
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
srvTLS.URL + "/no-cert": false,
srvMTLS.URL + "/no-cert": true,
}},
{"fail with default", &http.Client{}, map[string]bool{
srvTLS.URL + "/no-cert": true,
srvMTLS.URL + "/no-cert": true,
}},
}
// To count different cert fingerprints
fingerprints := map[string]struct{}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
}
// Wait for renewal 40s == 1m-1m/3
@ -535,17 +373,31 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
for path, wantErr := range tt.wantErr {
t.Run(path, func(t *testing.T) {
resp, err := tt.client.Get(path)
if (err != nil) != wantErr {
t.Errorf("http.Client.Get() error = %v", err)
return
}
if wantErr {
return
}
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
fingerprints[fp] = struct{}{}
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err)
return
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
return
}
})
}
})
}

@ -139,13 +139,13 @@ password `password` hardcoded, but you can create your own using `step ca init`.
These examples show the use of some other helper methods - simple ways to
create TLS configured http.Server and http.Client objects. The methods are
`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient`.
`BootstrapServer` and `BootstrapClient`.
```go
// Get a cancelable context to stop the renewal goroutines and timers.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create an http.Server
// Create an http.Server that requires a client certificate
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
Addr: ":8443",
Handler: handler,
@ -160,11 +160,11 @@ srv.ListenAndServeTLS("", "")
// Get a cancelable context to stop the renewal goroutines and timers.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create an http.Server that requires a client certificate
// Create an http.Server that does not require a client certificate
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
Addr: ":8443",
Handler: handler,
})
}, ca.VerifyClientCertIfGiven())
if err != nil {
panic(err)
}
@ -194,13 +194,13 @@ certificates $ bin/step-ca examples/pki/config/ca.json
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
```
Next we will start the bootstrap-server and enter `password` prompted for the
Next we will start the bootstrap-tls-server and enter `password` prompted for the
provisioner password:
```sh
certificates $ export STEPPATH=examples/pki
certificates $ export STEP_CA_URL=https://localhost:9000
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost)
certificates $ go run examples/bootstrap-tls-server/server.go $(step ca token localhost)
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
Please enter the password to decrypt the provisioner key:
Listening on :8443 ...

@ -22,7 +22,7 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
Addr: ":8443",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := "nobody"

@ -31,7 +31,7 @@ func main() {
}
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
}),
})
}, ca.VerifyClientCertIfGiven())
if err != nil {
panic(err)
}
Loading…
Cancel
Save