diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index 9f94c6ae..51c5f687 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -71,6 +71,8 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 switch { case req.CSR == nil: return nil, errors.New("createCertificateRequest `csr` cannot be nil") + case req.Template == nil: + return nil, errors.New("createCertificateRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") } @@ -87,7 +89,7 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 info.ProvisionerName = p.Name } - cert, chain, err := s.createCertificate(req.CSR, req.Lifetime, info) + cert, chain, err := s.createCertificate(req.CSR, req.Template, req.Lifetime, info) if err != nil { return nil, err } @@ -167,18 +169,18 @@ func (s *StepCAS) GetCertificateAuthority(*apiv1.GetCertificateAuthorityRequest) }, nil } -func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration, raInfo *raInfo) (*x509.Certificate, []*x509.Certificate, error) { - sans := make([]string, 0, len(cr.DNSNames)+len(cr.EmailAddresses)+len(cr.IPAddresses)+len(cr.URIs)) - sans = append(sans, cr.DNSNames...) - sans = append(sans, cr.EmailAddresses...) - for _, ip := range cr.IPAddresses { +func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, template *x509.Certificate, lifetime time.Duration, raInfo *raInfo) (*x509.Certificate, []*x509.Certificate, error) { + sans := make([]string, 0, len(template.DNSNames)+len(template.EmailAddresses)+len(template.IPAddresses)+len(template.URIs)) + sans = append(sans, template.DNSNames...) + sans = append(sans, template.EmailAddresses...) + for _, ip := range template.IPAddresses { sans = append(sans, ip.String()) } - for _, u := range cr.URIs { + for _, u := range template.URIs { sans = append(sans, u.String()) } - commonName := cr.Subject.CommonName + commonName := template.Subject.CommonName if commonName == "" && len(sans) > 0 { commonName = sans[0] } diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index b9dd9abd..f7746da0 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -23,6 +23,7 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" + "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" @@ -631,6 +632,17 @@ func TestStepCAS_CreateCertificate(t *testing.T) { jwkEnc := testJWKIssuer(t, caURL, testPassword) x5cBad := testX5CIssuer(t, caURL, "bad-password") + testTemplate := &x509.Certificate{ + Subject: testCR.Subject, + DNSNames: testCR.DNSNames, + EmailAddresses: testCR.EmailAddresses, + IPAddresses: testCR.IPAddresses, + URIs: testCR.URIs, + } + + testOtherCR, err := x509util.CreateCertificateRequest("Test Certificate", []string{"test.example.com"}, testKey) + require.NoError(t, err) + type fields struct { iss stepIssuer client *ca.Client @@ -648,6 +660,15 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }{ {"ok", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, + Lifetime: time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testCrt, + CertificateChain: []*x509.Certificate{testIssCrt}, + }, false}, + {"ok with different CSR", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ + CSR: testOtherCR, + Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, @@ -655,6 +676,7 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"ok with password", fields{x5cEnc, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, @@ -662,6 +684,7 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"ok jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, @@ -669,6 +692,7 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"ok jwk with password", fields{jwkEnc, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, @@ -676,6 +700,7 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"ok with provisioner", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, Lifetime: time.Hour, Provisioner: &apiv1.ProvisionerInfo{ID: "provisioner-id", Type: "ACME"}, }}, &apiv1.CreateCertificateResponse{ @@ -684,6 +709,7 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"ok with server cert", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, + Template: testTemplate, Lifetime: time.Hour, IsCAServerCert: true, }}, &apiv1.CreateCertificateResponse{ @@ -692,6 +718,12 @@ func TestStepCAS_CreateCertificate(t *testing.T) { }, false}, {"fail CSR", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: nil, + Template: testTemplate, + Lifetime: time.Hour, + }}, nil, true}, + {"fail Template", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ + CSR: testCR, + Template: nil, Lifetime: time.Hour, }}, nil, true}, {"fail lifetime", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{