package identity import ( "bytes" "crypto" "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "net/http" "os" "path/filepath" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "go.step.sm/cli-utils/step" "go.step.sm/crypto/pemutil" ) // Type represents the different types of identity files. type Type string // Disabled represents a disabled identity type const Disabled Type = "" // MutualTLS represents the identity using mTLS. const MutualTLS Type = "mTLS" // TunnelTLS represents an identity using a (m)TLS tunnel. // // TunnelTLS can be optionally configured with client certificates and a root // file with the CAs to trust. By default it will use the system truststore // instead of the CA truststore. const TunnelTLS Type = "tTLS" // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute var ( identityDir = step.IdentityPath configDir = step.ConfigPath // IdentityFile contains a pointer to a function that outputs the location of // the identity file. IdentityFile = step.IdentityFile // DefaultsFile contains a prointer a function that outputs the location of the // defaults configuration file. DefaultsFile = step.DefaultsFile ) // Identity represents the identity file that can be used to authenticate with // the CA. type Identity struct { Type string `json:"type"` Certificate string `json:"crt"` Key string `json:"key"` // Host is the tunnel host for a TunnelTLS (tTLS) identity. Host string `json:"host,omitempty"` // Root is the CA bundle of root CAs used in TunnelTLS to trust the // certificate of the host. Root string `json:"root,omitempty"` } // LoadIdentity loads an identity present in the given filename. func LoadIdentity(filename string) (*Identity, error) { b, err := os.ReadFile(filename) if err != nil { return nil, errors.Wrapf(err, "error reading %s", filename) } identity := new(Identity) if err := json.Unmarshal(b, &identity); err != nil { return nil, errors.Wrapf(err, "error unmarshaling %s", filename) } return identity, nil } // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { return LoadIdentity(IdentityFile()) } // WriteDefaultIdentity writes the given certificates and key and the // identity.json pointing to the new files. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { if err := os.MkdirAll(configDir(), 0700); err != nil { return errors.Wrap(err, "error creating config directory") } identityDir := identityDir() if err := os.MkdirAll(identityDir, 0700); err != nil { return errors.Wrap(err, "error creating identity directory") } certFilename := filepath.Join(identityDir, "identity.crt") keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate if err := writeCertificate(certFilename, certChain); err != nil { return err } // Write key buf := new(bytes.Buffer) block, err := pemutil.Serialize(key) if err != nil { return err } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding identity key") } if err := os.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } // Write identity.json buf.Reset() enc := json.NewEncoder(buf) enc.SetIndent("", " ") if err := enc.Encode(Identity{ Type: string(MutualTLS), Certificate: certFilename, Key: keyFilename, }); err != nil { return errors.Wrap(err, "error writing identity json") } if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } return nil } // WriteIdentityCertificate writes the identity certificate to disk. func WriteIdentityCertificate(certChain []api.Certificate) error { filename := filepath.Join(identityDir(), "identity.crt") return writeCertificate(filename, certChain) } // writeCertificate writes the given certificate on disk. func writeCertificate(filename string, certChain []api.Certificate) error { buf := new(bytes.Buffer) for _, crt := range certChain { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding certificate") } } if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing certificate") } return nil } // Kind returns the type for the given identity. func (i *Identity) Kind() Type { switch strings.ToLower(i.Type) { case "": return Disabled case "mtls": return MutualTLS case "ttls": return TunnelTLS default: return Type(i.Type) } } // Validate validates the identity object. func (i *Identity) Validate() error { switch i.Kind() { case Disabled: return nil case MutualTLS: if i.Certificate == "" { return errors.New("identity.crt cannot be empty") } if i.Key == "" { return errors.New("identity.key cannot be empty") } if err := fileExists(i.Certificate); err != nil { return err } return fileExists(i.Key) case TunnelTLS: if i.Host == "" { return errors.New("tunnel.host cannot be empty") } if i.Certificate != "" { if err := fileExists(i.Certificate); err != nil { return err } if i.Key == "" { return errors.New("tunnel.key cannot be empty") } if err := fileExists(i.Key); err != nil { return err } } if i.Root != "" { if err := fileExists(i.Root); err != nil { return err } } return nil default: return errors.Errorf("unsupported identity type %s", i.Type) } } // TLSCertificate returns a tls.Certificate for the identity. func (i *Identity) TLSCertificate() (tls.Certificate, error) { fail := func(err error) (tls.Certificate, error) { return tls.Certificate{}, err } switch i.Kind() { case Disabled: return tls.Certificate{}, nil case MutualTLS, TunnelTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) } // Check if certificate is expired. x509Cert, err := x509.ParseCertificate(crt.Certificate[0]) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) } now := time.Now().Truncate(time.Second) if now.Add(DefaultLeeway).Before(x509Cert.NotBefore) { return fail(errors.New("certificate is not yet valid")) } if now.After(x509Cert.NotAfter) { return fail(errors.New("certificate is already expired")) } return crt, nil default: return fail(errors.Errorf("unsupported identity type %s", i.Type)) } } // GetClientCertificateFunc returns a method that can be used as the // GetClientCertificate property in a tls.Config. func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { return func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return nil, errors.Wrap(err, "error loading identity certificate") } return &crt, nil } } // GetCertPool returns a x509.CertPool if the identity defines a custom root. func (i *Identity) GetCertPool() (*x509.CertPool, error) { if i.Root == "" { return nil, nil } b, err := os.ReadFile(i.Root) if err != nil { return nil, errors.Wrap(err, "error reading identity root") } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(b) { return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root) } return pool, nil } // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool Renew(tr http.RoundTripper) (*api.SignResponse, error) } // Renew renews the current identity certificate using a client with a renew // method. func (i *Identity) Renew(client Renewer) error { switch i.Kind() { case Disabled: return nil case MutualTLS, TunnelTLS: cert, err := i.TLSCertificate() if err != nil { return err } tr := http.DefaultTransport.(*http.Transport).Clone() tr.TLSClientConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: client.GetRootCAs(), MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, } sign, err := client.Renew(tr) if err != nil { return err } if sign.CertChainPEM == nil || len(sign.CertChainPEM) == 0 { sign.CertChainPEM = []api.Certificate{sign.ServerPEM, sign.CaPEM} } // Write certificate buf := new(bytes.Buffer) for _, crt := range sign.CertChainPEM { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding identity certificate") } } certFilename := filepath.Join(identityDir(), "identity.crt") if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } return nil default: return errors.Errorf("unsupported identity type %s", i.Type) } } func fileExists(filename string) error { info, err := os.Stat(filename) if err != nil { return errors.Wrapf(err, "error reading %s", filename) } if info.IsDir() { return errors.Errorf("error reading %s: file is a directory", filename) } return nil }