package ca import ( "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "net" "net/http" "os" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/ca/identity" ) // mTLSDialContext will hold the dial context function to use in // getDefaultTransport. var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) func init() { // STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS // over (m)TLS tunnel to step-ca using identity-like credentials. The value // is a path to a json file with the tunnel host, certificate, key and root // used to create the (m)TLS tunnel. // // The configuration should look like: // { // "type": "tTLS", // "host": "tunnel.example.com:443" // "crt": "/path/to/tunnel.crt", // "key": "/path/to/tunnel.key", // "root": "/path/to/tunnel-root.crt" // } // // This feature is EXPERIMENTAL and might change at any time. if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" { id, err := identity.LoadIdentity(path) if err != nil { panic(err) } if err := id.Validate(); err != nil { panic(err) } host, port, err := net.SplitHostPort(id.Host) if err != nil { panic(err) } pool, err := id.GetCertPool() if err != nil { panic(err) } mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) { d := &tls.Dialer{ NetDialer: getDefaultDialer(), Config: &tls.Config{ RootCAs: pool, GetClientCertificate: id.GetClientCertificateFunc(), }, } return func(ctx context.Context, network, address string) (net.Conn, error) { return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) } } } } // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } return tlsConfig, nil } func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, nil, err } renewer, err := NewTLSRenewer(cert, nil) if err != nil { return nil, nil, err } tlsConfig := getDefaultTLSConfig(sign) // Note that with GetClientCertificate tlsConfig.Certificates is not used. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, nil, err } tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport c.SetTransport(tr) // Start renewer renewer.RunContext(ctx) return tlsConfig, tr, nil } // GetServerTLSConfig returns a tls.Config for server use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The returned tls.Config will only verify the client certificate if provided. // The server certificate will automatically rotate before expiring. func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, err } renewer, err := NewTLSRenewer(cert, nil) if err != nil { return nil, err } tlsConfig := getDefaultTLSConfig(sign) // Note that GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, err } // GetConfigForClient allows seamless root and federated roots rotation. // If the return of the callback is not-nil, it will use the returned // tls.Config instead of the default one. tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) // Update renew function with transport tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport c.SetTransport(tr) // Start renewer renewer.RunContext(ctx) return tlsConfig, nil } // Transport returns an http.Transport configured to use the client certificate from the sign response. func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) { _, tr, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } return tr, nil } // buildGetConfigForClient returns an implementation of GetConfigForClient // callback in tls.Config. // // If the implementation returns a nil tls.Config, the original Config will be // used, but if it's non-nil, the returned Config will be used to handle this // connection. func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) { return func(*tls.ClientHelloInfo) (*tls.Config, error) { return ctx.mutableConfig.TLSConfig(), nil } } // buildDialTLS returns an implementation of DialTLS callback in http.Transport. func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) { return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig()) } } // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport. // nolint:unused,gocritic func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { d := getDefaultDialer() // TLS dialers do not support context, but we can use the context // deadline if it is set. if t, ok := ctx.Deadline(); ok { d.Deadline = t } return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig()) } } // Certificate returns the server or client certificate from the sign response. func Certificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.ServerPEM.Certificate == nil { return nil, errors.New("ca: certificate does not exist") } return sign.ServerPEM.Certificate, nil } // IntermediateCertificate returns the CA intermediate certificate from the sign // response. func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.CaPEM.Certificate == nil { return nil, errors.New("ca: certificate does not exist") } return sign.CaPEM.Certificate, nil } // RootCertificate returns the root certificate from the sign response. func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { return nil, errors.New("ca: certificate does not exist") } lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1] if len(lastChain) == 0 { return nil, errors.New("ca: certificate does not exist") } return lastChain[len(lastChain)-1], nil } // TLSCertificate creates a new TLS certificate from the sign response and the // private key used. func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) { certPEM, err := getPEM(sign.ServerPEM) if err != nil { return nil, err } caPEM, err := getPEM(sign.CaPEM) if err != nil { return nil, err } keyPEM, err := getPEM(pk) if err != nil { return nil, err } // nolint:gocritic // using a new variable for clarity chain := append(certPEM, caPEM...) cert, err := tls.X509KeyPair(chain, keyPEM) if err != nil { return nil, errors.Wrap(err, "error creating tls certificate") } leaf, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return nil, errors.Wrap(err, "error parsing tls certificate") } cert.Leaf = leaf return &cert, nil } func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { if sign.TLSOptions != nil { return sign.TLSOptions.TLSConfig() } return &tls.Config{ MinVersion: tls.VersionTLS12, } } // getDefaultDialer returns a new dialer with the default configuration. func getDefaultDialer() *net.Dialer { // With the KeepAlive parameter set to 0, it will be use Golang's default. return &net.Dialer{ Timeout: 30 * time.Second, } } // getDefaultTransport returns an http.Transport with the same parameters than // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) if mTLSDialContext == nil { d := getDefaultDialer() dialContext = d.DialContext } else { dialContext = mTLSDialContext() } return &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, } } func getPEM(i interface{}) ([]byte, error) { block := new(pem.Block) switch i := i.(type) { case api.Certificate: block.Type = "CERTIFICATE" block.Bytes = i.Raw case *x509.Certificate: block.Type = "CERTIFICATE" block.Bytes = i.Raw case *rsa.PrivateKey: block.Type = "RSA PRIVATE KEY" block.Bytes = x509.MarshalPKCS1PrivateKey(i) case *ecdsa.PrivateKey: var err error block.Type = "EC PRIVATE KEY" block.Bytes, err = x509.MarshalECPrivateKey(i) if err != nil { return nil, errors.Wrap(err, "error marshaling private key") } case ed25519.PrivateKey: var err error block.Type = "PRIVATE KEY" block.Bytes, err = x509.MarshalPKCS8PrivateKey(i) if err != nil { return nil, errors.Wrap(err, "error marshaling private key") } default: return nil, errors.Errorf("unsupported key type %T", i) } return pem.EncodeToMemory(block), nil } func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { // Get updated list of roots if err := ctx.applyRenew(); err != nil { return nil, err } // Get new certificate sign, err := client.Renew(tr) if err != nil { return nil, err } return TLSCertificate(sign, pk) } }