diff --git a/authority/authority.go b/authority/authority.go index b6fcdf23..d003d533 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -50,6 +50,7 @@ type Authority struct { rootX509CertPool *x509.CertPool federatedX509Certs []*x509.Certificate certificates *sync.Map + x509Enforcers []provisioner.CertificateEnforcer // SCEP CA scepService *scep.Service diff --git a/authority/options.go b/authority/options.go index a18d40e2..f92db99b 100644 --- a/authority/options.go +++ b/authority/options.go @@ -241,6 +241,15 @@ func WithLinkedCAToken(token string) Option { } } +// WithX509Enforcers is an option that allows to define custom certificate +// modifiers that will be processed just before the signing of the certificate. +func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option { + return func(a *Authority) error { + a.x509Enforcers = ces + return nil + } +} + func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { var block *pem.Block var certs []*x509.Certificate diff --git a/authority/tls.go b/authority/tls.go index cc049655..546a9399 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -180,6 +180,17 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } + // Process injected modifiers after validation + for _, m := range a.x509Enforcers { + if err := m.Enforce(leaf); err != nil { + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error creating certificate"), + opts..., + ) + } + } + + // Sign certificate lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ Template: leaf, diff --git a/authority/tls_test.go b/authority/tls_test.go index 3a0c999e..aeadaf0f 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -205,6 +205,17 @@ type basicConstraints struct { MaxPathLen int `asn1:"optional,default:-1"` } +type testEnforcer struct { + enforcer func(*x509.Certificate) error +} + +func (e *testEnforcer) Enforce(cert *x509.Certificate) error { + if e.enforcer != nil { + return e.enforcer(cert) + } + return nil +} + func TestAuthority_Sign(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() assert.FatalError(t, err) @@ -238,14 +249,15 @@ func TestAuthority_Sign(t *testing.T) { assert.FatalError(t, err) type signTest struct { - auth *Authority - csr *x509.CertificateRequest - signOpts provisioner.SignOptions - extraOpts []provisioner.SignOption - notBefore time.Time - notAfter time.Time - err error - code int + auth *Authority + csr *x509.CertificateRequest + signOpts provisioner.SignOptions + extraOpts []provisioner.SignOption + notBefore time.Time + notAfter time.Time + extensionsCount int + err error + code int } tests := map[string]func(*testing.T) *signTest{ "fail invalid signature": func(t *testing.T) *signTest { @@ -454,22 +466,66 @@ ZYtQ9Ot36qc= code: http.StatusInternalServerError, } }, - "ok": func(t *testing.T) *signTest { + "fail with provisioner enforcer": func(t *testing.T) *signTest { csr := getCSR(t, priv) - _a := testAuthority(t) - _a.db = &db.MockAuthDB{ + aa := testAuthority(t) + aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } + return &signTest{ - auth: a, + auth: aa, + csr: csr, + extraOpts: append(extraOpts, &testEnforcer{ + enforcer: func(crt *x509.Certificate) error { return fmt.Errorf("an error") }, + }), + signOpts: signOpts, + err: errors.New("error creating certificate"), + code: http.StatusForbidden, + } + }, + "fail with custom enforcer": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t, WithX509Enforcers(&testEnforcer{ + enforcer: func(cert *x509.Certificate) error { + return fmt.Errorf("an error") + }, + })) + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + return &signTest{ + auth: aa, csr: csr, extraOpts: extraOpts, signOpts: signOpts, - notBefore: signOpts.NotBefore.Time().Truncate(time.Second), - notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + err: errors.New("error creating certificate"), + code: http.StatusForbidden, + } + }, + "ok": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + _a := testAuthority(t) + _a.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + return nil + }, + } + return &signTest{ + auth: a, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, } }, "ok with enforced modifier": func(t *testing.T) *signTest { @@ -497,12 +553,13 @@ ZYtQ9Ot36qc= }, } return &signTest{ - auth: a, - csr: csr, - extraOpts: enforcedExtraOptions, - signOpts: signOpts, - notBefore: now.Truncate(time.Second), - notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), + auth: a, + csr: csr, + extraOpts: enforcedExtraOptions, + signOpts: signOpts, + notBefore: now.Truncate(time.Second), + notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), + extensionsCount: 6, } }, "ok with custom template": func(t *testing.T) *signTest { @@ -530,12 +587,13 @@ ZYtQ9Ot36qc= }, } return &signTest{ - auth: testAuthority, - csr: csr, - extraOpts: testExtraOpts, - signOpts: signOpts, - notBefore: signOpts.NotBefore.Time().Truncate(time.Second), - notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + auth: testAuthority, + csr: csr, + extraOpts: testExtraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, } }, "ok/csr with no template critical SAN extension": func(t *testing.T) *signTest { @@ -558,12 +616,39 @@ ZYtQ9Ot36qc= }, } return &signTest{ - auth: _a, - csr: csr, - extraOpts: enforcedExtraOptions, - signOpts: provisioner.SignOptions{}, - notBefore: now.Truncate(time.Second), - notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), + auth: _a, + csr: csr, + extraOpts: enforcedExtraOptions, + signOpts: provisioner.SignOptions{}, + notBefore: now.Truncate(time.Second), + notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), + extensionsCount: 5, + } + }, + "ok with custom enforcer": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + aa := testAuthority(t, WithX509Enforcers(&testEnforcer{ + enforcer: func(cert *x509.Certificate) error { + cert.CRLDistributionPoints = []string{"http://ca.example.org/leaf.crl"} + return nil + }, + })) + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + aa.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject.CommonName, "smallstep test") + assert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"}) + return nil + }, + } + return &signTest{ + auth: aa, + csr: csr, + extraOpts: extraOpts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 7, } }, } @@ -645,9 +730,6 @@ ZYtQ9Ot36qc= // Empty CSR subject test does not use any provisioner extensions. // So provisioner ID ext will be missing. found = 1 - assert.Len(t, 5, leaf.Extensions) - } else { - assert.Len(t, 6, leaf.Extensions) } } } @@ -655,6 +737,7 @@ ZYtQ9Ot36qc= realIntermediate, err := x509.ParseCertificate(issuer.Raw) assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) + assert.Len(t, tc.extensionsCount, leaf.Extensions) } } })