From d1deb7f93066e15e39ff6b9ca0d341e8213981cb Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 8 Feb 2024 14:10:48 +0100 Subject: [PATCH 1/4] Add `Expires` header to CRL response --- api/api.go | 2 +- api/api_test.go | 6 +++--- api/crl.go | 9 ++++++--- authority/tls.go | 17 +++++++++++++++-- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/api/api.go b/api/api.go index 7cf44a11..5d96cc45 100644 --- a/api/api.go +++ b/api/api.go @@ -54,7 +54,7 @@ type Authority interface { GetRoots() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version - GetCertificateRevocationList() ([]byte, error) + GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error) } // mustAuthority will be replaced on unit tests. diff --git a/api/api_test.go b/api/api_test.go index b3c01816..a62b34e8 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -200,7 +200,7 @@ type mockAuthority struct { getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) - getCRL func() ([]byte, error) + getCRL func() (*authority.CertificateRevocationListInfo, error) signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) @@ -214,12 +214,12 @@ type mockAuthority struct { version func() authority.Version } -func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) { +func (m *mockAuthority) GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error) { if m.getCRL != nil { return m.getCRL() } - return m.ret1.([]byte), m.err + return m.ret1.(*authority.CertificateRevocationListInfo), m.err } // TODO: remove once Authorize is deprecated. diff --git a/api/crl.go b/api/crl.go index 6386f34a..7f12c6f8 100644 --- a/api/crl.go +++ b/api/crl.go @@ -3,18 +3,21 @@ package api import ( "encoding/pem" "net/http" + "time" "github.com/smallstep/certificates/api/render" ) // CRL is an HTTP handler that returns the current CRL in DER or PEM format func CRL(w http.ResponseWriter, r *http.Request) { - crlBytes, err := mustAuthority(r.Context()).GetCertificateRevocationList() + crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList() if err != nil { render.Error(w, err) return } + w.Header().Add("Expires", crlInfo.ExpiresAt.Format(time.RFC1123)) + _, formatAsPEM := r.URL.Query()["pem"] if formatAsPEM { w.Header().Add("Content-Type", "application/x-pem-file") @@ -22,11 +25,11 @@ func CRL(w http.ResponseWriter, r *http.Request) { _ = pem.Encode(w, &pem.Block{ Type: "X509 CRL", - Bytes: crlBytes, + Bytes: crlInfo.Data, }) } else { w.Header().Add("Content-Type", "application/pkix-crl") w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"") - w.Write(crlBytes) + w.Write(crlInfo.Data) } } diff --git a/authority/tls.go b/authority/tls.go index 0dd6eb54..fa170d44 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -696,9 +696,17 @@ func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateIn return a.db.RevokeSSH(rci) } +// CertificateRevocationListInfo contains a CRL in DER format and associated metadata. +type CertificateRevocationListInfo struct { + Number int64 + ExpiresAt time.Time + Duration time.Duration + Data []byte +} + // GetCertificateRevocationList will return the currently generated CRL from the DB, or a not implemented // error if the underlying AuthDB does not support CRLs -func (a *Authority) GetCertificateRevocationList() ([]byte, error) { +func (a *Authority) GetCertificateRevocationList() (*CertificateRevocationListInfo, error) { if !a.config.CRL.IsEnabled() { return nil, errs.Wrap(http.StatusNotFound, errors.Errorf("Certificate Revocation Lists are not enabled"), "authority.GetCertificateRevocationList") } @@ -713,7 +721,12 @@ func (a *Authority) GetCertificateRevocationList() ([]byte, error) { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetCertificateRevocationList") } - return crlInfo.DER, nil + return &CertificateRevocationListInfo{ + Number: crlInfo.Number, + ExpiresAt: crlInfo.ExpiresAt, + Duration: crlInfo.Duration, + Data: crlInfo.DER, + }, nil } // GenerateCertificateRevocationList generates a DER representation of a signed CRL and stores it in the From 69f5f8d8eaada72fbf5b30cded594cbf4f3339c7 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 8 Feb 2024 14:11:13 +0100 Subject: [PATCH 2/4] Use `stretchr/testify` instead of `smallstep/assert` for tests --- authority/tls_test.go | 354 +++++++++++++++++++++--------------------- 1 file changed, 177 insertions(+), 177 deletions(-) diff --git a/authority/tls_test.go b/authority/tls_test.go index efcb78f8..1fb8411a 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -24,7 +24,7 @@ import ( "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/policy" @@ -33,6 +33,8 @@ import ( "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/nosql/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -80,25 +82,25 @@ func generateCertificate(t *testing.T, commonName string, sans []string, opts .. t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.FatalError(t, err) + require.NoError(t, err) cr, err := x509util.CreateCertificateRequest(commonName, sans, priv) - assert.FatalError(t, err) + require.NoError(t, err) template, err := x509util.NewCertificate(cr) - assert.FatalError(t, err) + require.NoError(t, err) cert := template.GetCertificate() for _, m := range opts { switch m := m.(type) { case provisioner.CertificateModifierFunc: err = m.Modify(cert, provisioner.SignOptions{}) - assert.FatalError(t, err) + require.NoError(t, err) case signerFunc: cert, err = m(cert, priv.Public()) - assert.FatalError(t, err) + require.NoError(t, err) default: - t.Fatalf("unknown type %T", m) + require.Fail(t, "", "unknown type %T", m) } } @@ -108,36 +110,36 @@ func generateCertificate(t *testing.T, commonName string, sans []string, opts .. func generateRootCertificate(t *testing.T) (*x509.Certificate, crypto.Signer) { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.FatalError(t, err) + require.NoError(t, err) cr, err := x509util.CreateCertificateRequest("TestRootCA", nil, priv) - assert.FatalError(t, err) + require.NoError(t, err) data := x509util.CreateTemplateData("TestRootCA", nil) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) - assert.FatalError(t, err) + require.NoError(t, err) cert := template.GetCertificate() cert, err = x509util.CreateCertificate(cert, cert, priv.Public(), priv) - assert.FatalError(t, err) + require.NoError(t, err) return cert, priv } func generateIntermidiateCertificate(t *testing.T, issuer *x509.Certificate, signer crypto.Signer) (*x509.Certificate, crypto.Signer) { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.FatalError(t, err) + require.NoError(t, err) cr, err := x509util.CreateCertificateRequest("TestIntermediateCA", nil, priv) - assert.FatalError(t, err) + require.NoError(t, err) data := x509util.CreateTemplateData("TestIntermediateCA", nil) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) - assert.FatalError(t, err) + require.NoError(t, err) cert := template.GetCertificate() cert, err = x509util.CreateCertificate(cert, issuer, priv.Public(), signer) - assert.FatalError(t, err) + require.NoError(t, err) return cert, priv } @@ -192,9 +194,9 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques opt(_csr) } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv) - assert.FatalError(t, err) + require.NoError(t, err) csr, err := x509.ParseCertificateRequest(csrBytes) - assert.FatalError(t, err) + require.NoError(t, err) return csr } @@ -239,10 +241,10 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error { func TestAuthority_Sign(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() - assert.FatalError(t, err) + require.NoError(t, err) a := testAuthority(t) - assert.FatalError(t, err) + require.NoError(t, err) a.config.AuthorityConfig.Template = &ASN1DN{ Country: "Tazmania", Organization: "Acme Co", @@ -262,12 +264,12 @@ func TestAuthority_Sign(t *testing.T) { // Create a token to get test extra opts. p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) - assert.FatalError(t, err) + require.NoError(t, err) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) - assert.FatalError(t, err) + require.NoError(t, err) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) extraOpts, err := a.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) type signTest struct { auth *Authority @@ -372,9 +374,9 @@ W5kR63lNVHBHgQmv5mA8YFsfrJHstaz5k727v2LMHEYIf5/3i16d5zhuxUoaPTYr ZYtQ9Ot36qc= -----END CERTIFICATE REQUEST-----` block, _ := pem.Decode([]byte(shortRSAKeyPEM)) - assert.FatalError(t, err) + require.NoError(t, err) csr, err := x509.ParseCertificateRequest(block.Bytes) - assert.FatalError(t, err) + require.NoError(t, err) return &signTest{ auth: a, @@ -413,10 +415,10 @@ ZYtQ9Ot36qc= X509: &provisioner.X509Options{Template: `{{ fail "fail message" }}`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -442,10 +444,10 @@ ZYtQ9Ot36qc= }, } testExtraOpts, err := testAuthority.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -471,10 +473,10 @@ ZYtQ9Ot36qc= }, } testExtraOpts, err := testAuthority.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -492,7 +494,7 @@ ZYtQ9Ot36qc= aa := testAuthority(t) aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -517,7 +519,7 @@ ZYtQ9Ot36qc= })) aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -537,7 +539,7 @@ ZYtQ9Ot36qc= aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { fmt.Println(crt.Subject) - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -549,7 +551,7 @@ ZYtQ9Ot36qc= }, } engine, err := policy.New(options) - assert.FatalError(t, err) + require.NoError(t, err) aa.policyEngine = engine return &signTest{ auth: aa, @@ -598,7 +600,7 @@ ZYtQ9Ot36qc= _a := testAuthority(t) _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -617,7 +619,7 @@ ZYtQ9Ot36qc= bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19} bcExt.Critical = false bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4}) - assert.FatalError(t, err) + require.NoError(t, err) csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{ bcExt, @@ -632,7 +634,7 @@ ZYtQ9Ot36qc= _a := testAuthority(t) _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -663,10 +665,10 @@ ZYtQ9Ot36qc= }`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -697,10 +699,10 @@ ZYtQ9Ot36qc= }`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) - assert.FatalError(t, err) + require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -737,7 +739,7 @@ ZYtQ9Ot36qc= _a.config.AuthorityConfig.Template = &ASN1DN{} _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject, pkix.Name{}) + sassert.Equals(t, crt.Subject, pkix.Name{}) return nil }, } @@ -762,8 +764,8 @@ ZYtQ9Ot36qc= aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") - assert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"}) + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"}) return nil }, } @@ -783,7 +785,7 @@ ZYtQ9Ot36qc= aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { - assert.Equals(t, crt.Subject.CommonName, "smallstep test") + sassert.Equals(t, crt.Subject.CommonName, "smallstep test") return nil }, } @@ -796,7 +798,7 @@ ZYtQ9Ot36qc= }, } engine, err := policy.New(options) - assert.FatalError(t, err) + require.NoError(t, err) aa.policyEngine = engine return &signTest{ auth: aa, @@ -816,13 +818,13 @@ ZYtQ9Ot36qc= MStoreCertificateChain: func(prov provisioner.Interface, certs ...*x509.Certificate) error { p, ok := prov.(attProvisioner) if assert.True(t, ok) { - assert.Equals(t, &provisioner.AttestationData{ + sassert.Equals(t, &provisioner.AttestationData{ PermanentIdentifier: "1234567890", }, p.AttestationData()) } - if assert.Len(t, 2, certs) { - assert.Equals(t, certs[0].Subject.CommonName, "smallstep test") - assert.Equals(t, certs[1].Subject.CommonName, "smallstep Intermediate CA") + if assert.Len(t, certs, 2) { + sassert.Equals(t, certs[0].Subject.CommonName, "smallstep test") + sassert.Equals(t, certs[1].Subject.CommonName, "smallstep Intermediate CA") } return nil }, @@ -851,26 +853,26 @@ ZYtQ9Ot36qc= if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError - assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - assert.HasPrefix(t, err.Error(), tc.err.Error()) + sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") + sassert.Equals(t, sc.StatusCode(), tc.code) + sassert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error - assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") - assert.Equals(t, ctxErr.Details["csr"], tc.csr) - assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts) + sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") + sassert.Equals(t, ctxErr.Details["csr"], tc.csr) + sassert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { - assert.Equals(t, leaf.NotBefore, tc.notBefore) - assert.Equals(t, leaf.NotAfter, tc.notAfter) + sassert.Equals(t, leaf.NotBefore, tc.notBefore) + sassert.Equals(t, leaf.NotAfter, tc.notAfter) tmplt := a.config.AuthorityConfig.Template if tc.csr.Subject.CommonName == "" { - assert.Equals(t, leaf.Subject, pkix.Name{}) + sassert.Equals(t, leaf.Subject, pkix.Name{}) } else { - assert.Equals(t, leaf.Subject.String(), + sassert.Equals(t, leaf.Subject.String(), pkix.Name{ Country: []string{tmplt.Country}, Organization: []string{tmplt.Organization}, @@ -879,18 +881,18 @@ ZYtQ9Ot36qc= Province: []string{tmplt.Province}, CommonName: "smallstep test", }.String()) - assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) + sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) } - assert.Equals(t, leaf.Issuer, intermediate.Subject) - assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) - assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) - assert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) + sassert.Equals(t, leaf.Issuer, intermediate.Subject) + sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) + sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) + sassert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) issuer := getDefaultIssuer(a) subjectKeyID, err := generateSubjectKeyID(pub) - assert.FatalError(t, err) - assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) - assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) + require.NoError(t, err) + sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID) + sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) // Verify Provisioner OID found := 0 @@ -900,18 +902,18 @@ ZYtQ9Ot36qc= found++ val := stepProvisionerASN1{} _, err := asn1.Unmarshal(ext.Value, &val) - assert.FatalError(t, err) - assert.Equals(t, val.Type, provisionerTypeJWK) - assert.Equals(t, val.Name, []byte(p.Name)) - assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) + require.NoError(t, err) + sassert.Equals(t, val.Type, provisionerTypeJWK) + sassert.Equals(t, val.Name, []byte(p.Name)) + sassert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) // Basic Constraints case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): val := basicConstraints{} _, err := asn1.Unmarshal(ext.Value, &val) - assert.FatalError(t, err) + require.NoError(t, err) assert.False(t, val.IsCA, false) - assert.Equals(t, val.MaxPathLen, 0) + sassert.Equals(t, val.MaxPathLen, 0) // SAN extension case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})): @@ -922,11 +924,11 @@ ZYtQ9Ot36qc= } } } - assert.Equals(t, found, 1) + sassert.Equals(t, found, 1) realIntermediate, err := x509.ParseCertificate(issuer.Raw) - assert.FatalError(t, err) - assert.Equals(t, intermediate, realIntermediate) - assert.Len(t, tc.extensionsCount, leaf.Extensions) + require.NoError(t, err) + sassert.Equals(t, intermediate, realIntermediate) + assert.Len(t, leaf.Extensions, tc.extensionsCount) } } }) @@ -1056,7 +1058,7 @@ func TestAuthority_Renew(t *testing.T) { for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() - assert.FatalError(t, err) + require.NoError(t, err) var certChain []*x509.Certificate if tc.auth != nil { @@ -1068,19 +1070,19 @@ func TestAuthority_Renew(t *testing.T) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError - assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - assert.HasPrefix(t, err.Error(), tc.err.Error()) + sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") + sassert.Equals(t, sc.StatusCode(), tc.code) + sassert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error - assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") - assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) + sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") + sassert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { - assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) + sassert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) @@ -1090,30 +1092,30 @@ func TestAuthority_Renew(t *testing.T) { assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template - assert.Equals(t, leaf.RawSubject, tc.cert.RawSubject) - assert.Equals(t, leaf.Subject.Country, []string{tmplt.Country}) - assert.Equals(t, leaf.Subject.Organization, []string{tmplt.Organization}) - assert.Equals(t, leaf.Subject.Locality, []string{tmplt.Locality}) - assert.Equals(t, leaf.Subject.StreetAddress, []string{tmplt.StreetAddress}) - assert.Equals(t, leaf.Subject.Province, []string{tmplt.Province}) - assert.Equals(t, leaf.Subject.CommonName, tmplt.CommonName) + sassert.Equals(t, leaf.RawSubject, tc.cert.RawSubject) + sassert.Equals(t, leaf.Subject.Country, []string{tmplt.Country}) + sassert.Equals(t, leaf.Subject.Organization, []string{tmplt.Organization}) + sassert.Equals(t, leaf.Subject.Locality, []string{tmplt.Locality}) + sassert.Equals(t, leaf.Subject.StreetAddress, []string{tmplt.StreetAddress}) + sassert.Equals(t, leaf.Subject.Province, []string{tmplt.Province}) + sassert.Equals(t, leaf.Subject.CommonName, tmplt.CommonName) - assert.Equals(t, leaf.Issuer, intermediate.Subject) + sassert.Equals(t, leaf.Issuer, intermediate.Subject) - assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) - assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) - assert.Equals(t, leaf.ExtKeyUsage, + sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) + sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) + sassert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) - assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) + sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) subjectKeyID, err := generateSubjectKeyID(leaf.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) + require.NoError(t, err) + sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID) // We did not change the intermediate before renewing. authIssuer := getDefaultIssuer(tc.auth) if issuer.SerialNumber == authIssuer.SerialNumber { - assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) + sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier @@ -1133,7 +1135,7 @@ func TestAuthority_Renew(t *testing.T) { } } else { // We did change the intermediate before renewing. - assert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) + sassert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier @@ -1161,8 +1163,8 @@ func TestAuthority_Renew(t *testing.T) { } realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) - assert.FatalError(t, err) - assert.Equals(t, intermediate, realIntermediate) + require.NoError(t, err) + sassert.Equals(t, intermediate, realIntermediate) } } }) @@ -1171,7 +1173,7 @@ func TestAuthority_Renew(t *testing.T) { func TestAuthority_Rekey(t *testing.T) { pub, _, err := keyutil.GenerateDefaultKeyPair() - assert.FatalError(t, err) + require.NoError(t, err) a := testAuthority(t) a.config.AuthorityConfig.Template = &ASN1DN{ @@ -1261,7 +1263,7 @@ func TestAuthority_Rekey(t *testing.T) { for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() - assert.FatalError(t, err) + require.NoError(t, err) var certChain []*x509.Certificate if tc.auth != nil { @@ -1273,19 +1275,19 @@ func TestAuthority_Rekey(t *testing.T) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError - assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - assert.HasPrefix(t, err.Error(), tc.err.Error()) + sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") + sassert.Equals(t, sc.StatusCode(), tc.code) + sassert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error - assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") - assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) + sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") + sassert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { - assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) + sassert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) @@ -1295,7 +1297,7 @@ func TestAuthority_Rekey(t *testing.T) { assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template - assert.Equals(t, leaf.Subject.String(), + sassert.Equals(t, leaf.Subject.String(), pkix.Name{ Country: []string{tmplt.Country}, Organization: []string{tmplt.Organization}, @@ -1304,32 +1306,32 @@ func TestAuthority_Rekey(t *testing.T) { Province: []string{tmplt.Province}, CommonName: tmplt.CommonName, }.String()) - assert.Equals(t, leaf.Issuer, intermediate.Subject) + sassert.Equals(t, leaf.Issuer, intermediate.Subject) - assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) - assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) - assert.Equals(t, leaf.ExtKeyUsage, + sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) + sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) + sassert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) - assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) + sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) // Test Public Key and SubjectKeyId expectedPK := tc.pk if tc.pk == nil { expectedPK = cert.PublicKey } - assert.Equals(t, leaf.PublicKey, expectedPK) + sassert.Equals(t, leaf.PublicKey, expectedPK) subjectKeyID, err := generateSubjectKeyID(expectedPK) - assert.FatalError(t, err) - assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) + require.NoError(t, err) + sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID) if tc.pk == nil { - assert.Equals(t, leaf.SubjectKeyId, cert.SubjectKeyId) + sassert.Equals(t, leaf.SubjectKeyId, cert.SubjectKeyId) } // We did not change the intermediate before renewing. authIssuer := getDefaultIssuer(tc.auth) if issuer.SerialNumber == authIssuer.SerialNumber { - assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) + sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier @@ -1349,7 +1351,7 @@ func TestAuthority_Rekey(t *testing.T) { } } else { // We did change the intermediate before renewing. - assert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) + sassert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier @@ -1377,8 +1379,8 @@ func TestAuthority_Rekey(t *testing.T) { } realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) - assert.FatalError(t, err) - assert.Equals(t, intermediate, realIntermediate) + require.NoError(t, err) + sassert.Equals(t, intermediate, realIntermediate) } } }) @@ -1413,10 +1415,10 @@ func TestAuthority_GetTLSOptions(t *testing.T) { for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() - assert.FatalError(t, err) + require.NoError(t, err) opts := tc.auth.GetTLSOptions() - assert.Equals(t, opts, tc.opts) + sassert.Equals(t, opts, tc.opts) }) } } @@ -1429,11 +1431,11 @@ func TestAuthority_Revoke(t *testing.T) { now := time.Now().UTC() jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) - assert.FatalError(t, err) + require.NoError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) - assert.FatalError(t, err) + require.NoError(t, err) a := testAuthority(t) @@ -1472,7 +1474,7 @@ func TestAuthority_Revoke(t *testing.T) { ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: a, @@ -1486,9 +1488,9 @@ func TestAuthority_Revoke(t *testing.T) { err: errors.New("authority.Revoke; no persistence layer configured"), code: http.StatusNotImplemented, checkErrDetails: func(err *errs.Error) { - assert.Equals(t, err.Details["token"], raw) - assert.Equals(t, err.Details["tokenID"], "44") - assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + sassert.Equals(t, err.Details["token"], raw) + sassert.Equals(t, err.Details["tokenID"], "44") + sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") }, } }, @@ -1512,7 +1514,7 @@ func TestAuthority_Revoke(t *testing.T) { ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: _a, @@ -1526,9 +1528,9 @@ func TestAuthority_Revoke(t *testing.T) { err: errors.New("authority.Revoke: force"), code: http.StatusInternalServerError, checkErrDetails: func(err *errs.Error) { - assert.Equals(t, err.Details["token"], raw) - assert.Equals(t, err.Details["tokenID"], "44") - assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + sassert.Equals(t, err.Details["token"], raw) + sassert.Equals(t, err.Details["tokenID"], "44") + sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") }, } }, @@ -1552,7 +1554,7 @@ func TestAuthority_Revoke(t *testing.T) { ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: _a, @@ -1566,9 +1568,9 @@ func TestAuthority_Revoke(t *testing.T) { err: errors.New("certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { - assert.Equals(t, err.Details["token"], raw) - assert.Equals(t, err.Details["tokenID"], "44") - assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") + sassert.Equals(t, err.Details["token"], raw) + sassert.Equals(t, err.Details["tokenID"], "44") + sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") }, } }, @@ -1591,7 +1593,7 @@ func TestAuthority_Revoke(t *testing.T) { ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, @@ -1607,7 +1609,7 @@ func TestAuthority_Revoke(t *testing.T) { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: _a, @@ -1625,7 +1627,7 @@ func TestAuthority_Revoke(t *testing.T) { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) + require.NoError(t, err) // Filter out provisioner extension. for i, ext := range crt.Extensions { if ext.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) { @@ -1650,7 +1652,7 @@ func TestAuthority_Revoke(t *testing.T) { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: _a, @@ -1683,7 +1685,7 @@ func TestAuthority_Revoke(t *testing.T) { ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) return test{ auth: a, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), @@ -1702,17 +1704,17 @@ func TestAuthority_Revoke(t *testing.T) { if err := tc.auth.Revoke(tc.ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { var sc render.StatusCodedError - assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - assert.HasPrefix(t, err.Error(), tc.err.Error()) + sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") + sassert.Equals(t, sc.StatusCode(), tc.code) + sassert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error - assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") - assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial) - assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode) - assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason) - assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS) - assert.Equals(t, ctxErr.Details["context"], provisioner.RevokeMethod.String()) + sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") + sassert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial) + sassert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode) + sassert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason) + sassert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS) + sassert.Equals(t, ctxErr.Details["context"], provisioner.RevokeMethod.String()) if tc.checkErrDetails != nil { tc.checkErrDetails(ctxErr) @@ -1814,13 +1816,11 @@ func TestAuthority_CRL(t *testing.T) { validIssuer := "step-cli" validAudience := testAudiences.Revoke now := time.Now().UTC() - // jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) - assert.FatalError(t, err) - // + require.NoError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) - assert.FatalError(t, err) + require.NoError(t, err) crlCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) @@ -1865,7 +1865,7 @@ func TestAuthority_CRL(t *testing.T) { auth: a, ctx: crlCtx, expected: nil, - err: database.ErrNotFound, + err: errors.New("authority.GetCertificateRevocationList: not found"), } }, "ok/crl-full": func() test { @@ -1910,7 +1910,7 @@ func TestAuthority_CRL(t *testing.T) { ID: sn, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - assert.FatalError(t, err) + require.NoError(t, err) err = a.Revoke(crlCtx, &RevokeOptions{ Serial: sn, ReasonCode: reasonCode, @@ -1918,7 +1918,7 @@ func TestAuthority_CRL(t *testing.T) { OTT: raw, }) - assert.FatalError(t, err) + require.NoError(t, err) ex = append(ex, sn) } @@ -1933,22 +1933,22 @@ func TestAuthority_CRL(t *testing.T) { for name, f := range tests { tc := f() t.Run(name, func(t *testing.T) { - if crlBytes, err := tc.auth.GetCertificateRevocationList(); err == nil { - crl, parseErr := x509.ParseRevocationList(crlBytes) - if parseErr != nil { - t.Errorf("x509.ParseCertificateRequest() error = %v, wantErr %v", parseErr, nil) - return - } - - var cmpList []string - for _, c := range crl.RevokedCertificates { - cmpList = append(cmpList, c.SerialNumber.String()) - } - - assert.Equals(t, cmpList, tc.expected) - } else { - assert.NotNil(t, tc.err, err.Error()) + crlInfo, err := tc.auth.GetCertificateRevocationList() + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + assert.Nil(t, crlInfo) + return } + + crl, parseErr := x509.ParseRevocationList(crlInfo.Data) + require.NoError(t, parseErr) + + var cmpList []string + for _, c := range crl.RevokedCertificateEntries { + cmpList = append(cmpList, c.SerialNumber.String()) + } + + assert.Equal(t, tc.expected, cmpList) }) } } From c76dad8a22b81482994b3599280607a5cb990c84 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 8 Feb 2024 15:03:46 +0100 Subject: [PATCH 3/4] Improve tests for CRL HTTP handler --- api/api_test.go | 39 --------------------- api/crl.go | 13 ++++++- api/crl_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 40 deletions(-) create mode 100644 api/crl_test.go diff --git a/api/api_test.go b/api/api_test.go index a62b34e8..28944a1e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -789,45 +789,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) ( return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } -func Test_CRLGeneration(t *testing.T) { - tests := []struct { - name string - err error - statusCode int - expected []byte - }{ - {"empty", nil, http.StatusOK, nil}, - } - - chiCtx := chi.NewRouteContext() - req := httptest.NewRequest("GET", "http://example.com/crl", http.NoBody) - req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err}) - w := httptest.NewRecorder() - CRL(w, req) - res := w.Result() - - if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) - } - - body, err := io.ReadAll(res.Body) - res.Body.Close() - if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) - } - if tt.statusCode == 200 { - if !bytes.Equal(bytes.TrimSpace(body), tt.expected) { - t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected) - } - } - }) - } -} - func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/api/crl.go b/api/crl.go index 7f12c6f8..a94056ad 100644 --- a/api/crl.go +++ b/api/crl.go @@ -6,6 +6,7 @@ import ( "time" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/errs" ) // CRL is an HTTP handler that returns the current CRL in DER or PEM format @@ -16,7 +17,17 @@ func CRL(w http.ResponseWriter, r *http.Request) { return } - w.Header().Add("Expires", crlInfo.ExpiresAt.Format(time.RFC1123)) + if crlInfo == nil { + render.Error(w, errs.New(http.StatusInternalServerError, "no CRL available")) + return + } + + expires := crlInfo.ExpiresAt + if expires.IsZero() { + expires = time.Now() + } + + w.Header().Add("Expires", expires.Format(time.RFC1123)) _, formatAsPEM := r.URL.Query()["pem"] if formatAsPEM { diff --git a/api/crl_test.go b/api/crl_test.go new file mode 100644 index 00000000..c1c7a4b0 --- /dev/null +++ b/api/crl_test.go @@ -0,0 +1,93 @@ +package api + +import ( + "bytes" + "context" + "encoding/pem" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CRL(t *testing.T) { + data := []byte{1, 2, 3, 4} + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: data, + }) + pemData = bytes.TrimSpace(pemData) + emptyPEMData := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: nil, + }) + emptyPEMData = bytes.TrimSpace(emptyPEMData) + tests := []struct { + name string + url string + err error + statusCode int + crlInfo *authority.CertificateRevocationListInfo + expectedBody []byte + expectedHeaders http.Header + expectedErrorJSON string + }{ + {"ok", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, data, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""}, + {"ok/pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, pemData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, + {"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""}, + {"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, + {"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`}, + {"fail/nil", "http://example.com/crl", nil, http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"no CRL available"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMustAuthority(t, &mockAuthority{ret1: tt.crlInfo, err: tt.err}) + + chiCtx := chi.NewRouteContext() + req := httptest.NewRequest("GET", tt.url, http.NoBody) + req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) + w := httptest.NewRecorder() + CRL(w, req) + res := w.Result() + + assert.Equal(t, tt.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + require.NoError(t, err) + + if tt.statusCode >= 300 { + assert.JSONEq(t, tt.expectedErrorJSON, string(bytes.TrimSpace(body))) + return + } + + // check expected header values + for _, h := range []string{"content-type", "content-disposition"} { + v := tt.expectedHeaders.Get(h) + require.NotEmpty(t, v) + + actual := res.Header.Get(h) + assert.Equal(t, v, actual) + } + + // check expires header value + assert.NotEmpty(t, res.Header.Get("expires")) + t1, err := time.Parse(time.RFC1123, res.Header.Get("expires")) + if assert.NoError(t, err) { + assert.False(t, t1.IsZero()) + } + + // check body contents + assert.Equal(t, tt.expectedBody, bytes.TrimSpace(body)) + }) + } +} From 3dbb4aad3dc338f3d838f2c153019a84064c7eab Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 14 Feb 2024 10:49:18 +0100 Subject: [PATCH 4/4] Change CRL unavailable case to HTTP 404 --- api/crl.go | 2 +- api/crl_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/crl.go b/api/crl.go index a94056ad..c10d08ca 100644 --- a/api/crl.go +++ b/api/crl.go @@ -18,7 +18,7 @@ func CRL(w http.ResponseWriter, r *http.Request) { } if crlInfo == nil { - render.Error(w, errs.New(http.StatusInternalServerError, "no CRL available")) + render.Error(w, errs.New(http.StatusNotFound, "no CRL available")) return } diff --git a/api/crl_test.go b/api/crl_test.go index c1c7a4b0..5b194721 100644 --- a/api/crl_test.go +++ b/api/crl_test.go @@ -45,7 +45,7 @@ func Test_CRL(t *testing.T) { {"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""}, {"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, {"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`}, - {"fail/nil", "http://example.com/crl", nil, http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"no CRL available"}`}, + {"fail/nil", "http://example.com/crl", nil, http.StatusNotFound, nil, nil, http.Header{}, `{"status":404,"message":"no CRL available"}`}, } for _, tt := range tests {