diff --git a/acme/order.go b/acme/order.go index d4a4c300..1fa0809e 100644 --- a/acme/order.go +++ b/acme/order.go @@ -200,6 +200,10 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ var sans []x509util.SubjectAlternativeName + if len(csr.EmailAddresses) > 0 || len(csr.URIs) > 0 { + return sans, NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed") + } + // order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers)) orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers)) @@ -279,7 +283,9 @@ func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int { // canonicalize canonicalizes a CSR so that it can be compared against an Order // NOTE: this effectively changes the order of SANs in the CSR, which may be OK, -// but may not be expected. +// but may not be expected. It also adds a Subject Common Name to either the IP +// addresses or DNS names slice, depending on whether it can be parsed as an IP +// or not. This might result in an additional SAN in the final certificate. func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) { // for clarity only; we're operating on the same object by pointer @@ -289,16 +295,20 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate // identifiers as the initial newOrder request. Identifiers of type "dns" // MUST appear either in the commonName portion of the requested subject // name or in an extensionRequest attribute [RFC2985] requesting a - // subjectAltName extension, or both. - // TODO(hs): we might want to check if the CommonName is in fact a DNS (and cannot - // be parsed as IP). This is related to https://github.com/smallstep/cli/pull/576 - // (ACME IP SANS) + // subjectAltName extension, or both. Subject Common Names that can be + // parsed as an IP are included as an IP address for the equality check. + // If these were excluded, a certificate could contain an IP as the + // common name without having been challenged. if csr.Subject.CommonName != "" { - // nolint:gocritic - canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) + if ip := net.ParseIP(csr.Subject.CommonName); ip != nil { + canonicalized.IPAddresses = append(canonicalized.IPAddresses, ip) + } else { + canonicalized.DNSNames = append(canonicalized.DNSNames, csr.Subject.CommonName) + } } - canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames) - canonicalized.IPAddresses = uniqueSortedIPs(csr.IPAddresses) + + canonicalized.DNSNames = uniqueSortedLowerNames(canonicalized.DNSNames) + canonicalized.IPAddresses = uniqueSortedIPs(canonicalized.IPAddresses) return canonicalized } @@ -340,7 +350,10 @@ func uniqueSortedIPs(ips []net.IP) (unique []net.IP) { } ipEntryMap := make(map[string]entry, len(ips)) for _, ip := range ips { - ipEntryMap[ip.String()] = entry{ip: ip} + // reparsing the IP results in the IP being represented using 16 bytes + // for both IPv4 as well as IPv6, even when the ips slice contains IPs that + // are represented by 4 bytes. This ensures a fair comparison and thus ordering. + ipEntryMap[ip.String()] = entry{ip: net.ParseIP(ip.String())} } unique = make([]net.IP, 0, len(ipEntryMap)) for _, entry := range ipEntryMap { diff --git a/acme/order_test.go b/acme/order_test.go index a90982a6..493b40b7 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -6,10 +6,12 @@ import ( "crypto/x509/pkix" "encoding/json" "net" + "net/url" "reflect" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" @@ -825,71 +827,92 @@ func Test_uniqueSortedIPs(t *testing.T) { ips []net.IP } tests := []struct { - name string - args args - wantUnique []net.IP + name string + args args + want []net.IP }{ { name: "ok/empty", args: args{ ips: []net.IP{}, }, - wantUnique: []net.IP{}, + want: []net.IP{}, }, { name: "ok/single-ipv4", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42")}, }, - wantUnique: []net.IP{net.ParseIP("192.168.42.42")}, + want: []net.IP{net.ParseIP("192.168.42.42")}, }, { name: "ok/multiple-ipv4", args: args{ - ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1")}, + ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1"), net.ParseIP("127.0.0.1")}, + }, + want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, + }, { + name: "ok/multiple-ipv4-with-varying-byte-representations", + args: args{ + ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1"), []byte{0x7f, 0x0, 0x0, 0x1}}, }, - wantUnique: []net.IP{net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, + want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, }, { name: "ok/unique-ipv4", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, }, - wantUnique: []net.IP{net.ParseIP("192.168.42.42")}, + want: []net.IP{net.ParseIP("192.168.42.42")}, }, { name: "ok/single-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::30")}, }, - wantUnique: []net.IP{net.ParseIP("2001:db8::30")}, + want: []net.IP{net.ParseIP("2001:db8::30")}, }, { name: "ok/multiple-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::30"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::10")}, }, - wantUnique: []net.IP{net.ParseIP("2001:db8::10"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::30")}, + want: []net.IP{net.ParseIP("2001:db8::10"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::30")}, }, { name: "ok/unique-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1")}, }, - wantUnique: []net.IP{net.ParseIP("2001:db8::1")}, + want: []net.IP{net.ParseIP("2001:db8::1")}, }, { name: "ok/mixed-ipv4-and-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, }, - wantUnique: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, + want: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, + }, + { + name: "ok/mixed-ipv4-and-ipv6-and-varying-byte-representations", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42"), []byte{0x7f, 0x0, 0x0, 0x1}}, + }, + want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, + }, + { + name: "ok/mixed-ipv4-and-ipv6-and-more-varying-byte-representations", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::2"), net.ParseIP("192.168.42.42"), []byte{0x7f, 0x0, 0x0, 0x1}, []byte{0x7f, 0x0, 0x0, 0x1}, []byte{0x7f, 0x0, 0x0, 0x2}}, + }, + want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.2"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::2")}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotUnique := uniqueSortedIPs(tt.args.ips); !reflect.DeepEqual(gotUnique, tt.wantUnique) { - t.Errorf("uniqueSortedIPs() = %v, want %v", gotUnique, tt.wantUnique) + got := uniqueSortedIPs(tt.args.ips) + if !cmp.Equal(tt.want, got) { + t.Errorf("uniqueSortedIPs() diff =\n%s", cmp.Diff(tt.want, got)) } }) } @@ -1122,9 +1145,9 @@ func Test_canonicalize(t *testing.T) { csr *x509.CertificateRequest } tests := []struct { - name string - args args - wantCanonicalized *x509.CertificateRequest + name string + args args + want *x509.CertificateRequest }{ { name: "ok/dns", @@ -1133,7 +1156,7 @@ func Test_canonicalize(t *testing.T) { DNSNames: []string{"www.example.com", "example.com"}, }, }, - wantCanonicalized: &x509.CertificateRequest{ + want: &x509.CertificateRequest{ DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{}, }, @@ -1148,7 +1171,7 @@ func Test_canonicalize(t *testing.T) { DNSNames: []string{"www.example.com"}, }, }, - wantCanonicalized: &x509.CertificateRequest{ + want: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, @@ -1163,7 +1186,7 @@ func Test_canonicalize(t *testing.T) { IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, - wantCanonicalized: &x509.CertificateRequest{ + want: &x509.CertificateRequest{ DNSNames: []string{}, IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, @@ -1176,7 +1199,7 @@ func Test_canonicalize(t *testing.T) { IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, - wantCanonicalized: &x509.CertificateRequest{ + want: &x509.CertificateRequest{ DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, @@ -1192,7 +1215,7 @@ func Test_canonicalize(t *testing.T) { IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, - wantCanonicalized: &x509.CertificateRequest{ + want: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, @@ -1200,11 +1223,31 @@ func Test_canonicalize(t *testing.T) { IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, }, + { + name: "ok/ip-common-name", + args: args{ + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + DNSNames: []string{"example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, + }, + }, + want: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + DNSNames: []string{"example.com"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotCanonicalized := canonicalize(tt.args.csr); !reflect.DeepEqual(gotCanonicalized, tt.wantCanonicalized) { - t.Errorf("canonicalize() = %v, want %v", gotCanonicalized, tt.wantCanonicalized) + got := canonicalize(tt.args.csr) + if !cmp.Equal(tt.want, got) { + t.Errorf("canonicalize() diff =\n%s", cmp.Diff(tt.want, got)) } }) } @@ -1238,6 +1281,39 @@ func TestOrder_sans(t *testing.T) { }, err: nil, }, + { + name: "fail/invalid-alternative-name-email", + fields: fields{ + Identifiers: []Identifier{}, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + EmailAddresses: []string{"test@example.com"}, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed"), + }, + { + name: "fail/invalid-alternative-name-uri", + fields: fields{ + Identifiers: []Identifier{}, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + URIs: []*url.URL{ + { + Scheme: "https://", + Host: "smallstep.com", + }, + }, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed"), + }, { name: "fail/error-names-length-mismatch", fields: fields{ diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 42087985..0e0f0fe3 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -2,12 +2,14 @@ package ca import ( "context" + "crypto" "crypto/tls" "net" "net/http" "strings" "github.com/pkg/errors" + "github.com/smallstep/certificates/api" "go.step.sm/crypto/jose" ) @@ -58,25 +60,21 @@ func Bootstrap(token string) (*Client, error) { // } // resp, err := client.Get("https://internal.smallstep.com") func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { - client, err := Bootstrap(token) + b, err := createBootstrap(token) if err != nil { return nil, err } - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToRootCAs()) } - sign, err := client.Sign(req) - if err != nil { - return nil, err - } - - // Make sure the tlsConfig have all supported roots on RootCAs - options = append(options, AddRootsToRootCAs()) - - transport, err := client.Transport(ctx, sign, pk, options...) + transport, err := b.Client.Transport(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } @@ -120,25 +118,21 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, errors.New("server TLSConfig is already set") } - client, err := Bootstrap(token) + b, err := createBootstrap(token) if err != nil { return nil, err } - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err - } - - sign, err := client.Sign(req) - if err != nil { - return nil, err + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToCAs()) } - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) - - tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) + tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } @@ -172,28 +166,60 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio // ... // register services // srv.Serve(lis) func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { - client, err := Bootstrap(token) + b, err := createBootstrap(token) if err != nil { return nil, err } - req, pk, err := CreateSignRequest(token) + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToCAs()) + } + + tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } - sign, err := client.Sign(req) + return tls.NewListener(inner, tlsConfig), nil +} + +type bootstrap struct { + Client *Client + RequireClientAuth bool + SignResponse *api.SignResponse + PrivateKey crypto.PrivateKey +} + +func createBootstrap(token string) (*bootstrap, error) { + client, err := Bootstrap(token) if err != nil { return nil, err } - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) + version, err := client.Version() + if err != nil { + return nil, err + } - tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) + req, pk, err := CreateSignRequest(token) if err != nil { return nil, err } - return tls.NewListener(inner, tlsConfig), nil + sign, err := client.Sign(req) + if err != nil { + return nil, err + } + + return &bootstrap{ + Client: client, + RequireClientAuth: version.RequireClientAuthentication, + SignResponse: sign, + PrivateKey: pk, + }, nil } diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 7c1bc908..9482d657 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "sync" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" ) @@ -74,6 +76,30 @@ func startCAServer(configFile string) (*CA, string, error) { return ca, caURL, nil } +func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/version" { + api.JSON(w, api.VersionResponse{ + Version: "test", + RequireClientAuthentication: true, + }) + return + } + + for _, s := range nonAuthenticatedPaths { + if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) { + next.ServeHTTP(w, r) + } + } + isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 + if !isMTLS { + api.WriteError(w, errs.Unauthorized("missing peer certificate")) + } else { + next.ServeHTTP(w, r) + } + }) +} + func generateBootstrapToken(ca, subject, sha string) string { now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) @@ -171,6 +197,15 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -182,6 +217,7 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -217,6 +253,15 @@ func TestBootstrapServerWithMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -228,6 +273,7 @@ func TestBootstrapServerWithMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -263,6 +309,15 @@ func TestBootstrapClient(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -273,6 +328,7 @@ func TestBootstrapClient(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token()}, false}, + {"ok mtls", args{context.Background(), mtlsToken()}, false}, {"fail", args{context.Background(), "bad-token"}, true}, } for _, tt := range tests { @@ -541,6 +597,15 @@ func TestBootstrapListener(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { token string } @@ -550,6 +615,7 @@ func TestBootstrapListener(t *testing.T) { wantErr bool }{ {"ok", args{token()}, false}, + {"ok mtls", args{mtlsToken()}, false}, {"fail", args{"bad-token"}, true}, } for _, tt := range tests { diff --git a/ca/tls_options.go b/ca/tls_options.go index b3b2d057..c77b70c3 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -115,6 +115,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } + ctx.hasRootCA = true ctx.Config.RootCAs.AddCert(cert) ctx.mutableConfig.AddImmutableRootCACert(cert) return nil @@ -129,6 +130,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } + ctx.hasClientCA = true ctx.Config.ClientCAs.AddCert(cert) ctx.mutableConfig.AddImmutableClientCACert(cert) return nil diff --git a/kms/pkcs11/pkcs11.go b/kms/pkcs11/pkcs11.go index cec05d33..c0e06408 100644 --- a/kms/pkcs11/pkcs11.go +++ b/kms/pkcs11/pkcs11.go @@ -7,6 +7,7 @@ import ( "context" "crypto" "crypto/elliptic" + "crypto/rsa" "crypto/x509" "encoding/hex" "fmt" @@ -142,8 +143,7 @@ func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons }, nil } -// CreateSigner creates a signer using the key present in the PKCS#11 MODULE signature -// slot. +// CreateSigner creates a signer using a key present in the PKCS#11 module. func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { if req.SigningKey == "" { return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") @@ -157,6 +157,27 @@ func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er return signer, nil } +// CreateDecrypter creates a decrypter using a key present in the PKCS#11 +// module. +func (k *PKCS11) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) { + if req.DecryptionKey == "" { + return nil, errors.New("createDecrypterRequest 'decryptionKey' cannot be empty") + } + + signer, err := findSigner(k.p11, req.DecryptionKey) + if err != nil { + return nil, errors.Wrap(err, "createDecrypterRequest failed") + } + + // Only RSA keys will implement the Decrypter interface. + if _, ok := signer.Public().(*rsa.PublicKey); ok { + if dec, ok := signer.(crypto.Decrypter); ok { + return dec, nil + } + } + return nil, errors.New("createDecrypterRequest failed: signer does not implement crypto.Decrypter") +} + // LoadCertificate implements kms.CertificateManager and loads a certificate // from the YubiKey. func (k *PKCS11) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 409cfb3f..06edd048 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -4,6 +4,7 @@ package pkcs11 import ( + "bytes" "context" "crypto" "crypto/ecdsa" @@ -491,6 +492,86 @@ func TestPKCS11_CreateSigner(t *testing.T) { } } +func TestPKCS11_CreateDecrypter(t *testing.T) { + k := setupPKCS11(t) + data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger") + + type args struct { + req *apiv1.CreateDecrypterRequest + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"RSA", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:id=7371;object=rsa-key", + }}, false}, + {"RSA PSS", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:id=7372;object=rsa-pss-key", + }}, false}, + {"ECDSA P256", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:id=7373;object=ecdsa-p256-key", + }}, true}, + {"ECDSA P384", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:id=7374;object=ecdsa-p384-key", + }}, true}, + {"ECDSA P521", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:id=7375;object=ecdsa-p521-key", + }}, true}, + {"fail DecryptionKey", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "", + }}, true}, + {"fail uri", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "https:id=7375;object=ecdsa-p521-key", + }}, true}, + {"fail FindKeyPair", args{&apiv1.CreateDecrypterRequest{ + DecryptionKey: "pkcs11:foo=bar", + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := k.CreateDecrypter(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("PKCS11.CreateDecrypter() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if got != nil { + pub := got.Public().(*rsa.PublicKey) + // PKCS#1 v1.5 + enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data) + if err != nil { + t.Errorf("rsa.EncryptPKCS1v15() error = %v", err) + return + } + dec, err := got.Decrypt(rand.Reader, enc, nil) + if err != nil { + t.Errorf("PKCS1v15.Decrypt() error = %v", err) + } else if !bytes.Equal(dec, data) { + t.Errorf("PKCS1v15.Decrypt() failed got = %s, want = %s", dec, data) + } + + // RSA-OAEP + enc, err = rsa.EncryptOAEP(crypto.SHA256.New(), rand.Reader, pub, data, []byte("label")) + if err != nil { + t.Errorf("rsa.EncryptOAEP() error = %v", err) + return + } + dec, err = got.Decrypt(rand.Reader, enc, &rsa.OAEPOptions{ + Hash: crypto.SHA256, + Label: []byte("label"), + }) + if err != nil { + t.Errorf("RSA-OAEP.Decrypt() error = %v", err) + } else if !bytes.Equal(dec, data) { + t.Errorf("RSA-OAEP.Decrypt() RSA-OAEP failed got = %s, want = %s", dec, data) + } + } + }) + } +} + func TestPKCS11_LoadCertificate(t *testing.T) { k := setupPKCS11(t)