package vaultcas import ( "bytes" "context" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "math/big" "strings" "time" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/vaultcas/auth/approle" "github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes" vault "github.com/hashicorp/vault/api" ) func init() { apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } // VaultOptions defines the configuration options added using the // apiv1.Options.Config field. type VaultOptions struct { PKIMountPath string `json:"pkiMountPath,omitempty"` PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"` PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` AuthType string `json:"authType,omitempty"` AuthMountPath string `json:"authMountPath,omitempty"` Namespace string `json:"namespace,omitempty"` AuthOptions json.RawMessage `json:"authOptions,omitempty"` } // VaultCAS implements a Certificate Authority Service using Hashicorp Vault. type VaultCAS struct { client *vault.Client config VaultOptions fingerprint string } type certBundle struct { leaf *x509.Certificate intermediates []*x509.Certificate root *x509.Certificate } // New creates a new CertificateAuthorityService implementation // using Hashicorp Vault func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { if opts.CertificateAuthority == "" { return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty") } if opts.CertificateAuthorityFingerprint == "" { return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty") } vc, err := loadOptions(opts.Config) if err != nil { return nil, err } config := vault.DefaultConfig() config.Address = opts.CertificateAuthority client, err := vault.NewClient(config) if err != nil { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } var method vault.AuthMethod switch vc.AuthType { case "kubernetes": method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions) case "approle": method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) default: return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType) } if err != nil { return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err) } if vc.Namespace != "" { client.SetNamespace(vc.Namespace) } authInfo, err := client.Auth().Login(ctx, method) if err != nil { return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err) } if authInfo == nil { return nil, errors.New("no auth info was returned after login") } return &VaultCAS{ client: client, config: *vc, fingerprint: opts.CertificateAuthorityFingerprint, }, nil } // CreateCertificate signs a new certificate using Hashicorp Vault. func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.CSR == nil: return nil, errors.New("createCertificate `csr` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificate `lifetime` cannot be 0") } cert, chain, err := v.createCertificate(req.CSR, req.Lifetime) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (v *VaultCAS) GetCertificateAuthority(*apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain") if err != nil { return nil, fmt.Errorf("error reading ca chain: %w", err) } if secret == nil { return nil, errors.New("error reading ca chain: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, err } if cert.root == nil { return nil, errors.New("error unmarshaling vault response: root certificate not found") } sum := sha256.Sum256(cert.root.Raw) if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) { return nil, errors.New("error verifying vault root: fingerprint does not match") } return &apiv1.GetCertificateAuthorityResponse{ RootCertificate: cert.root, }, nil } // RenewCertificate will always return a non-implemented error as renewals // are not supported yet. func (v *VaultCAS) RenewCertificate(*apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { return nil, apiv1.NotImplementedError{Message: "vaultCAS does not support renewals"} } // RevokeCertificate revokes a certificate by serial number. func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { if req.SerialNumber == "" && req.Certificate == nil { return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required") } var sn *big.Int if req.SerialNumber != "" { var ok bool if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok { return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber) } } else { sn = req.Certificate.SerialNumber } vaultReq := map[string]interface{}{ "serial_number": formatSerialNumber(sn), } _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq) if err != nil { return nil, fmt.Errorf("error revoking certificate: %w", err) } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, CertificateChain: nil, }, nil } func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { var vaultPKIRole string switch { case cr.PublicKeyAlgorithm == x509.RSA: vaultPKIRole = v.config.PKIRoleRSA case cr.PublicKeyAlgorithm == x509.ECDSA: vaultPKIRole = v.config.PKIRoleEC case cr.PublicKeyAlgorithm == x509.Ed25519: vaultPKIRole = v.config.PKIRoleEd25519 default: return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm) } vaultReq := map[string]interface{}{ "csr": string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: cr.Raw, })), "format": "pem_bundle", "ttl": lifetime.String(), } secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq) if err != nil { return nil, nil, fmt.Errorf("error signing certificate: %w", err) } if secret == nil { return nil, nil, errors.New("error signing certificate: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, nil, err } // Return certificate and certificate chain return cert.leaf, cert.intermediates, nil } func loadOptions(config json.RawMessage) (*VaultOptions, error) { // setup default values vc := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "default", } err := json.Unmarshal(config, &vc) if err != nil { return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) } if vc.PKIRoleRSA == "" { vc.PKIRoleRSA = vc.PKIRoleDefault } if vc.PKIRoleEC == "" { vc.PKIRoleEC = vc.PKIRoleDefault } if vc.PKIRoleEd25519 == "" { vc.PKIRoleEd25519 = vc.PKIRoleDefault } return &vc, nil } func parseCertificates(pemCert string) []*x509.Certificate { var certs []*x509.Certificate rest := []byte(pemCert) var block *pem.Block for { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { break } certs = append(certs, cert) } return certs } func getCertificateBundle(chain string) (*certBundle, error) { var root *x509.Certificate var leaf *x509.Certificate var intermediates []*x509.Certificate for _, cert := range parseCertificates(chain) { switch { case isRoot(cert): root = cert case cert.BasicConstraintsValid && cert.IsCA: intermediates = append(intermediates, cert) default: leaf = cert } } certificate := &certBundle{ root: root, leaf: leaf, intermediates: intermediates, } return certificate, nil } // isRoot returns true if the given certificate is a root certificate. func isRoot(cert *x509.Certificate) bool { if cert.BasicConstraintsValid && cert.IsCA { return cert.CheckSignatureFrom(cert) == nil } return false } // formatSerialNumber formats a serial number to a dash-separated hexadecimal // string. func formatSerialNumber(sn *big.Int) string { var ret bytes.Buffer for _, b := range sn.Bytes() { if ret.Len() > 0 { ret.WriteString("-") } ret.WriteString(hex.EncodeToString([]byte{b})) } return ret.String() }