smallstep-certificates/acme/common.go
2024-02-27 14:16:21 +01:00

203 lines
6.5 KiB
Go

package acme
import (
"context"
"crypto/x509"
"time"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
)
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
}
var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// PrerequisitesChecker is a function that checks if all prerequisites for
// serving ACME are met by the CA configuration.
type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(context.Context) (bool, error) {
return true, nil
}
type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority.
type Provisioner interface {
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error
IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool
IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
GetAttestationRoots() (*x509.CertPool, bool)
GetID() string
GetName() string
DefaultTLSCertDuration() time.Duration
GetOptions() *provisioner.Options
}
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustProvisionerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
var (
v Provisioner
ok bool
)
if v, ok = ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
}
return v
}
// MockProvisioner for testing
type MockProvisioner struct {
Mret1 interface{}
Merr error
MgetID func() string
MgetName func() string
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MisChallengeEnabled func(ctx context.Context, challenge provisioner.ACMEChallenge) bool
MisAttFormatEnabled func(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
MgetAttestationRoots func() (*x509.CertPool, bool)
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
}
// GetName mock
func (m *MockProvisioner) GetName() string {
if m.MgetName != nil {
return m.MgetName()
}
return m.Mret1.(string)
}
// AuthorizeOrderIdentifiers mock
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
if m.MauthorizeOrderIdentifier != nil {
return m.MauthorizeOrderIdentifier(ctx, identifier)
}
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.MauthorizeSign != nil {
return m.MauthorizeSign(ctx, ott)
}
return m.Mret1.([]provisioner.SignOption), m.Merr
}
// AuthorizeRevoke mock
func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error {
if m.MauthorizeRevoke != nil {
return m.MauthorizeRevoke(ctx, token)
}
return m.Merr
}
// IsChallengeEnabled mock
func (m *MockProvisioner) IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool {
if m.MisChallengeEnabled != nil {
return m.MisChallengeEnabled(ctx, challenge)
}
return m.Merr == nil
}
// IsAttestationFormatEnabled mock
func (m *MockProvisioner) IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool {
if m.MisAttFormatEnabled != nil {
return m.MisAttFormatEnabled(ctx, format)
}
return m.Merr == nil
}
func (m *MockProvisioner) GetAttestationRoots() (*x509.CertPool, bool) {
if m.MgetAttestationRoots != nil {
return m.MgetAttestationRoots()
}
return m.Mret1.(*x509.CertPool), m.Mret1 != nil
}
// DefaultTLSCertDuration mock
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
if m.MdefaultTLSCertDuration != nil {
return m.MdefaultTLSCertDuration()
}
return m.Mret1.(time.Duration)
}
// GetOptions mock
func (m *MockProvisioner) GetOptions() *provisioner.Options {
if m.MgetOptions != nil {
return m.MgetOptions()
}
return m.Mret1.(*provisioner.Options)
}
// GetID mock
func (m *MockProvisioner) GetID() string {
if m.MgetID != nil {
return m.MgetID()
}
return m.Mret1.(string)
}