diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 42087985..d0a74ab3 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -63,6 +63,11 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* return nil, err } + version, err := client.Version() + if err != nil { + return nil, err + } + req, pk, err := CreateSignRequest(token) if err != nil { return nil, err @@ -73,8 +78,14 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* return nil, err } - // Make sure the tlsConfig have all supported roots on RootCAs - options = append(options, AddRootsToRootCAs()) + // Make sure the tlsConfig have all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !version.RequireClientAuthentication { + options = append(options, AddRootsToRootCAs()) + } transport, err := client.Transport(ctx, sign, pk, options...) if err != nil { @@ -125,6 +136,11 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, err } + version, err := client.Version() + if err != nil { + return nil, err + } + req, pk, err := CreateSignRequest(token) if err != nil { return nil, err @@ -135,8 +151,14 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, err } - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) + // Make sure the tlsConfig have all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !version.RequireClientAuthentication { + options = append(options, AddRootsToCAs()) + } tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { @@ -177,6 +199,11 @@ func BootstrapListener(ctx context.Context, token string, inner net.Listener, op return nil, err } + version, err := client.Version() + if err != nil { + return nil, err + } + req, pk, err := CreateSignRequest(token) if err != nil { return nil, err @@ -187,8 +214,14 @@ func BootstrapListener(ctx context.Context, token string, inner net.Listener, op return nil, err } - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) + // Make sure the tlsConfig have all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !version.RequireClientAuthentication { + options = append(options, AddRootsToCAs()) + } tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 7c1bc908..c440f140 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "sync" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" ) @@ -74,6 +76,30 @@ func startCAServer(configFile string) (*CA, string, error) { return ca, caURL, nil } +func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/version" { + api.JSON(w, api.VersionResponse{ + Version: "test", + RequireClientAuthentication: true, + }) + return + } + + for _, s := range nonAuthenticatedPaths { + if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) { + next.ServeHTTP(w, r) + } + } + isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 + if !isMTLS { + api.WriteError(w, errs.Unauthorized("missing peer certificate")) + } else { + next.ServeHTTP(w, r) + } + }) +} + func generateBootstrapToken(ca, subject, sha string) string { now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) @@ -171,6 +197,15 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -182,6 +217,7 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -217,6 +253,15 @@ func TestBootstrapServerWithMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -228,6 +273,7 @@ func TestBootstrapServerWithMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -263,6 +309,15 @@ func TestBootstrapClient(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -273,6 +328,7 @@ func TestBootstrapClient(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token()}, false}, + {"ok mtls", args{context.Background(), mtlsToken()}, false}, {"fail", args{context.Background(), "bad-token"}, true}, } for _, tt := range tests { @@ -541,6 +597,15 @@ func TestBootstrapListener(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { token string } @@ -550,6 +615,7 @@ func TestBootstrapListener(t *testing.T) { wantErr bool }{ {"ok", args{token()}, false}, + {"ok mtls", args{mtlsToken()}, false}, {"fail", args{"bad-token"}, true}, } for _, tt := range tests { diff --git a/ca/tls_options.go b/ca/tls_options.go index b3b2d057..c77b70c3 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -115,6 +115,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } + ctx.hasRootCA = true ctx.Config.RootCAs.AddCert(cert) ctx.mutableConfig.AddImmutableRootCACert(cert) return nil @@ -129,6 +130,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } + ctx.hasClientCA = true ctx.Config.ClientCAs.AddCert(cert) ctx.mutableConfig.AddImmutableClientCACert(cert) return nil