package vaultcas import ( "bytes" "context" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "math/big" "strings" "time" "github.com/smallstep/certificates/cas/apiv1" vault "github.com/hashicorp/vault/api" auth "github.com/hashicorp/vault/api/auth/approle" ) func init() { apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } // VaultOptions defines the configuration options added using the // apiv1.Options.Config field. type VaultOptions struct { PKI string `json:"pki,omitempty"` PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"` PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` RoleID string `json:"roleID,omitempty"` SecretID auth.SecretID `json:"secretID,omitempty"` AppRole string `json:"appRole,omitempty"` IsWrappingToken bool `json:"isWrappingToken,omitempty"` } // VaultCAS implements a Certificate Authority Service using Hashicorp Vault. type VaultCAS struct { client *vault.Client config VaultOptions fingerprint string } type certBundle struct { leaf *x509.Certificate intermediates []*x509.Certificate root *x509.Certificate } // New creates a new CertificateAuthorityService implementation // using Hashicorp Vault func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { if opts.CertificateAuthority == "" { return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty") } if opts.CertificateAuthorityFingerprint == "" { return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty") } vc, err := loadOptions(opts.Config) if err != nil { return nil, err } config := vault.DefaultConfig() config.Address = opts.CertificateAuthority client, err := vault.NewClient(config) if err != nil { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } var appRoleAuth *auth.AppRoleAuth if vc.IsWrappingToken { appRoleAuth, err = auth.NewAppRoleAuth( vc.RoleID, &vc.SecretID, auth.WithWrappingToken(), auth.WithMountPath(vc.AppRole), ) } else { appRoleAuth, err = auth.NewAppRoleAuth( vc.RoleID, &vc.SecretID, auth.WithMountPath(vc.AppRole), ) } if err != nil { return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) } authInfo, err := client.Auth().Login(ctx, appRoleAuth) if err != nil { return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) } if authInfo == nil { return nil, errors.New("no auth info was returned after login") } return &VaultCAS{ client: client, config: *vc, fingerprint: opts.CertificateAuthorityFingerprint, }, nil } // CreateCertificate signs a new certificate using Hashicorp Vault. func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.CSR == nil: return nil, errors.New("createCertificate `csr` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificate `lifetime` cannot be 0") } cert, chain, err := v.createCertificate(req.CSR, req.Lifetime) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain") if err != nil { return nil, fmt.Errorf("error reading ca chain: %w", err) } if secret == nil { return nil, errors.New("error reading ca chain: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, err } if cert.root == nil { return nil, errors.New("error unmarshaling vault response: root certificate not found") } sum := sha256.Sum256(cert.root.Raw) if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) { return nil, errors.New("error verifying vault root: fingerprint does not match") } return &apiv1.GetCertificateAuthorityResponse{ RootCertificate: cert.root, }, nil } // RenewCertificate will always return a non-implemented error as renewals // are not supported yet. func (v *VaultCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { return nil, apiv1.ErrNotImplemented{Message: "vaultCAS does not support renewals"} } // RevokeCertificate revokes a certificate by serial number. func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { if req.SerialNumber == "" && req.Certificate == nil { return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required") } var sn *big.Int if req.SerialNumber != "" { var ok bool if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok { return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber) } } else { sn = req.Certificate.SerialNumber } vaultReq := map[string]interface{}{ "serial_number": formatSerialNumber(sn), } _, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq) if err != nil { return nil, fmt.Errorf("error revoking certificate: %w", err) } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, CertificateChain: nil, }, nil } func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { var vaultPKIRole string switch { case cr.PublicKeyAlgorithm == x509.RSA: vaultPKIRole = v.config.PKIRoleRSA case cr.PublicKeyAlgorithm == x509.ECDSA: vaultPKIRole = v.config.PKIRoleEC case cr.PublicKeyAlgorithm == x509.Ed25519: vaultPKIRole = v.config.PKIRoleEd25519 default: return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm) } vaultReq := map[string]interface{}{ "csr": string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: cr.Raw, })), "format": "pem_bundle", "ttl": lifetime.Seconds(), } secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq) if err != nil { return nil, nil, fmt.Errorf("error signing certificate: %w", err) } if secret == nil { return nil, nil, errors.New("error signing certificate: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, nil, err } // Return certificate and certificate chain return cert.leaf, cert.intermediates, nil } func loadOptions(config json.RawMessage) (*VaultOptions, error) { var vc *VaultOptions err := json.Unmarshal(config, &vc) if err != nil { return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) } if vc.PKI == "" { vc.PKI = "pki" // use default pki vault name } if vc.PKIRoleDefault == "" { vc.PKIRoleDefault = "default" // use default pki role name } if vc.PKIRoleRSA == "" { vc.PKIRoleRSA = vc.PKIRoleDefault } if vc.PKIRoleEC == "" { vc.PKIRoleEC = vc.PKIRoleDefault } if vc.PKIRoleEd25519 == "" { vc.PKIRoleEd25519 = vc.PKIRoleDefault } if vc.RoleID == "" { return nil, errors.New("vaultCAS config options must define `roleID`") } if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" { return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") } if vc.PKI == "" { vc.PKI = "pki" // use default pki vault name } if vc.AppRole == "" { vc.AppRole = "auth/approle" } return vc, nil } func parseCertificates(pemCert string) []*x509.Certificate { var certs []*x509.Certificate rest := []byte(pemCert) var block *pem.Block for { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { break } certs = append(certs, cert) } return certs } func getCertificateBundle(chain string) (*certBundle, error) { var root *x509.Certificate var leaf *x509.Certificate var intermediates []*x509.Certificate for _, cert := range parseCertificates(chain) { switch { case isRoot(cert): root = cert case cert.BasicConstraintsValid && cert.IsCA: intermediates = append(intermediates, cert) default: leaf = cert } } certificate := &certBundle{ root: root, leaf: leaf, intermediates: intermediates, } return certificate, nil } // isRoot returns true if the given certificate is a root certificate. func isRoot(cert *x509.Certificate) bool { if cert.BasicConstraintsValid && cert.IsCA { return cert.CheckSignatureFrom(cert) == nil } return false } // formatSerialNumber formats a serial number to a dash-separated hexadecimal // string. func formatSerialNumber(sn *big.Int) string { var ret bytes.Buffer for _, b := range sn.Bytes() { if ret.Len() > 0 { ret.WriteString("-") } ret.WriteString(hex.EncodeToString([]byte{b})) } return ret.String() }