smallstep-certificates/ca/client.go
2022-09-20 22:12:08 -07:00

1356 lines
39 KiB
Go

package ca
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/json"
"encoding/pem"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"golang.org/x/net/http2"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"gopkg.in/square/go-jose.v2/jwt"
)
// DisableIdentity is a global variable to disable the identity.
var DisableIdentity = false
// UserAgent will set the User-Agent header in the client requests.
var UserAgent = "step-http-client/1.0"
type uaClient struct {
Client *http.Client
}
func newClient(transport http.RoundTripper) *uaClient {
return &uaClient{
Client: &http.Client{
Transport: transport,
},
}
}
//nolint:gosec // used in bootstrap protocol
func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
},
}
}
func (c *uaClient) GetTransport() http.RoundTripper {
return c.Client.Transport
}
func (c *uaClient) SetTransport(tr http.RoundTripper) {
c.Client.Transport = tr
}
func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", u, http.NoBody)
if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", u)
}
req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req)
}
func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", u, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req)
}
func (c *uaClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req)
}
// RetryFunc defines the method used to retry a request. If it returns true, the
// request will be retried once.
type RetryFunc func(code int) bool
// ClientOption is the type of options passed to the Client constructor.
type ClientOption func(o *clientOptions) error
type clientOptions struct {
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
certificate tls.Certificate
getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
x5cJWK *jose.JSONWebKey
x5cCertFile string
x5cCertStrs []string
x5cCert *x509.Certificate
x5cSubject string
}
func (o *clientOptions) apply(opts []ClientOption) (err error) {
o.applyDefaultIdentity()
for _, fn := range opts {
if err = fn(o); err != nil {
return
}
}
return
}
// applyDefaultIdentity sets the options for the default identity if the
// identity file is present. The identity is enabled by default.
func (o *clientOptions) applyDefaultIdentity() {
if DisableIdentity {
return
}
// Do not load an identity if something fails
i, err := identity.LoadDefaultIdentity()
if err != nil {
return
}
if err := i.Validate(); err != nil {
return
}
crt, err := i.TLSCertificate()
if err != nil {
return
}
o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
}
// checkTransport checks if other ways to set up a transport have been provided.
// If they have it returns an error.
func (o *clientOptions) checkTransport() error {
if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" || o.rootBundle != nil {
return errors.New("multiple transport methods have been configured")
}
return nil
}
// getTransport returns the transport configured in the clientOptions.
func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err error) {
if o.transport != nil {
tr = o.transport
}
if o.rootFilename != "" {
if tr, err = getTransportFromFile(o.rootFilename); err != nil {
return nil, err
}
}
if o.rootSHA256 != "" {
if tr, err = getTransportFromSHA256(endpoint, o.rootSHA256); err != nil {
return nil, err
}
}
if o.rootBundle != nil {
if tr, err = getTransportFromCABundle(o.rootBundle); err != nil {
return nil, err
}
}
// As the last option attempt to load the default root ca
if tr == nil {
rootFile := getRootCAPath()
if _, err := os.Stat(rootFile); err == nil {
if tr, err = getTransportFromFile(rootFile); err != nil {
return nil, err
}
}
if tr == nil {
return nil, errors.New("a transport, a root cert, or a root sha256 must be used")
}
}
// Add client certificate if available
if o.certificate.Certificate != nil {
switch tr := tr.(type) {
case *http.Transport:
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
case *http2.Transport:
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
default:
return nil, errors.Errorf("unsupported transport type %T", tr)
}
}
return tr, nil
}
// WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.transport = tr
return nil
}
}
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
//nolint:gosec // insecure option
InsecureSkipVerify: true,
},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootFilename = filename
return nil
}
}
// WithRootSHA256 will create the transport using an insecure client to retrieve
// the root certificate using its fingerprint. It will fail if a previous option
// to create the transport has been configured.
func WithRootSHA256(sum string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootSHA256 = sum
return nil
}
}
// WithCABundle will create the transport using the given root certificates. It
// will fail if a previous option to create the transport has been configured.
func WithCABundle(bundle []byte) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootBundle = bundle
return nil
}
}
// WithCertificate will set the given certificate as the TLS client certificate
// in the client.
func WithCertificate(cert tls.Certificate) ClientOption {
return func(o *clientOptions) error {
o.certificate = cert
return nil
}
}
// WithAdminX5C will set the given file as the X5C certificate for use
// by the client.
func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile string) ClientOption {
return func(o *clientOptions) error {
// Get private key from given key file
var (
err error
opts []jose.Option
)
if passwordFile != "" {
opts = append(opts, jose.WithPasswordFile(passwordFile))
}
blk, err := pemutil.Serialize(key)
if err != nil {
return errors.Wrap(err, "error serializing private key")
}
o.x5cJWK, err = jose.ParseKey(pem.EncodeToMemory(blk), opts...)
if err != nil {
return err
}
o.x5cCertStrs, err = jose.ValidateX5C(certs, o.x5cJWK.Key)
if err != nil {
return errors.Wrap(err, "error validating x5c certificate chain and key for use in x5c header")
}
o.x5cCert = certs[0]
switch leaf := certs[0]; {
case leaf.Subject.CommonName != "":
o.x5cSubject = leaf.Subject.CommonName
case len(leaf.DNSNames) > 0:
o.x5cSubject = leaf.DNSNames[0]
case len(leaf.EmailAddresses) > 0:
o.x5cSubject = leaf.EmailAddresses[0]
}
return nil
}
}
// WithRetryFunc defines a method used to retry a request.
func WithRetryFunc(fn RetryFunc) ClientOption {
return func(o *clientOptions) error {
o.retryFunc = fn
return nil
}
}
func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(data) {
return nil, errors.Errorf("error parsing %s: no certificates found", filename)
}
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
}), nil
}
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
u, err := parseEndpoint(endpoint)
if err != nil {
return nil, err
}
client := &Client{endpoint: u}
root, err := client.Root(sum)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(root.RootPEM.Certificate)
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
}), nil
}
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(bundle) {
return nil, errors.New("error parsing ca bundle: no certificates found")
}
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
}), nil
}
// parseEndpoint parses and validates the given endpoint. It supports general
// URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like
// ca.smallstep.com[:port][/path].
func parseEndpoint(endpoint string) (*url.URL, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, errors.Wrapf(err, "error parsing endpoint '%s'", endpoint)
}
// URLs are generally parsed as:
// [scheme:][//[userinfo@]host][/]path[?query][#fragment]
// But URLs that do not start with a slash after the scheme are interpreted as
// scheme:opaque[?query][#fragment]
if u.Opaque == "" {
if u.Scheme == "" {
u.Scheme = "https"
}
if u.Host == "" {
// endpoint looks like ca.smallstep.com or ca.smallstep.com/1.0/sign
if u.Path != "" {
parts := strings.SplitN(u.Path, "/", 2)
u.Host = parts[0]
if len(parts) == 2 {
u.Path = parts[1]
} else {
u.Path = ""
}
return parseEndpoint(u.String())
}
return nil, errors.Errorf("error parsing endpoint: url '%s' is not valid", endpoint)
}
return u, nil
}
// scheme:opaque[?query][#fragment]
// endpoint looks like ca.smallstep.com:443 or ca.smallstep.com:443/1.0/sign
return parseEndpoint("https://" + endpoint)
}
// ProvisionerOption is the type of options passed to the Provisioner method.
type ProvisionerOption func(o *ProvisionerOptions) error
// ProvisionerOptions stores options for the provisioner CRUD API.
type ProvisionerOptions struct {
Cursor string
Limit int
ID string
Name string
}
// Apply caches provisioner options on a struct for later use.
func (o *ProvisionerOptions) Apply(opts []ProvisionerOption) (err error) {
for _, fn := range opts {
if err = fn(o); err != nil {
return
}
}
return
}
func (o *ProvisionerOptions) rawQuery() string {
v := url.Values{}
if o.Cursor != "" {
v.Set("cursor", o.Cursor)
}
if o.Limit > 0 {
v.Set("limit", strconv.Itoa(o.Limit))
}
if o.ID != "" {
v.Set("id", o.ID)
}
if o.Name != "" {
v.Set("name", o.Name)
}
return v.Encode()
}
// WithProvisionerCursor will request the provisioners starting with the given cursor.
func WithProvisionerCursor(cursor string) ProvisionerOption {
return func(o *ProvisionerOptions) error {
o.Cursor = cursor
return nil
}
}
// WithProvisionerLimit will request the given number of provisioners.
func WithProvisionerLimit(limit int) ProvisionerOption {
return func(o *ProvisionerOptions) error {
o.Limit = limit
return nil
}
}
// WithProvisionerID will request the given provisioner.
func WithProvisionerID(id string) ProvisionerOption {
return func(o *ProvisionerOptions) error {
o.ID = id
return nil
}
}
// WithProvisionerName will request the given provisioner.
func WithProvisionerName(name string) ProvisionerOption {
return func(o *ProvisionerOptions) error {
o.Name = name
return nil
}
}
// Client implements an HTTP client for the CA server.
type Client struct {
client *uaClient
endpoint *url.URL
retryFunc RetryFunc
opts []ClientOption
}
// NewClient creates a new Client with the given endpoint and options.
func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
u, err := parseEndpoint(endpoint)
if err != nil {
return nil, err
}
// Retrieve transport from options.
o := new(clientOptions)
if err := o.apply(opts); err != nil {
return nil, err
}
tr, err := o.getTransport(endpoint)
if err != nil {
return nil, err
}
return &Client{
client: newClient(tr),
endpoint: u,
retryFunc: o.retryFunc,
opts: opts,
}, nil
}
func (c *Client) retryOnError(r *http.Response) bool {
if c.retryFunc != nil {
if c.retryFunc(r.StatusCode) {
o := new(clientOptions)
if err := o.apply(c.opts); err != nil {
return false
}
tr, err := o.getTransport(c.endpoint.String())
if err != nil {
return false
}
r.Body.Close()
c.client.SetTransport(tr)
return true
}
}
return false
}
// GetCaURL returns the configured CA url.
func (c *Client) GetCaURL() string {
return c.endpoint.String()
}
// GetRootCAs returns the RootCAs certificate pool from the configured
// transport.
func (c *Client) GetRootCAs() *x509.CertPool {
switch t := c.client.GetTransport().(type) {
case *http.Transport:
if t.TLSClientConfig != nil {
return t.TLSClientConfig.RootCAs
}
return nil
case *http2.Transport:
if t.TLSClientConfig != nil {
return t.TLSClientConfig.RootCAs
}
return nil
default:
return nil
}
}
// SetTransport updates the transport of the internal HTTP client.
func (c *Client) SetTransport(tr http.RoundTripper) {
c.client.SetTransport(tr)
}
// Version performs the version request to the CA and returns the
// api.VersionResponse struct.
func (c *Client) Version() (*api.VersionResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/version"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var version api.VersionResponse
if err := readJSON(resp.Body, &version); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u)
}
return &version, nil
}
// Health performs the health request to the CA and returns the
// api.HealthResponse struct.
func (c *Client) Health() (*api.HealthResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var health api.HealthResponse
if err := readJSON(resp.Body, &health); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u)
}
return &health, nil
}
// Root performs the root request to the CA with the given SHA256 and returns
// the api.RootResponse struct. It uses an insecure client, but it checks the
// resulting root certificate with the given SHA256, returning an error if they
// do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
var retried bool
sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry:
resp, err := newInsecureClient().Get(u.String())
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var root api.RootResponse
if err := readJSON(resp.Body, &root); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u)
}
// verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw)
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("root certificate fingerprint does not match")
}
return &root, nil
}
// Sign performs the sign request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u)
}
// Add tls.ConnectionState:
// We'll extract the root certificate from the verified chains
sign.TLS = resp.TLS
return &sign, nil
}
// Renew performs the renew request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u)
}
return &sign, nil
}
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
}
req.Header.Add("Authorization", "Bearer "+token)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
}
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u)
}
return &sign, nil
}
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
// struct.
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
var client *uaClient
retry:
if tr != nil {
client = newClient(tr)
} else {
client = c.client
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"})
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var revoke api.RevokeResponse
if err := readJSON(resp.Body, &revoke); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &revoke, nil
}
// Provisioners performs the provisioners request to the CA and returns the
// api.ProvisionersResponse struct with a map of provisioners.
//
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
// paginate the provisioners.
func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
var retried bool
o := new(ProvisionerOptions)
if err := o.Apply(opts); err != nil {
return nil, err
}
u := c.endpoint.ResolveReference(&url.URL{
Path: "/provisioners",
RawQuery: o.rawQuery(),
})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var provisioners api.ProvisionersResponse
if err := readJSON(resp.Body, &provisioners); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &provisioners, nil
}
// ProvisionerKey performs the request to the CA to get the encrypted key for
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct
// with the encrypted key.
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var key api.ProvisionerKeyResponse
if err := readJSON(resp.Body, &key); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &key, nil
}
// Roots performs the get roots request to the CA and returns the
// api.RootsResponse struct.
func (c *Client) Roots() (*api.RootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var roots api.RootsResponse
if err := readJSON(resp.Body, &roots); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &roots, nil
}
// Federation performs the get federation request to the CA and returns the
// api.FederationResponse struct.
func (c *Client) Federation() (*api.FederationResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var federation api.FederationResponse
if err := readJSON(resp.Body, &federation); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &federation, nil
}
// SSHSign performs the POST /ssh/sign request to the CA and returns the
// api.SSHSignResponse struct.
func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SSHSignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &sign, nil
}
// SSHRenew performs the POST /ssh/renew request to the CA and returns the
// api.SSHRenewResponse struct.
func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var renew api.SSHRenewResponse
if err := readJSON(resp.Body, &renew); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &renew, nil
}
// SSHRekey performs the POST /ssh/rekey request to the CA and returns the
// api.SSHRekeyResponse struct.
func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var rekey api.SSHRekeyResponse
if err := readJSON(resp.Body, &rekey); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &rekey, nil
}
// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the
// api.SSHRevokeResponse struct.
func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var revoke api.SSHRevokeResponse
if err := readJSON(resp.Body, &revoke); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &revoke, nil
}
// SSHRoots performs the GET /ssh/roots request to the CA and returns the
// api.SSHRootsResponse struct.
func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var keys api.SSHRootsResponse
if err := readJSON(resp.Body, &keys); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &keys, nil
}
// SSHFederation performs the get /ssh/federation request to the CA and returns
// the api.SSHRootsResponse struct.
func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var keys api.SSHRootsResponse
if err := readJSON(resp.Body, &keys); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &keys, nil
}
// SSHConfig performs the POST /ssh/config request to the CA to get the ssh
// configuration templates.
func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var cfg api.SSHConfigResponse
if err := readJSON(resp.Body, &cfg); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &cfg, nil
}
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal.
func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert,
Principal: principal,
Token: token,
})
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request",
errs.WithMessage("Failed to marshal the check-host request"))
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed",
[]interface{}{u, errs.WithMessage("Failed to perform POST request to %s", u)}...)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response",
[]any{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")}...)
}
return &check, nil
}
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA.
func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/hosts"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var hosts api.SSHGetHostsResponse
if err := readJSON(resp.Body, &hosts); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &hosts, nil
}
// SSHBastion performs the POST /ssh/bastion request to the CA.
func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client.SSHBastion; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var bastion api.SSHBastionResponse
if err := readJSON(resp.Body, &bastion); err != nil {
return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u)
}
return &bastion, nil
}
// RootFingerprint is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified
// chains.
func (c *Client) RootFingerprint() (string, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
resp, err := c.client.Get(u.String())
if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", u)
}
defer resp.Body.Close()
if resp.TLS == nil || len(resp.TLS.VerifiedChains) == 0 {
return "", errors.New("missing verified chains")
}
lastChain := resp.TLS.VerifiedChains[len(resp.TLS.VerifiedChains)-1]
if len(lastChain) == 0 {
return "", errors.New("missing verified chains")
}
return x509util.Fingerprint(lastChain[len(lastChain)-1]), nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
token, err := jwt.ParseSigned(ott)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing ott")
}
var claims authority.Claims
if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, nil, errors.Wrap(err, "error parsing ott")
}
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, errors.Wrap(err, "error generating key")
}
dnsNames, ips, emails, uris := x509util.SplitSANs(claims.SANs)
if claims.Email != "" {
emails = append(emails, claims.Email)
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: claims.Subject,
},
SignatureAlgorithm: x509.ECDSAWithSHA256,
DNSNames: dnsNames,
IPAddresses: ips,
EmailAddresses: emails,
URIs: uris,
}
csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk)
if err != nil {
return nil, nil, errors.Wrap(err, "error creating certificate request")
}
cr, err := x509.ParseCertificateRequest(csr)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing certificate request")
}
if err := cr.CheckSignature(); err != nil {
return nil, nil, errors.Wrap(err, "error signing certificate request")
}
return &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: cr},
OTT: ott,
}, pk, nil
}
// CreateCertificateRequest creates a new CSR with the given common name and
// SANs. If no san is provided the commonName will set also a SAN.
func CreateCertificateRequest(commonName string, sans ...string) (*api.CertificateRequest, crypto.PrivateKey, error) {
key, err := keyutil.GenerateDefaultKey()
if err != nil {
return nil, nil, err
}
return createCertificateRequest(commonName, sans, key)
}
// CreateIdentityRequest returns a new CSR to create the identity. If an
// identity was already present it reuses the private key.
func CreateIdentityRequest(commonName string, sans ...string) (*api.CertificateRequest, crypto.PrivateKey, error) {
var identityKey crypto.PrivateKey
if i, err := identity.LoadDefaultIdentity(); err == nil && i.Key != "" {
if k, err := pemutil.Read(i.Key); err == nil {
identityKey = k
}
}
if identityKey == nil {
return CreateCertificateRequest(commonName, sans...)
}
return createCertificateRequest(commonName, sans, identityKey)
}
// LoadDefaultIdentity is a wrapper for identity.LoadDefaultIdentity.
func LoadDefaultIdentity() (*identity.Identity, error) {
return identity.LoadDefaultIdentity()
}
// WriteDefaultIdentity is a wrapper for identity.WriteDefaultIdentity.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
return identity.WriteDefaultIdentity(certChain, key)
}
func createCertificateRequest(commonName string, sans []string, key crypto.PrivateKey) (*api.CertificateRequest, crypto.PrivateKey, error) {
if len(sans) == 0 {
sans = []string{commonName}
}
dnsNames, ips, emails, uris := x509util.SplitSANs(sans)
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: commonName,
},
DNSNames: dnsNames,
IPAddresses: ips,
EmailAddresses: emails,
URIs: uris,
}
csr, err := x509.CreateCertificateRequest(rand.Reader, template, key)
if err != nil {
return nil, nil, err
}
cr, err := x509.ParseCertificateRequest(csr)
if err != nil {
return nil, nil, err
}
if err := cr.CheckSignature(); err != nil {
return nil, nil, err
}
return &api.CertificateRequest{CertificateRequest: cr}, key, nil
}
// getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable.
func getRootCAPath() string {
return filepath.Join(step.Path(), "certs", "root_ca.crt")
}
func readJSON(r io.ReadCloser, v interface{}) error {
defer r.Close()
return json.NewDecoder(r).Decode(v)
}
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
defer r.Close()
data, err := io.ReadAll(r)
if err != nil {
return err
}
return protojson.Unmarshal(data, m)
}
func readError(r io.ReadCloser) error {
defer r.Close()
apiErr := new(errs.Error)
if err := json.NewDecoder(r).Decode(apiErr); err != nil {
return err
}
return apiErr
}