diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 029d13c8..577e4aaa 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -43,6 +43,12 @@ func Bootstrap(token string) (*Client, error) { // Authority. By default the server will kick off a routine that will renew the // certificate after 2/3rd of the certificate's lifetime has expired. // +// Without any extra option the server will be configured for mTLS, it will +// require and verify clients certificates, but options can be used to drop this +// requirement, the most common will be only verify the certs if given with +// ca.VerifyClientCertIfGiven(), or add extra CAs with +// ca.AddClientCA(*x509.Certificate). +// // Usage: // // Default example with certificate rotation. // srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{ @@ -61,60 +67,7 @@ func Bootstrap(token string) (*Client, error) { // return err // } // srv.ListenAndServeTLS("", "") -func BootstrapServer(ctx context.Context, token string, base *http.Server) (*http.Server, error) { - if base.TLSConfig != nil { - return nil, errors.New("server TLSConfig is already set") - } - - client, err := Bootstrap(token) - if err != nil { - return nil, err - } - - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err - } - - sign, err := client.Sign(req) - if err != nil { - return nil, err - } - - tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk) - if err != nil { - return nil, err - } - - base.TLSConfig = tlsConfig - return base, nil -} - -// BootstrapServerWithMTLS is a helper function that using the given token -// returns the given http.Server configured with a TLS certificate signed by the -// Certificate Authority, this server will always require and verify a client -// certificate. By default the server will kick off a routine that will renew -// the certificate after 2/3rd of the certificate's lifetime has expired. -// -// Usage: -// // Default example with certificate rotation. -// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{ -// Addr: ":443", -// Handler: handler, -// }) -// -// // Example canceling automatic certificate rotation. -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{ -// Addr: ":443", -// Handler: handler, -// }) -// if err != nil { -// return err -// } -// srv.ListenAndServeTLS("", "") -func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) { +func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) { if base.TLSConfig != nil { return nil, errors.New("server TLSConfig is already set") } @@ -134,7 +87,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve return nil, err } - tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk) + tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { return nil, err } @@ -161,7 +114,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve // return err // } // resp, err := client.Get("https://internal.smallstep.com") -func BootstrapClient(ctx context.Context, token string) (*http.Client, error) { +func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { client, err := Bootstrap(token) if err != nil { return nil, err @@ -177,7 +130,7 @@ func BootstrapClient(ctx context.Context, token string) (*http.Client, error) { return nil, err } - transport, err := client.Transport(ctx, sign, pk) + transport, err := client.Transport(ctx, sign, pk, options...) if err != nil { return nil, err } diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 50a3f1f1..241827c6 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -124,7 +124,7 @@ func TestBootstrap(t *testing.T) { } } -func TestBootstrapServer(t *testing.T) { +func TestBootstrapServerWithoutMTLS(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := func() string { @@ -146,7 +146,7 @@ func TestBootstrapServer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base) + got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven()) if (err != nil) != tt.wantErr { t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr) return @@ -192,24 +192,24 @@ func TestBootstrapServerWithMTLS(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base) + got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base) if (err != nil) != tt.wantErr { - t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if got != nil { - t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got) + t.Errorf("BootstrapServer() = %v, want nil", got) } } else { expected := &http.Server{ TLSConfig: got.TLSConfig, } if !reflect.DeepEqual(got, expected) { - t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected) + t.Errorf("BootstrapServer() = %v, want %v", got, expected) } if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil { - t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig) + t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig) } } }) diff --git a/ca/tls.go b/ca/tls.go index 800a6adf..5e8c4118 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -20,7 +20,7 @@ import ( // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. -func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) { +func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, err @@ -36,10 +36,15 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true // Build RootCAs with given root certificate - if pool := c.getCertPool(sign); pool != nil { + if pool := getCertPool(sign); pool != nil { tlsConfig.RootCAs = pool } + // Apply options if given + if err := setTLSOptions(tlsConfig, options); err != nil { + return nil, err + } + // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { @@ -56,7 +61,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, // sign certificate, and a new certificate pool with the sign root certificate. // The returned tls.Config will only verify the client certificate if provided. // The server certificate will automatically rotate before expiring. -func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) { +func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, err @@ -74,13 +79,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true // Build RootCAs with given root certificate - if pool := c.getCertPool(sign); pool != nil { + if pool := getCertPool(sign); pool != nil { tlsConfig.ClientCAs = pool - tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Add RootCAs for refresh client tlsConfig.RootCAs = pool } + // Apply options if given + if err := setTLSOptions(tlsConfig, options); err != nil { + return nil, err + } + // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { @@ -93,44 +103,15 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, return tlsConfig, nil } -// GetServerMutualTLSConfig returns a tls.Config for server use configured with -// the sign certificate, and a new certificate pool with the sign root certificate. -// The returned tls.Config will always require and verify a client certificate. -// The server certificate will automatically rotate before expiring. -func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) { - tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk) - if err != nil { - return nil, err - } - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - return tlsConfig, nil -} - // Transport returns an http.Transport configured to use the client certificate from the sign response. -func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) { - tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk) +func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) { + tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...) if err != nil { return nil, err } return getDefaultTransport(tlsConfig) } -// getCertPool returns the transport x509.CertPool or the one from the sign -// request. -func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool { - // Return the transport certPool - if c.certPool != nil { - return c.certPool - } - // Return certificate used in sign request. - if root, err := RootCertificate(sign); err == nil { - pool := x509.NewCertPool() - pool.AddCert(root) - return pool - } - return nil -} - // Certificate returns the server or client certificate from the sign response. func Certificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.ServerPEM.Certificate == nil { @@ -189,6 +170,17 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific return &cert, nil } +// getCertPool returns the transport x509.CertPool or the one from the sign +// request. +func getCertPool(sign *api.SignResponse) *x509.CertPool { + if root, err := RootCertificate(sign); err == nil { + pool := x509.NewCertPool() + pool.AddCert(root) + return pool + } + return nil +} + func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { if sign.TLSOptions != nil { return sign.TLSOptions.TLSConfig() diff --git a/ca/tls_options.go b/ca/tls_options.go new file mode 100644 index 00000000..fb0bb20b --- /dev/null +++ b/ca/tls_options.go @@ -0,0 +1,64 @@ +package ca + +import ( + "crypto/tls" + "crypto/x509" +) + +// TLSOption defines the type of a function that modifies a tls.Config. +type TLSOption func(c *tls.Config) error + +// setTLSOptions takes one or more option function and applies them in order to +// a tls.Config. +func setTLSOptions(c *tls.Config, options []TLSOption) error { + for _, opt := range options { + if err := opt(c); err != nil { + return err + } + } + return nil +} + +// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce +// a valid TLS client certificate. This is the default option for mTLS servers. +func RequireAndVerifyClientCert() TLSOption { + return func(c *tls.Config) error { + c.ClientAuth = tls.RequireAndVerifyClientCert + return nil + } +} + +// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate +// a TLS client certificate if it is provided. It does not requires a certificate. +func VerifyClientCertIfGiven() TLSOption { + return func(c *tls.Config) error { + c.ClientAuth = tls.VerifyClientCertIfGiven + return nil + } +} + +// AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs +// defines the set of root certificate authorities that clients use when +// verifying server certificates. +func AddRootCA(cert *x509.Certificate) TLSOption { + return func(c *tls.Config) error { + if c.RootCAs == nil { + c.RootCAs = x509.NewCertPool() + } + c.RootCAs.AddCert(cert) + return nil + } +} + +// AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs +// defines the set of root certificate authorities that servers use if required +// to verify a client certificate by the policy in ClientAuth. +func AddClientCA(cert *x509.Certificate) TLSOption { + return func(c *tls.Config) error { + if c.ClientCAs == nil { + c.ClientCAs = x509.NewCertPool() + } + c.ClientCAs.AddCert(cert) + return nil + } +} diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go new file mode 100644 index 00000000..896ff72b --- /dev/null +++ b/ca/tls_options_test.go @@ -0,0 +1,137 @@ +package ca + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "reflect" + "testing" +) + +func Test_setTLSOptions(t *testing.T) { + fail := func() TLSOption { + return func(c *tls.Config) error { + return fmt.Errorf("an error") + } + } + type args struct { + c *tls.Config + options []TLSOption + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false}, + {"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false}, + {"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr { + t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRequireAndVerifyClientCert(t *testing.T) { + tests := []struct { + name string + want *tls.Config + }{ + {"ok", &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := RequireAndVerifyClientCert()(got); err != nil { + t.Errorf("RequireAndVerifyClientCert() error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVerifyClientCertIfGiven(t *testing.T) { + tests := []struct { + name string + want *tls.Config + }{ + {"ok", &tls.Config{ClientAuth: tls.VerifyClientCertIfGiven}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := VerifyClientCertIfGiven()(got); err != nil { + t.Errorf("VerifyClientCertIfGiven() error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddRootCA(t *testing.T) { + cert := parseCertificate(rootPEM) + pool := x509.NewCertPool() + pool.AddCert(cert) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *tls.Config + }{ + {"ok", args{cert}, &tls.Config{RootCAs: pool}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddRootCA(tt.args.cert)(got); err != nil { + t.Errorf("AddRootCA() error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddRootCA() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddClientCA(t *testing.T) { + cert := parseCertificate(rootPEM) + pool := x509.NewCertPool() + pool.AddCert(cert) + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *tls.Config + }{ + {"ok", args{cert}, &tls.Config{ClientCAs: pool}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddClientCA(tt.args.cert)(got); err != nil { + t.Errorf("AddClientCA() error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddClientCA() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ca/tls_test.go b/ca/tls_test.go index 182310c4..799b496a 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -104,15 +104,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) ( return client, sr, pk } -func TestClient_GetServerTLSConfig_http(t *testing.T) { - client, sr, pk := sign("127.0.0.1") - tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } - clientDomain := "test.domain" - // Create server with given tls.Config - srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { +func serverHandler(t *testing.T, clientDomain string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.RequestURI != "/no-cert" { if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { w.Write([]byte("fail")) @@ -129,245 +122,46 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) return } - } - w.Write([]byte("ok")) - })) - defer srv.Close() - - tests := []struct { - name string - path string - wantErr bool - getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client - }{ - {"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { - tr, err := client.Transport(context.Background(), sr, pk) - if err != nil { - t.Errorf("Client.Transport() error = %v", err) - return nil - } - return &http.Client{ - Transport: tr, - } - }}, - {"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { - tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Errorf("Client.GetClientTLSConfig() error = %v", err) - return nil - } - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } - return &http.Client{ - Transport: tr, - } - }}, - {"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { - root, err := RootCertificate(sr) - if err != nil { - t.Errorf("RootCertificate() error = %v", err) - return nil - } - tlsConfig := getDefaultTLSConfig(sr) - tlsConfig.RootCAs = x509.NewCertPool() - tlsConfig.RootCAs.AddCert(root) - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } - return &http.Client{ - Transport: tr, - } - }}, - {"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { - return &http.Client{} - }}, - } + // Add serial number to check rotation + sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw) + w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:])) + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client, sr, pk := sign(clientDomain) - cli := tt.getClient(t, client, sr, pk) - if cli != nil { - resp, err := cli.Get(srv.URL + tt.path) - if (err != nil) != tt.wantErr { - t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr { - return - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) - } - } - }) - } + w.Write([]byte("ok")) + }) } -func TestClient_GetServerTLSConfig_renew(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode.") - } - - // Start CA - ca := startCATestServer() - defer ca.Close() - - client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute) - tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } +func TestClient_GetServerTLSConfig_http(t *testing.T) { clientDomain := "test.domain" - fingerprints := make(map[string]struct{}) - - // Create server with given tls.Config - srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { - w.Write([]byte("fail")) - t.Error("http.Request.TLS does not have peer certificates") - return - } - if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain) - return - } - if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) - return - } - // Add serial number to check rotation - sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw) - fingerprints[hex.EncodeToString(sum[:])] = struct{}{} - w.Write([]byte("ok")) - })) - defer srv.Close() + client, sr, pk := sign("127.0.0.1") - // Clients: transport and tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) - tr1, err := client.Transport(context.Background(), sr, pk) + // Create mTLS server + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) if err != nil { - t.Fatalf("Client.Transport() error = %v", err) - } - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) - tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetClientTLSConfig() error = %v", err) - } - tr2, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Fatalf("getDefaultTransport() error = %v", err) - } - - // Disable keep alives to force TLS handshake - tr1.DisableKeepAlives = true - tr2.DisableKeepAlives = true - - tests := []struct { - name string - client *http.Client - }{ - {"with transport", &http.Client{Transport: tr1}}, - {"with tlsConfig", &http.Client{Transport: tr2}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resp, err := tt.client.Get(srv.URL) - if err != nil { - t.Fatalf("http.Client.Get() error = %v", err) - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) - } - }) - } - - if l := len(fingerprints); l != 2 { - t.Errorf("number of fingerprints unexpected, got %d, want 4", l) - } - - // Wait for renewal 40s == 1m-1m/3 - log.Printf("Sleeping for %s ...\n", 40*time.Second) - time.Sleep(40 * time.Second) - - for _, tt := range tests { - t.Run("renewed "+tt.name, func(t *testing.T) { - resp, err := tt.client.Get(srv.URL) - if err != nil { - t.Fatalf("http.Client.Get() error = %v", err) - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) - } - }) - } - - if l := len(fingerprints); l != 4 { - t.Errorf("number of fingerprints unexpected, got %d, want 4", l) + t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } -} + srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + defer srvMTLS.Close() -func TestClient_GetServerMutualTLSConfig_http(t *testing.T) { - client, sr, pk := sign("127.0.0.1") - tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk) + // Create TLS server + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - clientDomain := "test.domain" - // Create server with given tls.Config - srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.RequestURI != "/no-cert" { - if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { - w.Write([]byte("fail")) - t.Error("http.Request.TLS does not have peer certificates") - return - } - if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain) - return - } - if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) - return - } - } - w.Write([]byte("ok")) - })) - defer srv.Close() + srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + defer srvTLS.Close() tests := []struct { name string - path string - wantErr bool getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client + wantErr map[string]bool }{ - {"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { + {"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { tr, err := client.Transport(context.Background(), sr, pk) if err != nil { t.Errorf("Client.Transport() error = %v", err) @@ -376,8 +170,8 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) { return &http.Client{ Transport: tr, } - }}, - {"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { + }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, + {"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk) if err != nil { t.Errorf("Client.GetClientTLSConfig() error = %v", err) @@ -391,8 +185,8 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) { return &http.Client{ Transport: tr, } - }}, - {"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { + }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, + {"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { root, err := RootCertificate(sr) if err != nil { t.Errorf("RootCertificate() error = %v", err) @@ -410,36 +204,44 @@ func TestClient_GetServerMutualTLSConfig_http(t *testing.T) { return &http.Client{ Transport: tr, } - }}, + }, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}}, + {"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { + return &http.Client{} + }, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client, sr, pk := sign(clientDomain) cli := tt.getClient(t, client, sr, pk) - if cli != nil { - resp, err := cli.Get(srv.URL + tt.path) - if (err != nil) != tt.wantErr { - t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.wantErr { - return - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) - } + if cli == nil { + return + } + for path, wantErr := range tt.wantErr { + t.Run(path, func(t *testing.T) { + resp, err := cli.Get(path) + if (err != nil) != wantErr { + t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr) + return + } + if wantErr { + return + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ioutil.RealAdd() error = %v", err) + } + if !bytes.Equal(b, []byte("ok")) { + t.Errorf("response body unexpected, got %s, want ok", b) + } + }) } }) } } -func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) { +func TestClient_GetServerTLSConfig_renew(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } @@ -448,44 +250,36 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) { ca := startCATestServer() defer ca.Close() + clientDomain := "test.domain" client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute) - tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk) + + // Start mTLS server + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - clientDomain := "test.domain" - fingerprints := make(map[string]struct{}) - - // Create server with given tls.Config - srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { - w.Write([]byte("fail")) - t.Error("http.Request.TLS does not have peer certificates") - return - } - if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain) - return - } - if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) { - w.Write([]byte("fail")) - t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) - return - } - // Add serial number to check rotation - sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw) - fingerprints[hex.EncodeToString(sum[:])] = struct{}{} - w.Write([]byte("ok")) - })) - defer srv.Close() + srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + defer srvMTLS.Close() + + // Start TLS server + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) + if err != nil { + t.Fatalf("Client.GetServerTLSConfig() error = %v", err) + } + srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + defer srvTLS.Close() - // Clients: transport and tlsConfig + // Transport client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) tr1, err := client.Transport(context.Background(), sr, pk) if err != nil { t.Fatalf("Client.Transport() error = %v", err) } + // Transport with tlsConfig client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) if err != nil { @@ -495,38 +289,82 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("getDefaultTransport() error = %v", err) } + // No client cert + root, err := RootCertificate(sr) + if err != nil { + t.Fatalf("RootCertificate() error = %v", err) + } + tlsConfig = getDefaultTLSConfig(sr) + tlsConfig.RootCAs = x509.NewCertPool() + tlsConfig.RootCAs.AddCert(root) + tr3, err := getDefaultTransport(tlsConfig) + if err != nil { + t.Fatalf("getDefaultTransport() error = %v", err) + } // Disable keep alives to force TLS handshake tr1.DisableKeepAlives = true tr2.DisableKeepAlives = true + tr3.DisableKeepAlives = true tests := []struct { - name string - client *http.Client + name string + client *http.Client + wantErr map[string]bool }{ - {"with transport", &http.Client{Transport: tr1}}, - {"with tlsConfig", &http.Client{Transport: tr2}}, + {"with transport", &http.Client{Transport: tr1}, map[string]bool{ + srvTLS.URL: false, + srvMTLS.URL: false, + }}, + {"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{ + srvTLS.URL: false, + srvMTLS.URL: false, + }}, + {"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{ + srvTLS.URL + "/no-cert": false, + srvMTLS.URL + "/no-cert": true, + }}, + {"fail with default", &http.Client{}, map[string]bool{ + srvTLS.URL + "/no-cert": true, + srvMTLS.URL + "/no-cert": true, + }}, } + // To count different cert fingerprints + fingerprints := map[string]struct{}{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp, err := tt.client.Get(srv.URL) - if err != nil { - t.Fatalf("http.Client.Get() error = %v", err) - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) + for path, wantErr := range tt.wantErr { + t.Run(path, func(t *testing.T) { + resp, err := tt.client.Get(path) + if (err != nil) != wantErr { + t.Errorf("http.Client.Get() error = %v", err) + return + } + if wantErr { + return + } + if fp := resp.Header.Get("x-fingerprint"); fp != "" { + fingerprints[fp] = struct{}{} + } + + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.RealAdd() error = %v", err) + return + } + if !bytes.Equal(b, []byte("ok")) { + t.Errorf("response body unexpected, got %s, want ok", b) + return + } + }) } }) } if l := len(fingerprints); l != 2 { - t.Errorf("number of fingerprints unexpected, got %d, want 4", l) + t.Errorf("number of fingerprints unexpected, got %d, want 2", l) } // Wait for renewal 40s == 1m-1m/3 @@ -535,17 +373,31 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) { for _, tt := range tests { t.Run("renewed "+tt.name, func(t *testing.T) { - resp, err := tt.client.Get(srv.URL) - if err != nil { - t.Fatalf("http.Client.Get() error = %v", err) - } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.RealAdd() error = %v", err) - } - if !bytes.Equal(b, []byte("ok")) { - t.Errorf("response body unexpected, got %s, want ok", b) + for path, wantErr := range tt.wantErr { + t.Run(path, func(t *testing.T) { + resp, err := tt.client.Get(path) + if (err != nil) != wantErr { + t.Errorf("http.Client.Get() error = %v", err) + return + } + if wantErr { + return + } + if fp := resp.Header.Get("x-fingerprint"); fp != "" { + fingerprints[fp] = struct{}{} + } + + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.RealAdd() error = %v", err) + return + } + if !bytes.Equal(b, []byte("ok")) { + t.Errorf("response body unexpected, got %s, want ok", b) + return + } + }) } }) } diff --git a/examples/README.md b/examples/README.md index 36a1f530..6afa645b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -139,13 +139,13 @@ password `password` hardcoded, but you can create your own using `step ca init`. These examples show the use of some other helper methods - simple ways to create TLS configured http.Server and http.Client objects. The methods are -`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient`. +`BootstrapServer` and `BootstrapClient`. ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() -// Create an http.Server +// Create an http.Server that requires a client certificate srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: handler, @@ -160,11 +160,11 @@ srv.ListenAndServeTLS("", "") // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() -// Create an http.Server that requires a client certificate +// Create an http.Server that does not require a client certificate srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{ Addr: ":8443", Handler: handler, -}) +}, ca.VerifyClientCertIfGiven()) if err != nil { panic(err) } @@ -194,13 +194,13 @@ certificates $ bin/step-ca examples/pki/config/ca.json 2018/11/02 18:29:25 Serving HTTPS on :9000 ... ``` -Next we will start the bootstrap-server and enter `password` prompted for the +Next we will start the bootstrap-tls-server and enter `password` prompted for the provisioner password: ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 -certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost) +certificates $ go run examples/bootstrap-tls-server/server.go $(step ca token localhost) ✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com) Please enter the password to decrypt the provisioner key: Listening on :8443 ... diff --git a/examples/bootstrap-mtls-server/server.go b/examples/bootstrap-mtls-server/server.go index 3527c368..1f9c0901 100644 --- a/examples/bootstrap-mtls-server/server.go +++ b/examples/bootstrap-mtls-server/server.go @@ -22,7 +22,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{ + srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := "nobody" diff --git a/examples/bootstrap-server/server.go b/examples/bootstrap-tls-server/server.go similarity index 96% rename from examples/bootstrap-server/server.go rename to examples/bootstrap-tls-server/server.go index 1f9c0901..e91f524b 100644 --- a/examples/bootstrap-server/server.go +++ b/examples/bootstrap-tls-server/server.go @@ -31,7 +31,7 @@ func main() { } w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC()))) }), - }) + }, ca.VerifyClientCertIfGiven()) if err != nil { panic(err) }