diff --git a/api/api.go b/api/api.go index fa7602e1..f20df474 100644 --- a/api/api.go +++ b/api/api.go @@ -46,7 +46,6 @@ type Authority interface { GetRoots() (federation []*x509.Certificate, err error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version - GenerateCertificateRevocationList() error GetCertificateRevocationList() ([]byte, error) } diff --git a/api/api_test.go b/api/api_test.go index aef6db77..7ce54d73 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -580,7 +580,7 @@ type mockAuthority struct { version func() authority.Version } -func (m *mockAuthority) GenerateCertificateRevocationList(force bool) ([]byte, error) { +func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) { panic("implement me") } diff --git a/api/crl.go b/api/crl.go index c90a77f9..45024470 100644 --- a/api/crl.go +++ b/api/crl.go @@ -4,6 +4,7 @@ import ( "encoding/pem" "fmt" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "net/http" ) @@ -14,6 +15,14 @@ func (h *caHandler) CRL(w http.ResponseWriter, r *http.Request) { _, formatAsPEM := r.URL.Query()["pem"] if err != nil { + + caErr, isCaErr := err.(*errs.Error) + + if isCaErr { + http.Error(w, caErr.Msg, caErr.Status) + return + } + w.WriteHeader(500) _, err = fmt.Fprintf(w, "%v\n", err) if err != nil { @@ -22,15 +31,6 @@ func (h *caHandler) CRL(w http.ResponseWriter, r *http.Request) { return } - if crlBytes == nil { - w.WriteHeader(404) - _, err = fmt.Fprintln(w, "No CRL available") - if err != nil { - panic(errors.Wrap(err, "error writing http response")) - } - return - } - if formatAsPEM { pemBytes := pem.EncodeToMemory(&pem.Block{ Type: "X509 CRL", diff --git a/authority/authority.go b/authority/authority.go index 5f52d04e..7414a0d4 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -66,7 +66,7 @@ type Authority struct { sshCAHostFederatedCerts []ssh.PublicKey // CRL vars - crlChannel chan int + crlTicker *time.Ticker // Do not re-initialize initOnce bool @@ -586,6 +586,10 @@ func (a *Authority) IsAdminAPIEnabled() bool { // Shutdown safely shuts down any clients, databases, etc. held by the Authority. func (a *Authority) Shutdown() error { + if a.crlTicker != nil { + a.crlTicker.Stop() + } + if err := a.keyManager.Close(); err != nil { log.Printf("error closing the key manager: %v", err) } @@ -594,6 +598,11 @@ func (a *Authority) Shutdown() error { // CloseForReload closes internal services, to allow a safe reload. func (a *Authority) CloseForReload() { + + if a.crlTicker != nil { + a.crlTicker.Stop() + } + if err := a.keyManager.Close(); err != nil { log.Printf("error closing the key manager: %v", err) } @@ -655,12 +664,12 @@ func (a *Authority) startCRLGenerator() error { if tickerDuration <= 0 { panic(fmt.Sprintf("ERROR: Addition of jitter to CRL generation time %v creates a negative duration (%v). Use a CRL generation time of longer than 1 minute.", a.config.CRL.CacheDuration, tickerDuration)) } - crlTicker := time.NewTicker(tickerDuration) + a.crlTicker = time.NewTicker(tickerDuration) go func() { for { select { - case <-crlTicker.C: + case <-a.crlTicker.C: log.Println("Regenerating CRL") err := a.GenerateCertificateRevocationList() if err != nil { diff --git a/authority/tls.go b/authority/tls.go index ba1c4939..44da6b3a 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -365,13 +365,15 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error err error ) - // Attempt to get the certificate expiry using the serial number. - cert, err := a.db.GetCertificate(revokeOpts.Serial) + if revokeOpts.Crt == nil { + // Attempt to get the certificate expiry using the serial number. + cert, err := a.db.GetCertificate(revokeOpts.Serial) - // Revocation of a certificate not in the database may be requested, so fill in the expiry only - // if we can - if err == nil { - rci.ExpiresAt = cert.NotAfter + // Revocation of a certificate not in the database may be requested, so fill in the expiry only + // if we can + if err == nil { + rci.ExpiresAt = cert.NotAfter + } } // If not mTLS then get the TokenID of the token. diff --git a/db/db.go b/db/db.go index 21cea901..883f54ed 100644 --- a/db/db.go +++ b/db/db.go @@ -215,13 +215,15 @@ func (db *DB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { return nil, err } var revokedCerts []RevokedCertificateInfo + now := time.Now().UTC() + for _, e := range entries { var data RevokedCertificateInfo if err := json.Unmarshal(e.Value, &data); err != nil { return nil, err } - if !data.ExpiresAt.IsZero() && data.ExpiresAt.After(time.Now().UTC()) { + if !data.ExpiresAt.IsZero() && data.ExpiresAt.After(now) { revokedCerts = append(revokedCerts, data) } else if data.ExpiresAt.IsZero() { cert, err := db.GetCertificate(data.Serial) @@ -232,7 +234,7 @@ func (db *DB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { continue } - if cert.NotAfter.After(time.Now().UTC()) { + if cert.NotAfter.After(now) { revokedCerts = append(revokedCerts, data) } }