mirror of
https://github.com/smallstep/certificates.git
synced 2024-10-31 03:20:16 +00:00
192 lines
5.4 KiB
Go
192 lines
5.4 KiB
Go
package ca
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"math/rand"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// RenewFunc defines the type of the functions used to get a new tls
|
|
// certificate.
|
|
type RenewFunc func() (*tls.Certificate, error)
|
|
|
|
var minCertDuration = time.Minute
|
|
|
|
// TLSRenewer automatically renews a tls certificate using a RenewFunc.
|
|
type TLSRenewer struct {
|
|
renewMutex sync.RWMutex
|
|
RenewCertificate RenewFunc
|
|
cert *tls.Certificate
|
|
timer *time.Timer
|
|
renewBefore time.Duration
|
|
renewJitter time.Duration
|
|
certNotAfter time.Time
|
|
}
|
|
|
|
type tlsRenewerOptions func(r *TLSRenewer) error
|
|
|
|
// WithRenewBefore modifies a tlsRenewer by setting the renewBefore attribute.
|
|
func WithRenewBefore(b time.Duration) func(r *TLSRenewer) error {
|
|
return func(r *TLSRenewer) error {
|
|
r.renewBefore = b
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// WithRenewJitter modifies a tlsRenewer by setting the renewJitter attribute.
|
|
func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error {
|
|
return func(r *TLSRenewer) error {
|
|
r.renewJitter = j
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given
|
|
// RenewFunc to get a new certificate when required.
|
|
func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) {
|
|
r := &TLSRenewer{
|
|
RenewCertificate: fn,
|
|
cert: cert,
|
|
certNotAfter: cert.Leaf.NotAfter.Add(-1 * time.Minute),
|
|
}
|
|
|
|
for _, f := range opts {
|
|
if err := f(r); err != nil {
|
|
return nil, errors.Wrap(err, "error applying options")
|
|
}
|
|
}
|
|
|
|
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
|
|
if period < minCertDuration {
|
|
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
|
|
}
|
|
// By default we will try to renew the cert before 2/3 of the validity
|
|
// period have expired.
|
|
if r.renewBefore == 0 {
|
|
r.renewBefore = period / 3
|
|
}
|
|
// By default we set the jitter to 1/20th of the validity period.
|
|
if r.renewJitter == 0 {
|
|
r.renewJitter = period / 20
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// Run starts the certificate renewer for the given certificate.
|
|
func (r *TLSRenewer) Run() {
|
|
cert := r.getCertificate()
|
|
next := r.nextRenewDuration(cert.Leaf.NotAfter)
|
|
r.renewMutex.Lock()
|
|
r.timer = time.AfterFunc(next, r.renewCertificate)
|
|
r.renewMutex.Unlock()
|
|
}
|
|
|
|
// RunContext starts the certificate renewer for the given certificate.
|
|
func (r *TLSRenewer) RunContext(ctx context.Context) {
|
|
r.Run()
|
|
go func() {
|
|
<-ctx.Done()
|
|
r.Stop()
|
|
}()
|
|
}
|
|
|
|
// Stop prevents the renew timer from firing.
|
|
func (r *TLSRenewer) Stop() bool {
|
|
if r.timer != nil {
|
|
return r.timer.Stop()
|
|
}
|
|
return true
|
|
}
|
|
|
|
// GetCertificate returns the current server certificate.
|
|
//
|
|
// This method is set in the tls.Config GetCertificate property.
|
|
func (r *TLSRenewer) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return r.getCertificate(), nil
|
|
}
|
|
|
|
// GetCertificateForCA returns the current server certificate. It can only be
|
|
// used if the renew function creates the new certificate and do not uses a TLS
|
|
// request. It's intended to be use by the certificate authority server.
|
|
//
|
|
// This method is set in the tls.Config GetCertificate property.
|
|
func (r *TLSRenewer) GetCertificateForCA(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return r.getCertificateForCA(), nil
|
|
}
|
|
|
|
// GetClientCertificate returns the current client certificate.
|
|
//
|
|
// This method is set in the tls.Config GetClientCertificate property.
|
|
func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return r.getCertificate(), nil
|
|
}
|
|
|
|
// getCertificate returns the certificate using a read-only lock.
|
|
//
|
|
// Known issue: It cannot renew an expired certificate because the /renew
|
|
// endpoint requires a valid client certificate. The certificate can expire
|
|
// if the timer does not fire e.g. when the CA is run from a laptop that
|
|
// enters sleep mode.
|
|
func (r *TLSRenewer) getCertificate() *tls.Certificate {
|
|
r.renewMutex.RLock()
|
|
cert := r.cert
|
|
r.renewMutex.RUnlock()
|
|
return cert
|
|
}
|
|
|
|
// getCertificateForCA returns the certificate using a read-only lock. It will
|
|
// automatically renew the certificate if it has expired.
|
|
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
|
|
r.renewMutex.RLock()
|
|
// Force certificate renewal if the timer didn't run.
|
|
// This is an special case that can happen after a computer sleep.
|
|
if time.Now().After(r.certNotAfter) {
|
|
r.renewMutex.RUnlock()
|
|
r.renewCertificate()
|
|
r.renewMutex.RLock()
|
|
}
|
|
cert := r.cert
|
|
r.renewMutex.RUnlock()
|
|
return cert
|
|
}
|
|
|
|
// setCertificate updates the certificate using a read-write lock. It also
|
|
// updates certNotAfter with 1m of delta; this will force the renewal of the
|
|
// certificate if it is about to expire.
|
|
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
|
|
r.renewMutex.Lock()
|
|
r.cert = cert
|
|
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
|
|
r.renewMutex.Unlock()
|
|
}
|
|
|
|
func (r *TLSRenewer) renewCertificate() {
|
|
var next time.Duration
|
|
cert, err := r.RenewCertificate()
|
|
if err != nil {
|
|
next = r.renewJitter / 2
|
|
next += time.Duration(rand.Int63n(int64(next)))
|
|
} else {
|
|
r.setCertificate(cert)
|
|
next = r.nextRenewDuration(cert.Leaf.NotAfter)
|
|
}
|
|
r.renewMutex.Lock()
|
|
r.timer.Reset(next)
|
|
r.renewMutex.Unlock()
|
|
}
|
|
|
|
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
|
|
d := time.Until(notAfter) - r.renewBefore
|
|
n := rand.Int63n(int64(r.renewJitter))
|
|
d -= time.Duration(n)
|
|
if d < 0 {
|
|
d = 0
|
|
}
|
|
return d
|
|
}
|