diff --git a/cas/apiv1/services.go b/cas/apiv1/services.go index d4dd3c8c..cf9a5470 100644 --- a/cas/apiv1/services.go +++ b/cas/apiv1/services.go @@ -1,6 +1,7 @@ package apiv1 import ( + "crypto/x509" "net/http" "strings" ) @@ -26,6 +27,12 @@ type CertificateAuthorityCreator interface { CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error) } +// SignatureAlgorithmGetter is an optional implementation in a crypto.Signer +// that returns the SignatureAlgorithm to use. +type SignatureAlgorithmGetter interface { + SignatureAlgorithm() x509.SignatureAlgorithm +} + // Type represents the CAS type used. type Type string diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 21760490..f3b2d051 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -68,7 +68,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 } req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R req.Template.NotAfter = t.Add(req.Lifetime) req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) if err != nil { return nil, err } @@ -150,12 +150,12 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori var cert *x509.Certificate switch req.Type { case apiv1.RootCA: - cert, err = x509util.CreateCertificate(req.Template, req.Template, signer.Public(), signer) + cert, err = createCertificate(req.Template, req.Template, signer.Public(), signer) if err != nil { return nil, err } case apiv1.IntermediateCA: - cert, err = x509util.CreateCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) + cert, err = createCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) if err != nil { return nil, err } @@ -210,3 +210,16 @@ func (c *SoftCAS) createSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer, } return c.KeyManager.CreateSigner(req) } + +// createCertificate sets the SignatureAlgorithm of the template if necessary +// and calls x509util.CreateCertificate. +func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) { + // Signers can specify the signature algorithm. This is specially important + // when x509.CreateCertificates attempts to validate a RSAPSS signature. + if template.SignatureAlgorithm == 0 { + if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok { + template.SignatureAlgorithm = sa.SignatureAlgorithm() + } + } + return x509util.CreateCertificate(template, parent, pub, signer) +} diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 092a0337..c8e1a8e9 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -75,6 +75,15 @@ var ( testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) ) +type signatureAlgorithmSigner struct { + crypto.Signer + algorithm x509.SignatureAlgorithm +} + +func (s *signatureAlgorithmSigner) SignatureAlgorithm() x509.SignatureAlgorithm { + return s.algorithm +} + type mockKeyManager struct { signer crypto.Signer errGetPublicKey error @@ -247,6 +256,13 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil + saTemplate := *testSignedTemplate + saTemplate.SignatureAlgorithm = 0 + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -267,6 +283,12 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ + Template: &saTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ @@ -316,6 +338,11 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -336,6 +363,12 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ @@ -425,6 +458,11 @@ func Test_now(t *testing.T) { func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { mockNow(t) + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -467,6 +505,17 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { PrivateKey: testSigner, Signer: testSigner, }, false}, + {"ok signature algorithm", fields{nil, nil, &mockKeyManager{signer: saSigner}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testRootTemplate, + Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Root CA", + Certificate: testSignedRootTemplate, + PublicKey: testSignedRootTemplate.PublicKey, + PrivateKey: saSigner, + Signer: saSigner, + }, false}, {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index f4c656d3..65d06048 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -3,6 +3,7 @@ package cloudkms import ( "context" "crypto" + "crypto/x509" "log" "strings" "time" @@ -63,6 +64,19 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{ apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384, } +var cryptoKeyVersionMapping = map[kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm]x509.SignatureAlgorithm{ + kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: x509.ECDSAWithSHA256, + kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384: x509.ECDSAWithSHA384, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: x509.SHA512WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512: x509.SHA512WithRSAPSS, +} + // KeyManagementClient defines the methods on KeyManagementClient that this // package will use. This interface will be used for unit testing. type KeyManagementClient interface { diff --git a/kms/cloudkms/signer.go b/kms/cloudkms/signer.go index 686aca25..5a5443cf 100644 --- a/kms/cloudkms/signer.go +++ b/kms/cloudkms/signer.go @@ -2,6 +2,7 @@ package cloudkms import ( "crypto" + "crypto/x509" "io" "github.com/pkg/errors" @@ -13,6 +14,7 @@ import ( type Signer struct { client KeyManagementClient signingKey string + algorithm x509.SignatureAlgorithm publicKey crypto.PublicKey } @@ -40,7 +42,7 @@ func (s *Signer) preloadKey(signingKey string) error { if err != nil { return errors.Wrap(err, "cloudKMS GetPublicKey failed") } - + s.algorithm = cryptoKeyVersionMapping[response.Algorithm] s.publicKey, err = pemutil.ParseKey([]byte(response.Pem)) return err } @@ -84,3 +86,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([] return response.Signature, nil } + +// SignatureAlgorithm returns the algorithm that must be specified in a +// certificate to sign. This is specially important to distinguish RSA and +// RSAPSS schemas. +func (s *Signer) SignatureAlgorithm() x509.SignatureAlgorithm { + return s.algorithm +} diff --git a/kms/cloudkms/signer_test.go b/kms/cloudkms/signer_test.go index fa730fe3..a8f964f1 100644 --- a/kms/cloudkms/signer_test.go +++ b/kms/cloudkms/signer_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto" "crypto/rand" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -156,3 +157,79 @@ func Test_signer_Sign(t *testing.T) { }) } } + +func TestSigner_SignatureAlgorithm(t *testing.T) { + pemBytes, err := ioutil.ReadFile("testdata/pub.pem") + if err != nil { + t.Fatal(err) + } + + client := &MockClient{ + getPublicKey: func(_ context.Context, req *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + var algorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm + switch req.Name { + case "ECDSA-SHA256": + algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256 + case "ECDSA-SHA384": + algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384 + case "SHA256-RSA-2048": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256 + case "SHA256-RSA-3072": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256 + case "SHA256-RSA-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256 + case "SHA512-RSA-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512 + case "SHA256-RSAPSS-2048": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256 + case "SHA256-RSAPSS-3072": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256 + case "SHA256-RSAPSS-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256 + case "SHA512-RSAPSS-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512 + } + return &kmspb.PublicKey{ + Pem: string(pemBytes), + Algorithm: algorithm, + }, nil + }, + } + + if err != nil { + t.Fatal(err) + } + + type fields struct { + client KeyManagementClient + signingKey string + } + tests := []struct { + name string + fields fields + want x509.SignatureAlgorithm + }{ + {"ECDSA-SHA256", fields{client, "ECDSA-SHA256"}, x509.ECDSAWithSHA256}, + {"ECDSA-SHA384", fields{client, "ECDSA-SHA384"}, x509.ECDSAWithSHA384}, + {"SHA256-RSA-2048", fields{client, "SHA256-RSA-2048"}, x509.SHA256WithRSA}, + {"SHA256-RSA-3072", fields{client, "SHA256-RSA-3072"}, x509.SHA256WithRSA}, + {"SHA256-RSA-4096", fields{client, "SHA256-RSA-4096"}, x509.SHA256WithRSA}, + {"SHA512-RSA-4096", fields{client, "SHA512-RSA-4096"}, x509.SHA512WithRSA}, + {"SHA256-RSAPSS-2048", fields{client, "SHA256-RSAPSS-2048"}, x509.SHA256WithRSAPSS}, + {"SHA256-RSAPSS-3072", fields{client, "SHA256-RSAPSS-3072"}, x509.SHA256WithRSAPSS}, + {"SHA256-RSAPSS-4096", fields{client, "SHA256-RSAPSS-4096"}, x509.SHA256WithRSAPSS}, + {"SHA512-RSAPSS-4096", fields{client, "SHA512-RSAPSS-4096"}, x509.SHA512WithRSAPSS}, + {"unknown", fields{client, "UNKNOWN"}, x509.UnknownSignatureAlgorithm}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signer, err := NewSigner(tt.fields.client, tt.fields.signingKey) + if err != nil { + t.Errorf("NewSigner() error = %v", err) + } + if got := signer.SignatureAlgorithm(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.SignatureAlgorithm() = %v, want %v", got, tt.want) + } + }) + } +}