Merge branch 'master' into wire-acme-extensions

pull/1666/head
Herman Slatman 3 months ago
commit 364566bb01
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -1,49 +1,62 @@
# Step Certificates # step-ca
`step-ca` is an online certificate authority for secure, automated certificate management. It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli). [![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
[![Build Status](https://github.com/smallstep/certificates/actions/workflows/test.yml/badge.svg)](https://github.com/smallstep/certificates)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![CLA assistant](https://cla-assistant.io/readme/badge/smallstep/certificates)](https://cla-assistant.io/smallstep/certificates)
You can use it to: `step-ca` is an online certificate authority for secure, automated certificate management for DevOps.
- Issue X.509 certificates for your internal infrastructure: It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli) for working with certificates and keys.
- HTTPS certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance) Both projects are maintained by [Smallstep Labs](https://smallstep.com).
- TLS certificates for VMs, containers, APIs, mobile clients, database connections, printers, wifi networks, toaster ovens...
- Client certificates to [enable mutual TLS (mTLS)](https://smallstep.com/hello-mtls) in your infra. mTLS is an optional feature in TLS where both client and server authenticate each other. Why add the complexity of a VPN when you can safely use mTLS over the public internet? You can use `step-ca` to:
- Issue HTTPS server and client certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance)
- Issue TLS certificates for DevOps: VMs, containers, APIs, database connections, Kubernetes pods...
- Issue SSH certificates: - Issue SSH certificates:
- For people, in exchange for single sign-on ID tokens - For people, in exchange for single sign-on identity tokens
- For hosts, in exchange for cloud instance identity documents - For hosts, in exchange for cloud instance identity documents
- Easily automate certificate management: - Easily automate certificate management:
- It's an ACME v2 server - It's an [ACME server](https://smallstep.com/docs/step-ca/acme-basics/) that supports all [popular ACME challenge types](https://smallstep.com/docs/step-ca/acme-basics/#acme-challenge-types)
- It has a JSON API
- It comes with a [Go wrapper](./examples#user-content-basic-client-usage) - It comes with a [Go wrapper](./examples#user-content-basic-client-usage)
- ... and there's a [command-line client](https://github.com/smallstep/cli) you can use in scripts! - ... and there's a [command-line client](https://github.com/smallstep/cli) you can use in scripts!
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
--- ---
**Don't want to run your own CA?** ### Comparison with Smallstep's commercial product
To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/).
`step-ca` is optimized for a two-tier PKI serving common DevOps use cases.
As you design your PKI, if you need any of the following, [consider our commerical CA](http://smallstep.com):
- Multiple certificate authorities
- Active revocation (CRL, OSCP)
- Turnkey high-volume, high availability CA
- An API for seamless IaC management of your PKI
- Integrated support for SCEP & NDES, for migrating from legacy Active Directory Certificate Services deployments
- Device identity — cross-platform device inventory and attestation using Secure Enclave & TPM 2.0
- Highly automated PKI — managed certificate renewal, monitoring, TPM-based attested enrollment
- Seamless client deployments of EAP-TLS Wi-Fi, VPN, SSH, and browser certificates
- Jamf, Intune, or other MDM for root distribution and client enrollment
- Web Admin UI — history, issuance, and metrics
- ACME External Account Binding (EAB)
- Deep integration with an identity provider
- Fine-grained, role-based access control
- FIPS-compliant software
- HSM-bound private keys
See our [full feature comparison](https://smallstep.com/step-ca-vs-smallstep-certificate-manager/) for more.
You can [start a free trial](https://smallstep.com/signup) or [set up a call with us](https://go.smallstep.com/request-demo) to learn more.
--- ---
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).** **Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).**
[Website](https://smallstep.com/certificates) | [Website](https://smallstep.com/certificates) |
[Documentation](https://smallstep.com/docs) | [Documentation](https://smallstep.com/docs/step-ca) |
[Installation](https://smallstep.com/docs/step-ca/installation) | [Installation](https://smallstep.com/docs/step-ca/installation) |
[Getting Started](https://smallstep.com/docs/step-ca/getting-started) |
[Contributor's Guide](./CONTRIBUTING.md) [Contributor's Guide](./CONTRIBUTING.md)
[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
[![Build Status](https://github.com/smallstep/certificates/actions/workflows/test.yml/badge.svg)](https://github.com/smallstep/certificates)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![CLA assistant](https://cla-assistant.io/readme/badge/smallstep/certificates)](https://cla-assistant.io/smallstep/certificates)
[![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers)
[![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs)
![star us](https://github.com/smallstep/certificates/raw/master/docs/images/star.gif)
## Features ## Features
### 🦾 A fast, stable, flexible private CA ### 🦾 A fast, stable, flexible private CA
@ -52,7 +65,6 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
- Choose key types (RSA, ECDSA, EdDSA) and lifetimes to suit your needs - Choose key types (RSA, ECDSA, EdDSA) and lifetimes to suit your needs
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca) - Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) - [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
@ -127,5 +139,5 @@ and visiting http://localhost:8080.
## Feedback? ## Feedback?
* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. * Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. [Join our Discord](https://u.step.sm/discord) or [GitHub Discussions](https://github.com/smallstep/certificates/discussions)
* Tell us about a feature you'd like to see! [Add a feature request Issue](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=), [ask on Discussions](https://github.com/smallstep/certificates/discussions), or hit us up on [Twitter](https://twitter.com/smallsteplabs). * Tell us about a feature you'd like to see! [Request a Feature](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=)

@ -281,7 +281,7 @@ type mockCA struct {
MockAreSANsallowed func(ctx context.Context, sans []string) error MockAreSANsallowed func(ctx context.Context, sans []string) error
} }
func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil return nil, nil
} }

@ -21,7 +21,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority. // CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface { type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error) IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error

@ -295,7 +295,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
signOps = append(signOps, extraOptions...) signOps = append(signOps, extraOptions...)
// Sign a new certificate. // Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{ certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter), NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...) }, signOps...)

@ -270,16 +270,16 @@ func TestOrder_UpdateStatus(t *testing.T) {
} }
type mockSignAuth struct { type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error) loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
} }
func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil { if m.signWithContext != nil {
return m.sign(csr, signOpts, extraOpts...) return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
@ -577,7 +577,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return nil, errors.New("force") return nil, errors.New("force")
}, },
@ -627,7 +627,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -684,7 +684,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -769,7 +769,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -862,7 +862,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -972,7 +972,7 @@ func TestOrder_Finalize(t *testing.T) {
// using the mocking functions as a wrapper for actual test helpers generated per test case or per // using the mocking functions as a wrapper for actual test helpers generated per test case or per
// function that's tested. // function that's tested.
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil return []*x509.Certificate{leaf, inter, root}, nil
}, },
@ -1043,7 +1043,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -1107,7 +1107,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },
@ -1174,7 +1174,7 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
ca: &mockSignAuth{ ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr) assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil return []*x509.Certificate{foo, bar, baz}, nil
}, },

@ -42,7 +42,7 @@ type Authority interface {
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)

@ -189,7 +189,7 @@ type mockAuthority struct {
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
@ -251,9 +251,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
return m.ret1.(*x509.Certificate), m.err return m.ret1.(*x509.Certificate), m.err
} }
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil { if m.signWithContext != nil {
return m.sign(cr, opts, signOpts...) return m.signWithContext(ctx, cr, opts, signOpts...)
} }
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
@ -884,16 +884,12 @@ func Test_Sign(t *testing.T) {
CsrPEM: CertificateRequest{csr}, CsrPEM: CertificateRequest{csr},
OTT: "foobarzar", OTT: "foobarzar",
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
invalid, err := json.Marshal(SignRequest{ invalid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr}, CsrPEM: CertificateRequest{csr},
OTT: "", OTT: "",
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
return return
} }
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return

@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0), NotAfter: time.Unix(int64(cert.ValidBefore), 0),
}) })
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return

@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) {
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
return tt.addUserCert, tt.addUserErr return tt.addUserCert, tt.addUserErr
}, },
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr return tt.tlsSignCerts, tt.tlsSignErr
}, },
}) })

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
csr, err := x509.ParseCertificateRequest(cr) csr, err := x509.ParseCertificateRequest(cr)
assert.FatalError(t, err) assert.FatalError(t, err)
cert, err := a.Sign(csr, provisioner.SignOptions{}) cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{})
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames)
assert.Equals(t, crt, cert[1]) assert.Equals(t, crt, cert[1])

@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
} }
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -36,7 +37,7 @@ type WebhookController struct {
// Enrich fetches data from remote servers and adds returned data to the // Enrich fetches data from remote servers and adds returned data to the
// templateData // templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -55,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) { if !wc.isCertTypeOK(wh) {
continue continue
} }
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }
@ -68,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
} }
// Authorize checks that all remote servers allow the request // Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
@ -87,7 +92,11 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) { if !wc.isCertTypeOK(wh) {
continue continue
} }
resp, err := wh.Do(wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil { if err != nil {
return err return err
} }
@ -123,13 +132,6 @@ type Webhook struct {
} `json:"-"` } `json:"-"`
} }
func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return w.DoWithContext(ctx, client, reqBody, data)
}
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil { if err != nil {
@ -169,6 +171,10 @@ retry:
return nil, err return nil, err
} }
if requestID, ok := requestid.FromContext(ctx); ok {
req.Header.Set("X-Request-Id", requestID)
}
secret, err := base64.StdEncoding.DecodeString(w.Secret) secret, err := base64.StdEncoding.DecodeString(w.Secret)
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,6 +1,7 @@
package provisioner package provisioner
import ( import (
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
@ -8,18 +9,23 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/webhook"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
) )
func TestWebhookController_isCertTypeOK(t *testing.T) { func TestWebhookController_isCertTypeOK(t *testing.T) {
@ -92,19 +98,25 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
} }
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh))
}) })
} }
} }
// withRequestID is a helper that calls into [requestid.NewContext] and returns
// a new context with the requestID added.
func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context {
t.Helper()
return requestid.NewContext(ctx, requestID)
}
func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Enrich(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
ctx context.Context
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
@ -129,6 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
@ -143,6 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) {
}, },
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -166,6 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
certType: linkedca.Webhook_X509, certType: linkedca.Webhook_X509,
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -185,14 +200,15 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}},
assertRequest: func(t *testing.T, req *webhook.RequestBody) { assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, &webhook.X5CCertificate{ assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw, Raw: cert.Raw,
PublicKey: key, PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -207,6 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -221,6 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -232,19 +250,21 @@ func TestWebhookController_Enrich(t *testing.T) {
for i, wh := range test.ctl.webhooks { for i, wh := range test.ctl.webhooks {
var j = i var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j]) err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err) require.NoError(t, err)
})) }))
// nolint: gocritic // defer in loop isn't a memory leak // nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close() defer ts.Close()
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Enrich(test.req) err := test.ctl.Enrich(test.ctx, test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData)
if test.assertRequest != nil { if test.assertRequest != nil {
test.assertRequest(t, test.req) test.assertRequest(t, test.req)
} }
@ -254,12 +274,11 @@ func TestWebhookController_Enrich(t *testing.T) {
func TestWebhookController_Authorize(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
ctx context.Context
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
@ -280,6 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
@ -290,6 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
certType: linkedca.Webhook_SSH, certType: linkedca.Webhook_SSH,
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false, expectErr: false,
@ -300,13 +321,14 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
assertRequest: func(t *testing.T, req *webhook.RequestBody) { assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, &webhook.X5CCertificate{ assert.Equal(t, &webhook.X5CCertificate{
Raw: cert.Raw, Raw: cert.Raw,
PublicKey: key, PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
@ -320,6 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -332,6 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -342,15 +366,17 @@ func TestWebhookController_Authorize(t *testing.T) {
for i, wh := range test.ctl.webhooks { for i, wh := range test.ctl.webhooks {
var j = i var j = i
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "reqID", r.Header.Get("X-Request-ID"))
err := json.NewEncoder(w).Encode(test.responses[j]) err := json.NewEncoder(w).Encode(test.responses[j])
assert.FatalError(t, err) require.NoError(t, err)
})) }))
// nolint: gocritic // defer in loop isn't a memory leak // nolint: gocritic // defer in loop isn't a memory leak
defer ts.Close() defer ts.Close()
wh.URL = ts.URL wh.URL = ts.URL
} }
err := test.ctl.Authorize(test.req) err := test.ctl.Authorize(test.ctx, test.req)
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
@ -366,6 +392,7 @@ func TestWebhook_Do(t *testing.T) {
type test struct { type test struct {
webhook Webhook webhook Webhook
dataArg any dataArg any
requestID string
webhookResponse webhook.ResponseBody webhookResponse webhook.ResponseBody
expectPath string expectPath string
errStatusCode int errStatusCode int
@ -375,6 +402,16 @@ func TestWebhook_Do(t *testing.T) {
} }
tests := map[string]test{ tests := map[string]test{
"ok": { "ok": {
webhook: Webhook{
ID: "abc123",
Secret: "c2VjcmV0Cg==",
},
requestID: "reqID",
webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"},
},
},
"ok/no-request-id": {
webhook: Webhook{ webhook: Webhook{
ID: "abc123", ID: "abc123",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
@ -389,6 +426,7 @@ func TestWebhook_Do(t *testing.T) {
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
BearerToken: "mytoken", BearerToken: "mytoken",
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -405,6 +443,7 @@ func TestWebhook_Do(t *testing.T) {
Password: "mypass", Password: "mypass",
}, },
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -416,7 +455,8 @@ func TestWebhook_Do(t *testing.T) {
URL: "/users/{{ .username }}?region={{ .region }}", URL: "/users/{{ .username }}?region={{ .region }}",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
}, },
dataArg: map[string]interface{}{"username": "areed", "region": "central"}, requestID: "reqID",
dataArg: map[string]interface{}{"username": "areed", "region": "central"},
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
@ -451,6 +491,7 @@ func TestWebhook_Do(t *testing.T) {
ID: "abc123", ID: "abc123",
Secret: "c2VjcmV0Cg==", Secret: "c2VjcmV0Cg==",
}, },
requestID: "reqID",
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Allow: true, Allow: true,
}, },
@ -463,6 +504,7 @@ func TestWebhook_Do(t *testing.T) {
webhookResponse: webhook.ResponseBody{ webhookResponse: webhook.ResponseBody{
Data: map[string]interface{}{"role": "dba"}, Data: map[string]interface{}{"role": "dba"},
}, },
requestID: "reqID",
errStatusCode: 404, errStatusCode: 404,
serverErrMsg: "item not found", serverErrMsg: "item not found",
expectErr: errors.New("Webhook server responded with 404"), expectErr: errors.New("Webhook server responded with 404"),
@ -471,17 +513,20 @@ func TestWebhook_Do(t *testing.T) {
for name, tc := range tests { for name, tc := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Smallstep-Webhook-ID") if tc.requestID != "" {
assert.Equals(t, tc.webhook.ID, id) assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID"))
}
assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID"))
sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature"))
assert.FatalError(t, err) assert.NoError(t, err)
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
assert.FatalError(t, err) assert.NoError(t, err)
secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret)
assert.FatalError(t, err) assert.NoError(t, err)
h := hmac.New(sha256.New, secret) h := hmac.New(sha256.New, secret)
h.Write(body) h.Write(body)
mac := h.Sum(nil) mac := h.Sum(nil)
@ -490,19 +535,19 @@ func TestWebhook_Do(t *testing.T) {
switch { switch {
case tc.webhook.BearerToken != "": case tc.webhook.BearerToken != "":
ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken)
assert.Equals(t, ah, r.Header.Get("Authorization")) assert.Equal(t, ah, r.Header.Get("Authorization"))
case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "":
whReq, err := http.NewRequest("", "", http.NoBody) whReq, err := http.NewRequest("", "", http.NoBody)
assert.FatalError(t, err) require.NoError(t, err)
whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password)
ah := whReq.Header.Get("Authorization") ah := whReq.Header.Get("Authorization")
assert.Equals(t, ah, whReq.Header.Get("Authorization")) assert.Equal(t, ah, whReq.Header.Get("Authorization"))
default: default:
assert.Equals(t, "", r.Header.Get("Authorization")) assert.Equal(t, "", r.Header.Get("Authorization"))
} }
if tc.expectPath != "" { if tc.expectPath != "" {
assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery)
} }
if tc.errStatusCode != 0 { if tc.errStatusCode != 0 {
@ -512,26 +557,33 @@ func TestWebhook_Do(t *testing.T) {
reqBody := new(webhook.RequestBody) reqBody := new(webhook.RequestBody)
err = json.Unmarshal(body, reqBody) err = json.Unmarshal(body, reqBody)
assert.FatalError(t, err) require.NoError(t, err)
// assert.Equals(t, tc.expectToken, reqBody.Token)
err = json.NewEncoder(w).Encode(tc.webhookResponse) err = json.NewEncoder(w).Encode(tc.webhookResponse)
assert.FatalError(t, err) require.NoError(t, err)
})) }))
defer ts.Close() defer ts.Close()
tc.webhook.URL = ts.URL + tc.webhook.URL tc.webhook.URL = ts.URL + tc.webhook.URL
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err) require.NoError(t, err)
got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg)
ctx := context.Background()
if tc.requestID != "" {
ctx = withRequestID(t, ctx, tc.requestID)
}
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg)
if tc.expectErr != nil { if tc.expectErr != nil {
assert.Equals(t, tc.expectErr.Error(), err.Error()) assert.Equal(t, tc.expectErr.Error(), err.Error())
return return
} }
assert.FatalError(t, err) assert.NoError(t, err)
assert.Equals(t, got, &tc.webhookResponse) assert.Equal(t, &tc.webhookResponse, got)
}) })
} }
@ -544,7 +596,7 @@ func TestWebhook_Do(t *testing.T) {
URL: ts.URL, URL: ts.URL,
} }
cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key")
assert.FatalError(t, err) require.NoError(t, err)
transport := http.DefaultTransport.(*http.Transport).Clone() transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{ transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -554,12 +606,19 @@ func TestWebhook_Do(t *testing.T) {
Transport: transport, Transport: transport,
} }
reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
assert.FatalError(t, err) require.NoError(t, err)
_, err = wh.Do(client, reqBody, nil)
assert.FatalError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
wh.DisableTLSClientAuth = true wh.DisableTLSClientAuth = true
_, err = wh.Do(client, reqBody, nil) _, err = wh.DoWithContext(ctx, client, reqBody, nil)
assert.Error(t, err) require.Error(t, err)
}) })
} }

@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
opts, err := a.Authorize(ctx, token) opts, err := a.Authorize(ctx, token)
require.NoError(t, err) require.NoError(t, err)
opts = append(opts, extraOpts...) opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
require.NoError(t, err) require.NoError(t, err)
return certs[0] return certs[0]
} }

@ -152,7 +152,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return cert, err return cert, err
} }
func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) {
var ( var (
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
@ -211,7 +211,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Call enriching webhooks // Call enriching webhooks
if err := a.callEnrichingWebhooksSSH(prov, webhookCtl, cr); err != nil { if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
@ -284,7 +284,7 @@ func (a *Authority) signSSH(_ context.Context, key ssh.PublicKey, opts provision
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksSSH(prov, webhookCtl, certificate, certTpl); err != nil { if err := a.callAuthorizingWebhooksSSH(ctx, prov, webhookCtl, certificate, certTpl); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
) )
@ -671,35 +671,33 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal) return strings.ReplaceAll(cmd, "<principal>", principal)
} }
func (a *Authority) callEnrichingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) { func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
defer func() { a.meter.SSHWebhookEnriched(prov, err) }()
var whEnrichReq *webhook.RequestBody var whEnrichReq *webhook.RequestBody
if whEnrichReq, err = webhook.NewRequestBody( if whEnrichReq, err = webhook.NewRequestBody(
webhook.WithSSHCertificateRequest(cr), webhook.WithSSHCertificateRequest(cr),
); err == nil { ); err == nil {
err = webhookCtl.Enrich(whEnrichReq) err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.SSHWebhookEnriched(prov, err)
} }
return return
} }
func (a *Authority) callAuthorizingWebhooksSSH(prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) { func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
defer func() { a.meter.SSHWebhookAuthorized(prov, err) }()
var whAuthBody *webhook.RequestBody var whAuthBody *webhook.RequestBody
if whAuthBody, err = webhook.NewRequestBody( if whAuthBody, err = webhook.NewRequestBody(
webhook.WithSSHCertificate(cert, certTpl), webhook.WithSSHCertificate(cert, certTpl),
); err == nil { ); err == nil {
err = webhookCtl.Authorize(whAuthBody) err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.SSHWebhookAuthorized(prov, err)
} }
return return

@ -91,14 +91,23 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
} }
} }
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(csr, signOpts, extraOpts...) return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}
// SignWithContext creates a signed certificate from a certificate signing
// request, taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
chain, prov, err := a.signX509(ctx, csr, signOpts, extraOpts...)
a.meter.X509Signed(prov, err) a.meter.X509Signed(prov, err)
return chain, err return chain, err
} }
func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) { func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) {
var ( var (
certOptions []x509util.Option certOptions []x509util.Option
certValidators []provisioner.CertificateValidator certValidators []provisioner.CertificateValidator
@ -171,7 +180,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
} }
} }
if err := a.callEnrichingWebhooksX509(prov, webhookCtl, attData, csr); err != nil { if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()), errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr), errs.WithKeyVal("csr", csr),
@ -265,7 +274,7 @@ func (a *Authority) signX509(csr *x509.CertificateRequest, signOpts provisioner.
} }
// Send certificate to webhooks for authorization // Send certificate to webhooks for authorization
if err := a.callAuthorizingWebhooksX509(prov, webhookCtl, crt, leaf, attData); err != nil { if err := a.callAuthorizingWebhooksX509(ctx, prov, webhookCtl, crt, leaf, attData); err != nil {
return nil, prov, errs.ApplyOptions( return nil, prov, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"), errs.ForbiddenErr(err, "error creating certificate"),
opts..., opts...,
@ -986,10 +995,11 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template") return errors.Wrap(cause, "error applying certificate template")
} }
func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) { func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
defer func() { a.meter.X509WebhookEnriched(prov, err) }()
var attested *webhook.AttestationData var attested *webhook.AttestationData
if attData != nil { if attData != nil {
@ -1003,18 +1013,17 @@ func (a *Authority) callEnrichingWebhooksX509(prov provisioner.Interface, webhoo
webhook.WithX509CertificateRequest(csr), webhook.WithX509CertificateRequest(csr),
webhook.WithAttestationData(attested), webhook.WithAttestationData(attested),
); err == nil { ); err == nil {
err = webhookCtl.Enrich(whEnrichReq) err = webhookCtl.Enrich(ctx, whEnrichReq)
a.meter.X509WebhookEnriched(prov, err)
} }
return return
} }
func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) { func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) {
if webhookCtl == nil { if webhookCtl == nil {
return return
} }
defer func() { a.meter.X509WebhookAuthorized(prov, err) }()
var attested *webhook.AttestationData var attested *webhook.AttestationData
if attData != nil { if attData != nil {
@ -1028,9 +1037,7 @@ func (a *Authority) callAuthorizingWebhooksX509(prov provisioner.Interface, webh
webhook.WithX509Certificate(cert, leaf), webhook.WithX509Certificate(cert, leaf),
webhook.WithAttestationData(attested), webhook.WithAttestationData(attested),
); err == nil { ); err == nil {
err = webhookCtl.Authorize(whAuthBody) err = webhookCtl.Authorize(ctx, whAuthBody)
a.meter.X509WebhookAuthorized(prov, err)
} }
return return

@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
return nil return nil
} }
func TestAuthority_Sign(t *testing.T) { func TestAuthority_SignWithContext(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair() pub, priv, err := keyutil.GenerateDefaultKeyPair()
require.NoError(t, err) require.NoError(t, err)
@ -848,7 +848,7 @@ ZYtQ9Ot36qc=
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption) _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr)
} }
_, err = auth.Renew(cert) _, err = auth.Renew(cert)

@ -1,8 +1,12 @@
package authority package authority
import "github.com/smallstep/certificates/webhook" import (
"context"
"github.com/smallstep/certificates/webhook"
)
type webhookController interface { type webhookController interface {
Enrich(*webhook.RequestBody) error Enrich(context.Context, *webhook.RequestBody) error
Authorize(*webhook.RequestBody) error Authorize(context.Context, *webhook.RequestBody) error
} }

@ -1,6 +1,8 @@
package authority package authority
import ( import (
"context"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
) )
@ -14,7 +16,7 @@ type mockWebhookController struct {
var _ webhookController = &mockWebhookController{} var _ webhookController = &mockWebhookController{}
func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData { for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data) wc.templateData.SetWebhook(key, data)
} }
@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
return wc.enrichErr return wc.enrichErr
} }
func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr return wc.authorizeErr
} }

@ -48,6 +48,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint) return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint)
} }
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := ac.client.Do(req) resp, err := ac.client.Do(req)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", endpoint) return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
@ -109,6 +110,7 @@ func (c *ACMEClient) GetNonce() (string, error) {
return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce) return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce)
} }
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce) return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
@ -188,6 +190,7 @@ func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOpt
} }
req.Header.Set("Content-Type", "application/jose+json") req.Header.Set("Content-Type", "application/jose+json")
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder) return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder)

@ -29,6 +29,7 @@ import (
"github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/metrix" "github.com/smallstep/certificates/internal/metrix"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/monitoring"
"github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/scep"
@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
} }
// Add logger if configured // Add logger if configured
var legacyTraceHeader string
if len(cfg.Logger) > 0 { if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", cfg.Logger) logger, err := logging.New("ca", cfg.Logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
legacyTraceHeader = logger.GetTraceHeader()
handler = logger.Middleware(handler) handler = logger.Middleware(handler)
insecureHandler = logger.Middleware(insecureHandler) insecureHandler = logger.Middleware(insecureHandler)
} }
// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
handler = requestid.New(legacyTraceHeader).Middleware(handler)
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)
// Create context with all the necessary values. // Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)

@ -289,6 +289,9 @@ ZEp7knvU2psWRw==
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest { if rr.Code < http.StatusBadRequest {
var sign api.SignResponse var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign)) assert.FatalError(t, readJSON(body, &sign))
@ -325,7 +328,7 @@ ZEp7knvU2psWRw==
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, intermediate, realIntermediate)
} else { } else {
err := readError(body) err := readError(resp)
if tc.errMsg == "" { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
@ -369,6 +372,9 @@ func TestCAProvisioners(t *testing.T) {
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest { if rr.Code < http.StatusBadRequest {
var resp api.ProvisionersResponse var resp api.ProvisionersResponse
@ -379,7 +385,7 @@ func TestCAProvisioners(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, a, b) assert.Equals(t, a, b)
} else { } else {
err := readError(body) err := readError(resp)
if tc.errMsg == "" { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
@ -436,12 +442,15 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest { if rr.Code < http.StatusBadRequest {
var ek api.ProvisionerKeyResponse var ek api.ProvisionerKeyResponse
assert.FatalError(t, readJSON(body, &ek)) assert.FatalError(t, readJSON(body, &ek))
assert.Equals(t, ek.Key, tc.expectedKey) assert.Equals(t, ek.Key, tc.expectedKey)
} else { } else {
err := readError(body) err := readError(resp)
if tc.errMsg == "" { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
@ -498,12 +507,15 @@ func TestCARoot(t *testing.T) {
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest { if rr.Code < http.StatusBadRequest {
var root api.RootResponse var root api.RootResponse
assert.FatalError(t, readJSON(body, &root)) assert.FatalError(t, readJSON(body, &root))
assert.Equals(t, root.RootPEM.Certificate, rootCrt) assert.Equals(t, root.RootPEM.Certificate, rootCrt)
} else { } else {
err := readError(body) err := readError(resp)
if tc.errMsg == "" { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
@ -641,6 +653,9 @@ func TestCARenew(t *testing.T) {
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest { if rr.Code < http.StatusBadRequest {
var sign api.SignResponse var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign)) assert.FatalError(t, readJSON(body, &sign))
@ -673,7 +688,7 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
} else { } else {
err := readError(body) err := readError(resp)
if tc.errMsg == "" { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }

@ -27,12 +27,14 @@ import (
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
@ -83,8 +85,7 @@ func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u) return nil, errors.Wrapf(err, "create GET %s request failed", u)
} }
req.Header.Set("User-Agent", UserAgent) return c.Do(req)
return c.Client.Do(req)
} }
func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
@ -97,12 +98,43 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b
return nil, errors.Wrapf(err, "create POST %s request failed", u) return nil, errors.Wrapf(err, "create POST %s request failed", u)
} }
req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Type", contentType)
req.Header.Set("User-Agent", UserAgent) return c.Do(req)
return c.Client.Do(req) }
// requestIDHeader is the header name used for propagating request IDs from
// the CA client to the CA and back again.
const requestIDHeader = "X-Request-Id"
// newRequestID generates a new random UUIDv4 request ID. If it fails,
// the request ID will be the empty string.
func newRequestID() string {
requestID, err := randutil.UUIDv4()
if err != nil {
return ""
}
return requestID
}
// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
// empty, the context is searched for a request ID. If that's also empty, a new
// request ID is generated.
func enforceRequestID(r *http.Request) {
if requestID := r.Header.Get(requestIDHeader); requestID == "" {
if reqID, ok := client.RequestIDFromContext(r.Context()); ok {
// TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been
// used before by the client (unless it's a retry for the same request)?
requestID = reqID
} else {
requestID = newRequestID()
}
r.Header.Set(requestIDHeader, requestID)
}
} }
func (c *uaClient) Do(req *http.Request) (*http.Response, error) { func (c *uaClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
return c.Client.Do(req) return c.Client.Do(req)
} }
@ -375,8 +407,8 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
client := &Client{endpoint: u} caClient := &Client{endpoint: u}
root, err := client.Root(sum) root, err := caClient.Root(sum)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -610,7 +642,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var version api.VersionResponse var version api.VersionResponse
if err := readJSON(resp.Body, &version); err != nil { if err := readJSON(resp.Body, &version); err != nil {
@ -640,7 +672,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var health api.HealthResponse var health api.HealthResponse
if err := readJSON(resp.Body, &health); err != nil { if err := readJSON(resp.Body, &health); err != nil {
@ -675,7 +707,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var root api.RootResponse var root api.RootResponse
if err := readJSON(resp.Body, &root); err != nil { if err := readJSON(resp.Body, &root); err != nil {
@ -714,7 +746,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
@ -737,14 +769,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
client := &http.Client{Transport: tr} httpClient := &http.Client{Transport: tr}
retry: retry:
req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
return nil, clientError(err) return nil, clientError(err)
} }
@ -753,7 +785,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
@ -790,7 +822,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
@ -814,14 +846,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
client := &http.Client{Transport: tr} httpClient := &http.Client{Transport: tr}
retry: retry:
httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
resp, err := client.Do(httpReq) resp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
return nil, clientError(err) return nil, clientError(err)
} }
@ -830,7 +862,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var sign api.SignResponse var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
@ -853,16 +885,16 @@ func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest,
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
var client *uaClient var uaClient *uaClient
retry: retry:
if tr != nil { if tr != nil {
client = newClient(tr) uaClient = newClient(tr)
} else { } else {
client = c.client uaClient = c.client
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"})
resp, err := client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) resp, err := uaClient.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, clientError(err) return nil, clientError(err)
} }
@ -871,7 +903,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var revoke api.RevokeResponse var revoke api.RevokeResponse
if err := readJSON(resp.Body, &revoke); err != nil { if err := readJSON(resp.Body, &revoke); err != nil {
@ -914,7 +946,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var provisioners api.ProvisionersResponse var provisioners api.ProvisionersResponse
if err := readJSON(resp.Body, &provisioners); err != nil { if err := readJSON(resp.Body, &provisioners); err != nil {
@ -946,7 +978,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var key api.ProvisionerKeyResponse var key api.ProvisionerKeyResponse
if err := readJSON(resp.Body, &key); err != nil { if err := readJSON(resp.Body, &key); err != nil {
@ -976,7 +1008,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var roots api.RootsResponse var roots api.RootsResponse
if err := readJSON(resp.Body, &roots); err != nil { if err := readJSON(resp.Body, &roots); err != nil {
@ -1006,7 +1038,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var federation api.FederationResponse var federation api.FederationResponse
if err := readJSON(resp.Body, &federation); err != nil { if err := readJSON(resp.Body, &federation); err != nil {
@ -1040,7 +1072,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var sign api.SSHSignResponse var sign api.SSHSignResponse
if err := readJSON(resp.Body, &sign); err != nil { if err := readJSON(resp.Body, &sign); err != nil {
@ -1074,7 +1106,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var renew api.SSHRenewResponse var renew api.SSHRenewResponse
if err := readJSON(resp.Body, &renew); err != nil { if err := readJSON(resp.Body, &renew); err != nil {
@ -1108,7 +1140,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var rekey api.SSHRekeyResponse var rekey api.SSHRekeyResponse
if err := readJSON(resp.Body, &rekey); err != nil { if err := readJSON(resp.Body, &rekey); err != nil {
@ -1142,7 +1174,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var revoke api.SSHRevokeResponse var revoke api.SSHRevokeResponse
if err := readJSON(resp.Body, &revoke); err != nil { if err := readJSON(resp.Body, &revoke); err != nil {
@ -1172,7 +1204,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var keys api.SSHRootsResponse var keys api.SSHRootsResponse
if err := readJSON(resp.Body, &keys); err != nil { if err := readJSON(resp.Body, &keys); err != nil {
@ -1202,7 +1234,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var keys api.SSHRootsResponse var keys api.SSHRootsResponse
if err := readJSON(resp.Body, &keys); err != nil { if err := readJSON(resp.Body, &keys); err != nil {
@ -1236,7 +1268,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var cfg api.SSHConfigResponse var cfg api.SSHConfigResponse
if err := readJSON(resp.Body, &cfg); err != nil { if err := readJSON(resp.Body, &cfg); err != nil {
@ -1275,7 +1307,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {
@ -1304,7 +1336,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var hosts api.SSHGetHostsResponse var hosts api.SSHGetHostsResponse
if err := readJSON(resp.Body, &hosts); err != nil { if err := readJSON(resp.Body, &hosts); err != nil {
@ -1336,7 +1368,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readError(resp)
} }
var bastion api.SSHBastionResponse var bastion api.SSHBastionResponse
if err := readJSON(resp.Body, &bastion); err != nil { if err := readJSON(resp.Body, &bastion); err != nil {
@ -1504,12 +1536,13 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
return protojson.Unmarshal(data, m) return protojson.Unmarshal(data, m)
} }
func readError(r io.ReadCloser) error { func readError(r *http.Response) error {
defer r.Close() defer r.Body.Close()
apiErr := new(errs.Error) apiErr := new(errs.Error)
if err := json.NewDecoder(r).Decode(apiErr); err != nil { if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil {
return err return fmt.Errorf("failed decoding CA error response: %w", err)
} }
apiErr.RequestID = r.Header.Get("X-Request-Id")
return apiErr return apiErr
} }

@ -0,0 +1,18 @@
package client
import "context"
type contextKey struct{}
// NewRequestIDContext returns a new context with the given request ID added to the
// context.
func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, contextKey{}, requestID)
}
// RequestIDFromContext returns the request ID from the context if it exists.
// and is not empty.
func RequestIDFromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != ""
}

@ -9,24 +9,26 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
"go.step.sm/crypto/x509util" "github.com/google/uuid"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -106,52 +108,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
-----END CERTIFICATE REQUEST-----` -----END CERTIFICATE REQUEST-----`
) )
func mustKey() *ecdsa.PrivateKey { func mustKey(t *testing.T) *ecdsa.PrivateKey {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { require.NoError(t, err)
panic(err)
}
return priv return priv
} }
func parseCertificate(data string) *x509.Certificate { func parseCertificate(t *testing.T, data string) *x509.Certificate {
t.Helper()
block, _ := pem.Decode([]byte(data)) block, _ := pem.Decode([]byte(data))
if block == nil { if block == nil {
panic("failed to parse certificate PEM") require.Fail(t, "failed to parse certificate PEM")
return nil
} }
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { require.NoError(t, err, "failed to parse certificate")
panic("failed to parse certificate: " + err.Error())
}
return cert return cert
} }
func parseCertificateRequest(string) *x509.CertificateRequest { func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest {
t.Helper()
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode([]byte(csrPEM))
if block == nil { if block == nil {
panic("failed to parse certificate request PEM") require.Fail(t, "failed to parse certificate request PEM")
return nil
} }
csr, err := x509.ParseCertificateRequest(block.Bytes) csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil { require.NoError(t, err, "failed to parse certificate request")
panic("failed to parse certificate request: " + err.Error())
}
return csr return csr
} }
func equalJSON(t *testing.T, a, b interface{}) bool { func equalJSON(t *testing.T, a, b interface{}) bool {
t.Helper()
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(a, b) {
return true return true
} }
ab, err := json.Marshal(a) ab, err := json.Marshal(a)
if err != nil { require.NoError(t, err)
t.Error(err)
return false
}
bb, err := json.Marshal(b) bb, err := json.Marshal(b)
if err != nil { require.NoError(t, err)
t.Error(err)
return false
}
return bytes.Equal(ab, bb) return bytes.Equal(ab, bb)
} }
@ -176,32 +175,23 @@ func TestClient_Version(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Version() got, err := c.Version()
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Version() = %v, want nil", got)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -226,40 +216,30 @@ func TestClient_Health(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Health() got, err := c.Health()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr) assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Health() = %v, want nil", got)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Root(t *testing.T) { func TestClient_Root(t *testing.T) {
ok := &api.RootResponse{ ok := &api.RootResponse{
RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
} }
tests := []struct { tests := []struct {
@ -280,10 +260,7 @@ func TestClient_Root(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
expected := "/root/" + tt.shasum expected := "/root/" + tt.shasum
@ -294,37 +271,31 @@ func TestClient_Root(t *testing.T) {
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Root() = %v, want nil", got)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Sign(t *testing.T) { func TestClient_Sign(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
request := &api.SignRequest{ request := &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
OTT: "the-ott", OTT: "the-ott",
NotBefore: api.NewTimeDuration(time.Now()), NotBefore: api.NewTimeDuration(time.Now()),
NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)),
@ -350,16 +321,13 @@ func TestClient_Sign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -375,23 +343,16 @@ func TestClient_Sign(t *testing.T) {
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr) assert.EqualError(t, err, tt.expectedErr.Error())
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -422,16 +383,13 @@ func TestClient_Revoke(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") require.True(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -447,34 +405,27 @@ func TestClient_Revoke(t *testing.T) {
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr) assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Renew(t *testing.T) { func TestClient_Renew(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -497,49 +448,38 @@ func TestClient_Renew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_RenewWithToken(t *testing.T) { func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -562,10 +502,7 @@ func TestClient_RenewWithToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" { if req.Header.Get("Authorization") != "Bearer token" {
@ -576,44 +513,36 @@ func TestClient_RenewWithToken(t *testing.T) {
}) })
got, err := c.RenewWithToken("token") got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_Rekey(t *testing.T) { func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
request := &api.RekeyRequest{ request := &api.RekeyRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)},
} }
tests := []struct { tests := []struct {
@ -636,38 +565,27 @@ func TestClient_Rekey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Rekey(tt.request, nil) got, err := c.Rekey(tt.request, nil)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -699,10 +617,7 @@ func TestClient_Provisioners(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != tt.expectedURI { if req.RequestURI != tt.expectedURI {
@ -712,22 +627,16 @@ func TestClient_Provisioners(t *testing.T) {
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got)
}
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -755,10 +664,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
expected := "/provisioners/" + tt.kid + "/encrypted-key" expected := "/provisioners/" + tt.kid + "/encrypted-key"
@ -769,27 +675,20 @@ func TestClient_ProvisionerKey(t *testing.T) {
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) if assert.Error(t, err) {
var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -797,7 +696,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
func TestClient_Roots(t *testing.T) { func TestClient_Roots(t *testing.T) {
ok := &api.RootsResponse{ ok := &api.RootsResponse{
Certificates: []api.Certificate{ Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -819,37 +718,27 @@ func TestClient_Roots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Roots() got, err := c.Roots()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -857,7 +746,7 @@ func TestClient_Roots(t *testing.T) {
func TestClient_Federation(t *testing.T) { func TestClient_Federation(t *testing.T) {
ok := &api.FederationResponse{ ok := &api.FederationResponse{
Certificates: []api.Certificate{ Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
@ -878,46 +767,34 @@ func TestClient_Federation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Federation() got, err := c.Federation()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
func TestClient_SSHRoots(t *testing.T) { func TestClient_SSHRoots(t *testing.T) {
key, err := ssh.NewPublicKey(mustKey().Public()) key, err := ssh.NewPublicKey(mustKey(t).Public())
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
ok := &api.SSHRootsResponse{ ok := &api.SSHRootsResponse{
HostKeys: []api.SSHPublicKey{{PublicKey: key}}, HostKeys: []api.SSHPublicKey{{PublicKey: key}},
@ -941,37 +818,27 @@ func TestClient_SSHRoots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHRoots() got, err := c.SSHRoots()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) var sc render.StatusCodedError
if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
}
assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
}
assert.Nil(t, got)
return return
} }
switch { assert.NoError(t, err)
case err != nil: assert.Equal(t, tt.response, got)
if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got)
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
}
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
}
}
}) })
} }
} }
@ -1003,13 +870,14 @@ func Test_parseEndpoint(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := parseEndpoint(tt.args.endpoint) got, err := parseEndpoint(tt.args.endpoint)
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) assert.Error(t, err)
assert.Nil(t, got)
return return
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) assert.NoError(t, err)
} assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -1042,24 +910,21 @@ func TestClient_RootFingerprint(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tr := tt.server.Client().Transport tr := tt.server.Client().Transport
c, err := NewClient(tt.server.URL, WithTransport(tr)) c, err := NewClient(tt.server.URL, WithTransport(tr))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) assert.Error(t, err)
t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) assert.Empty(t, got)
return return
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) assert.NoError(t, err)
} assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -1068,12 +933,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
srv := startCABootstrapServer() srv := startCABootstrapServer()
defer srv.Close() defer srv.Close()
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
assert.FatalError(t, err) require.NoError(t, err)
fp, err := client.RootFingerprint() fp, err := caClient.RootFingerprint()
assert.FatalError(t, err) assert.NoError(t, err)
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
} }
func TestClient_SSHBastion(t *testing.T) { func TestClient_SSHBastion(t *testing.T) {
@ -1103,39 +968,29 @@ func TestClient_SSHBastion(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
render.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHBastion(tt.request) got, err := c.SSHBastion(tt.request)
if (err != nil) != tt.wantErr { if tt.wantErr {
fmt.Printf("%+v", err) if assert.Error(t, err) {
t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) if tt.responseCode != 200 {
return var sc render.StatusCodedError
} if assert.ErrorAs(t, err, &sc) {
assert.Equal(t, tt.responseCode, sc.StatusCode())
switch { }
case err != nil: assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error()))
if got != nil {
t.Errorf("Client.SSHBastion() = %v, want nil", got)
}
if tt.responseCode != 200 {
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response)
} }
assert.Nil(t, got)
return
} }
assert.NoError(t, err)
assert.Equal(t, tt.response, got)
}) })
} }
} }
@ -1154,13 +1009,60 @@ func TestClient_GetCaURL(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Errorf("NewClient() error = %v", err)
return got := c.GetCaURL()
} assert.Equal(t, tt.want, got)
if got := c.GetCaURL(); got != tt.want { })
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) }
}
func Test_enforceRequestID(t *testing.T) {
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
set.Header.Set("X-Request-Id", "already-set")
inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context"))
newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
tests := []struct {
name string
r *http.Request
want string
}{
{
name: "set",
r: set,
want: "already-set",
},
{
name: "context",
r: inContext,
want: "from-context",
},
{
name: "new",
r: newRequestID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
enforceRequestID(tt.r)
v := tt.r.Header.Get("X-Request-Id")
if assert.NotEmpty(t, v) {
if tt.want != "" {
assert.Equal(t, tt.want, v)
}
} }
}) })
} }
} }
func Test_newRequestID(t *testing.T) {
requestID := newRequestID()
u, err := uuid.Parse(requestID)
assert.NoError(t, err)
assert.Equal(t, uuid.Version(0x4), u.Version())
assert.Equal(t, uuid.RFC4122, u.Variant())
assert.Equal(t, requestID, u.String())
}

@ -7,6 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
@ -41,14 +43,12 @@ func getTestProvisioner(t *testing.T, caURL string) *Provisioner {
} }
func TestNewProvisioner(t *testing.T) { func TestNewProvisioner(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type args struct { type args struct {
name string name string

@ -10,6 +10,8 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/stretchr/testify/require"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
) )
@ -130,7 +132,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootCA(t *testing.T) { func TestAddRootCA(t *testing.T) {
cert := parseCertificate(rootPEM) cert := parseCertificate(t, rootPEM)
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -163,7 +165,7 @@ func TestAddRootCA(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddClientCA(t *testing.T) { func TestAddClientCA(t *testing.T) {
cert := parseCertificate(rootPEM) cert := parseCertificate(t, rootPEM)
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -196,25 +198,19 @@ func TestAddClientCA(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToRootCAs(t *testing.T) { func TestAddRootsToRootCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -251,25 +247,19 @@ func TestAddRootsToRootCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToClientCAs(t *testing.T) { func TestAddRootsToClientCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -306,31 +296,23 @@ func TestAddRootsToClientCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToRootCAs(t *testing.T) { func TestAddFederationToRootCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)
@ -371,31 +353,23 @@ func TestAddFederationToRootCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToClientCAs(t *testing.T) { func TestAddFederationToClientCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)
@ -436,25 +410,19 @@ func TestAddFederationToClientCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToCAs(t *testing.T) { func TestAddRootsToCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(cert) pool.AddCert(cert)
@ -491,31 +459,23 @@ func TestAddRootsToCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToCAs(t *testing.T) { func TestAddFederationToCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(string(federated)) crt2 := parseCertificate(t, string(federated))
pool := x509.NewCertPool() pool := x509.NewCertPool()
pool.AddCert(crt1) pool.AddCert(crt1)
pool.AddCert(crt2) pool.AddCert(crt2)

@ -17,27 +17,28 @@ import (
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/api" "github.com/stretchr/testify/require"
"github.com/smallstep/certificates/authority"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
) )
func generateOTT(subject string) string { func generateOTT(t *testing.T, subject string) string {
t.Helper()
now := time.Now() now := time.Now()
jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
if err != nil { require.NoError(t, err)
panic(err)
}
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
if err != nil { require.NoError(t, err)
panic(err)
}
id, err := randutil.ASCII(64) id, err := randutil.ASCII(64)
if err != nil { require.NoError(t, err)
panic(err)
}
cl := struct { cl := struct {
jose.Claims jose.Claims
SANS []string `json:"sans"` SANS []string `json:"sans"`
@ -53,9 +54,8 @@ func generateOTT(subject string) string {
SANS: []string{subject}, SANS: []string{subject},
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
if err != nil { require.NoError(t, err)
panic(err)
}
return raw return raw
} }
@ -72,32 +72,28 @@ func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler
return srv return srv
} }
func startCATestServer() *httptest.Server { func startCATestServer(t *testing.T) *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json") config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil { require.NoError(t, err)
panic(err)
}
ca, err := New(config) ca, err := New(config)
if err != nil { require.NoError(t, err)
panic(err)
}
// Use a httptest.Server instead // Use a httptest.Server instead
baseContext := buildContext(ca.auth, nil, nil, nil) baseContext := buildContext(ca.auth, nil, nil, nil)
srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler)
return srv return srv
} }
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
srv := startCATestServer() t.Helper()
srv := startCATestServer(t)
defer srv.Close() defer srv.Close()
return signDuration(srv, domain, 0) return signDuration(t, srv, domain, 0)
} }
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
req, pk, err := CreateSignRequest(generateOTT(domain)) t.Helper()
if err != nil { req, pk, err := CreateSignRequest(generateOTT(t, domain))
panic(err) require.NoError(t, err)
}
if duration > 0 { if duration > 0 {
req.NotBefore = api.NewTimeDuration(time.Now()) req.NotBefore = api.NewTimeDuration(time.Now())
@ -105,13 +101,11 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
} }
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
panic(err)
}
sr, err := client.Sign(req) sr, err := client.Sign(req)
if err != nil { require.NoError(t, err)
panic(err)
}
return client, sr, pk return client, sr, pk
} }
@ -145,7 +139,7 @@ func serverHandler(t *testing.T, clientDomain string) http.Handler {
func TestClient_GetServerTLSConfig_http(t *testing.T) { func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain" clientDomain := "test.domain"
client, sr, pk := sign("127.0.0.1") client, sr, pk := sign(t, "127.0.0.1")
// Create mTLS server // Create mTLS server
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -212,7 +206,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain) client, sr, pk := sign(t, clientDomain)
cli := tt.getClient(t, client, sr, pk) cli := tt.getClient(t, client, sr, pk)
if cli == nil { if cli == nil {
return return
@ -246,19 +240,18 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
defer reset() defer reset()
// Start CA // Start CA
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
clientDomain := "test.domain" clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second)
// Start mTLS server // Start mTLS server
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close() defer srvMTLS.Close()
@ -266,30 +259,26 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(context.Background())
defer cancel() defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close() defer srvTLS.Close()
// Transport // Transport
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
tr1, err := client.Transport(context.Background(), sr, pk) tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig // Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2 := getDefaultTransport(tlsConfig) tr2 := getDefaultTransport(tlsConfig)
// No client cert // No client cert
root, err := RootCertificate(sr) root, err := RootCertificate(sr)
if err != nil { require.NoError(t, err)
t.Fatalf("RootCertificate() error = %v", err)
}
tlsConfig = getDefaultTLSConfig(sr) tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root) tlsConfig.RootCAs.AddCert(root)
@ -401,13 +390,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
func TestCertificate(t *testing.T) { func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM) cert := parseCertificate(t, certPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: cert}, ServerPEM: api.Certificate{Certificate: cert},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: cert}, {Certificate: cert},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
tests := []struct { tests := []struct {
@ -434,12 +423,12 @@ func TestCertificate(t *testing.T) {
} }
func TestIntermediateCertificate(t *testing.T) { func TestIntermediateCertificate(t *testing.T) {
intermediate := parseCertificate(rootPEM) intermediate := parseCertificate(t, rootPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: intermediate}, CaPEM: api.Certificate{Certificate: intermediate},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: intermediate}, {Certificate: intermediate},
}, },
} }
@ -467,24 +456,24 @@ func TestIntermediateCertificate(t *testing.T) {
} }
func TestRootCertificateCertificate(t *testing.T) { func TestRootCertificateCertificate(t *testing.T) {
root := parseCertificate(rootPEM) root := parseCertificate(t, rootPEM)
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{
{root, root}, {root, root},
}}, }},
} }
noTLS := &api.SignResponse{ noTLS := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)},
CertChainPEM: []api.Certificate{ CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)}, {Certificate: parseCertificate(t, certPEM)},
{Certificate: parseCertificate(rootPEM)}, {Certificate: parseCertificate(t, rootPEM)},
}, },
} }
tests := []struct { tests := []struct {

@ -49,10 +49,11 @@ func WithKeyVal(key string, val interface{}) Option {
// Error represents the CA API errors. // Error represents the CA API errors.
type Error struct { type Error struct {
Status int Status int
Err error Err error
Msg string Msg string
Details map[string]interface{} Details map[string]interface{}
RequestID string `json:"-"`
} }
// ErrorResponse represents an error in JSON format. // ErrorResponse represents an error in JSON format.

@ -2,8 +2,9 @@ package errs
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestError_MarshalJSON(t *testing.T) { func TestError_MarshalJSON(t *testing.T) {
@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) {
Err: tt.fields.Err, Err: tt.fields.Err,
} }
got, err := e.MarshalJSON() got, err := e.MarshalJSON()
if (err != nil) != tt.wantErr { if tt.wantErr {
t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) assert.Error(t, err)
assert.Empty(t, got)
return return
} }
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want) assert.NoError(t, err)
} assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
e := new(Error) e := new(Error)
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { err := e.UnmarshalJSON(tt.args.data)
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) if tt.wantErr {
} assert.Error(t, err)
//nolint:govet // best option return
if !reflect.DeepEqual(tt.expected, e) {
t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e)
} }
assert.NoError(t, err)
assert.Equal(t, tt.expected, e)
}) })
} }
} }

@ -7,12 +7,12 @@ sleep 5
rm -f /var/local/step/root_ca.crt rm -f /var/local/step/root_ca.crt
rm -f /var/local/step/site.crt /var/local/step/site.key rm -f /var/local/step/site.crt /var/local/step/site.key
# Donwload the root certificate # Download the root certificate
step ca root /var/local/step/root_ca.crt step ca root /var/local/step/root_ca.crt
# Get token # Get token
STEP_TOKEN=$(step ca token $COMMON_NAME) STEP_TOKEN=$(step ca token $COMMON_NAME)
# Donwload the root certificate # Download the root certificate
step ca certificate --token $STEP_TOKEN $COMMON_NAME /var/local/step/site.crt /var/local/step/site.key step ca certificate --token $STEP_TOKEN $COMMON_NAME /var/local/step/site.crt /var/local/step/site.key
exec "$@" exec "$@"

@ -9,14 +9,14 @@ require (
github.com/coreos/go-oidc/v3 v3.4.0 github.com/coreos/go-oidc/v3 v3.4.0
github.com/dgraph-io/badger v1.6.2 github.com/dgraph-io/badger v1.6.2
github.com/dgraph-io/badger/v2 v2.2007.4 github.com/dgraph-io/badger/v2 v2.2007.4
github.com/fxamacker/cbor/v2 v2.5.0 github.com/fxamacker/cbor/v2 v2.6.0
github.com/go-chi/chi/v5 v5.0.11 github.com/go-chi/chi/v5 v5.0.11
github.com/go-jose/go-jose/v3 v3.0.1 github.com/go-jose/go-jose/v3 v3.0.2
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/google/go-tpm v0.9.0 github.com/google/go-tpm v0.9.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/googleapis/gax-go/v2 v2.12.0 github.com/googleapis/gax-go/v2 v2.12.2
github.com/hashicorp/vault/api v1.12.0 github.com/hashicorp/vault/api v1.12.0
github.com/hashicorp/vault/api/auth/approle v0.6.0 github.com/hashicorp/vault/api/auth/approle v0.6.0
github.com/hashicorp/vault/api/auth/kubernetes v0.6.0 github.com/hashicorp/vault/api/auth/kubernetes v0.6.0
@ -40,7 +40,7 @@ require (
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
golang.org/x/net v0.21.0 golang.org/x/net v0.21.0
google.golang.org/api v0.165.0 google.golang.org/api v0.165.0
google.golang.org/grpc v1.61.0 google.golang.org/grpc v1.62.0
google.golang.org/protobuf v1.32.0 google.golang.org/protobuf v1.32.0
) )
@ -94,7 +94,7 @@ require (
github.com/go-piv/piv-go v1.11.0 // indirect github.com/go-piv/piv-go v1.11.0 // indirect
github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.0 // indirect github.com/golang-jwt/jwt/v5 v5.2.0 // indirect
github.com/golang/glog v1.1.2 // indirect github.com/golang/glog v1.2.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
@ -163,7 +163,7 @@ require (
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe // indirect google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

@ -159,7 +159,7 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 h1:7To3pQ+pZo0i3dsWEbinPNFs5gPSBOsJtx3wTT94VBY= github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
@ -201,22 +201,23 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.m
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA=
github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-jose/go-jose/v3 v3.0.2 h1:2Edjn8Nrb44UvTdp84KU0bBPs1cO7noRCybtS3eJEUQ=
github.com/go-jose/go-jose/v3 v3.0.2/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
github.com/go-kit/kit v0.4.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.4.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU=
github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg= github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg=
@ -245,8 +246,8 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.1.2 h1:DVjP2PbBOzHyzA+dn3WhHIq4NdVu3Q+pvivFICf/7fo= github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
github.com/golang/glog v1.1.2/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@ -305,6 +306,7 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-sev-guest v0.9.3 h1:GOJ+EipURdeWFl/YYdgcCxyPeMgQUWlI056iFkBD8UU= github.com/google/go-sev-guest v0.9.3 h1:GOJ+EipURdeWFl/YYdgcCxyPeMgQUWlI056iFkBD8UU=
@ -351,8 +353,8 @@ github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0
github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/OthfcblKl4IGNaM= github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/OthfcblKl4IGNaM=
github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM= github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM=
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c= github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA=
github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc=
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@ -888,6 +890,7 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1107,8 +1110,8 @@ google.golang.org/genproto v0.0.0-20220608133413-ed9918b62aac/go.mod h1:KEWEmljW
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe h1:USL2DhxfgRchafRvt/wYyyQNzwgL7ZiURcozOE/Pkvo= google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe h1:USL2DhxfgRchafRvt/wYyyQNzwgL7ZiURcozOE/Pkvo=
google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:cc8bqMqtv9gMOr0zHg2Vzff5ULhhL2IXP4sbcn32Dro= google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:cc8bqMqtv9gMOr0zHg2Vzff5ULhhL2IXP4sbcn32Dro=
google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe h1:0poefMBYvYbs7g5UkjS6HcxBPaTRAmznle9jnxYoAI8= google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014 h1:x9PwdEgd11LgK+orcck69WVRo7DezSO4VUMPI4xpc8A=
google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014/go.mod h1:rbHMSEDyoYX62nRVLOCc4Qt1HbsdytAYoVwgjiOhF3I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 h1:FSL3lRCkhaPFxqi0s9o+V4UI2WTzAVOvkgbd4kVV4Wg= google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 h1:FSL3lRCkhaPFxqi0s9o+V4UI2WTzAVOvkgbd4kVV4Wg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014/go.mod h1:SaPjaZGWb0lPqs6Ittu0spdfrOArqji4ZdeP5IC/9N4= google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014/go.mod h1:SaPjaZGWb0lPqs6Ittu0spdfrOArqji4ZdeP5IC/9N4=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
@ -1142,8 +1145,8 @@ google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11
google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.46.2/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.46.2/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk=
google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

@ -42,9 +42,13 @@ func New() (m *Meter) {
m.ssh.rekeyed, m.ssh.rekeyed,
m.ssh.renewed, m.ssh.renewed,
m.ssh.signed, m.ssh.signed,
m.ssh.webhookAuthorized,
m.ssh.webhookEnriched,
m.x509.rekeyed, m.x509.rekeyed,
m.x509.renewed, m.x509.renewed,
m.x509.signed, m.x509.signed,
m.x509.webhookAuthorized,
m.x509.webhookEnriched,
m.kms.signed, m.kms.signed,
m.kms.errors, m.kms.errors,
) )

@ -0,0 +1,91 @@
package requestid
import (
"context"
"net/http"
"github.com/rs/xid"
"go.step.sm/crypto/randutil"
)
const (
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
requestIDHeader = "X-Request-Id"
// defaultTraceHeader is the default Smallstep tracing header that's currently
// in use. It is used as a fallback to retrieve a request ID from, if the
// "X-Request-Id" request header is not set.
defaultTraceHeader = "X-Smallstep-Id"
)
type Handler struct {
legacyTraceHeader string
}
// New creates a new request ID [handler]. It takes a trace header,
// which is used keep the legacy behavior intact, which relies on the
// X-Smallstep-Id header instead of X-Request-Id.
func New(legacyTraceHeader string) *Handler {
if legacyTraceHeader == "" {
legacyTraceHeader = defaultTraceHeader
}
return &Handler{legacyTraceHeader: legacyTraceHeader}
}
// Middleware wraps an [http.Handler] with request ID extraction
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
// header if not set. If both are not set, a new request ID is generated.
// In all cases, the request ID is added to the request context, and
// set to be reflected in the response.
func (h *Handler) Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(h.legacyTraceHeader)
}
if requestID == "" {
requestID = newRequestID()
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
}
// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)
// continue down the handler chain
ctx := NewContext(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// newRequestID generates a new random UUIDv4 request ID. If UUIDv4
// generation fails, it'll fallback to generating a random ID using
// github.com/rs/xid.
func newRequestID() string {
requestID, err := randutil.UUIDv4()
if err != nil {
requestID = xid.New().String()
}
return requestID
}
type contextKey struct{}
// NewContext returns a new context with the given request ID added to the
// context.
func NewContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, contextKey{}, requestID)
}
// FromContext returns the request ID from the context if it exists and
// is not the empty value.
func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != ""
}

@ -0,0 +1,105 @@
package requestid
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newRequest(t *testing.T) *http.Request {
t.Helper()
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
require.NoError(t, err)
return r
}
func Test_Middleware(t *testing.T) {
requestWithID := newRequest(t)
requestWithID.Header.Set("X-Request-Id", "reqID")
requestWithoutID := newRequest(t)
requestWithEmptyHeader := newRequest(t)
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
requestWithSmallstepID := newRequest(t)
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")
tests := []struct {
name string
traceHeader string
next http.HandlerFunc
req *http.Request
}{
{
name: "default-request-id",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "reqID", reqID)
}
assert.Equal(t, "reqID", w.Header().Get("X-Request-Id"))
},
req: requestWithID,
},
{
name: "no-request-id",
traceHeader: "X-Request-Id",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
value := r.Header.Get("X-Request-Id")
assert.NotEmpty(t, value)
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
assert.Equal(t, value, w.Header().Get("X-Request-Id"))
},
req: requestWithoutID,
},
{
name: "empty-header",
traceHeader: "",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
value := r.Header.Get("X-Smallstep-Id")
assert.NotEmpty(t, value)
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
assert.Equal(t, value, w.Header().Get("X-Request-Id"))
},
req: requestWithEmptyHeader,
},
{
name: "fallback-header-name",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "smallstepID", reqID)
}
assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id"))
},
req: requestWithSmallstepID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := New(tt.traceHeader).Middleware(tt.next)
w := httptest.NewRecorder()
handler.ServeHTTP(w, tt.req)
assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
})
}
}

@ -0,0 +1,20 @@
package userid
import "context"
type contextKey struct{}
// NewContext returns a new context with the given user ID added to the
// context.
// TODO(hs): this doesn't seem to be used / set currently; implement
// when/where it makes sense.
func NewContext(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, contextKey{}, userID)
}
// FromContext returns the user ID from the context if it exists
// and is not empty.
func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != ""
}

@ -1,66 +0,0 @@
package logging
import (
"context"
"net/http"
"github.com/rs/xid"
)
type key int
const (
// RequestIDKey is the context key that should store the request identifier.
RequestIDKey key = iota
// UserIDKey is the context key that should store the user identifier.
UserIDKey
)
// NewRequestID creates a new request id using github.com/rs/xid.
func NewRequestID() string {
return xid.New().String()
}
// RequestID returns a new middleware that gets the given header and sets it
// in the context so it can be written in the logger. If the header does not
// exists or it's the empty string, it uses github.com/rs/xid to create a new
// one.
func RequestID(headerName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(headerName)
if requestID == "" {
requestID = NewRequestID()
req.Header.Set(headerName, requestID)
}
ctx := WithRequestID(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}
// WithRequestID returns a new context with the given requestID added to the
// context.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, RequestIDKey, requestID)
}
// GetRequestID returns the request id from the context if it exists.
func GetRequestID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(RequestIDKey).(string)
return v, ok
}
// WithUserID decodes the token, extracts the user from the payload and stores
// it in the context.
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, UserIDKey, userID)
}
// GetUserID returns the request id from the context if it exists.
func GetUserID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(UserIDKey).(string)
return v, ok
}

@ -9,6 +9,9 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/internal/userid"
) )
// LoggerHandler creates a logger handler // LoggerHandler creates a logger handler
@ -29,16 +32,15 @@ type options struct {
// NewLoggerHandler returns the given http.Handler with the logger integrated. // NewLoggerHandler returns the given http.Handler with the logger integrated.
func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler { func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler {
h := RequestID(logger.GetTraceHeader())
onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT")) onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT"))
return h(&LoggerHandler{ return &LoggerHandler{
name: name, name: name,
logger: logger.GetImpl(), logger: logger.GetImpl(),
options: options{ options: options{
onlyTraceHealthEndpoint: onlyTraceHealthEndpoint, onlyTraceHealthEndpoint: onlyTraceHealthEndpoint,
}, },
next: next, next: next,
}) }
} }
// ServeHTTP implements the http.Handler and call to the handler to log with a // ServeHTTP implements the http.Handler and call to the handler to log with a
@ -54,14 +56,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// writeEntry writes to the Logger writer the request information in the logger. // writeEntry writes to the Logger writer the request information in the logger.
func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) { func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) {
var reqID, user string var requestID, userID string
ctx := r.Context() ctx := r.Context()
if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" { if v, ok := requestid.FromContext(ctx); ok {
reqID = v requestID = v
} }
if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" { if v, ok := userid.FromContext(ctx); ok {
user = v userID = v
} }
// Remote hostname // Remote hostname
@ -85,10 +87,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim
status := w.StatusCode() status := w.StatusCode()
fields := logrus.Fields{ fields := logrus.Fields{
"request-id": reqID, "request-id": requestID,
"remote-address": addr, "remote-address": addr,
"name": l.name, "name": l.name,
"user-id": user, "user-id": userID,
"time": t.Format(time.RFC3339), "time": t.Format(time.RFC3339),
"duration-ns": d.Nanoseconds(), "duration-ns": d.Nanoseconds(),
"duration": d.String(), "duration": d.String(),

@ -9,6 +9,8 @@ import (
"github.com/newrelic/go-agent/v3/newrelic" "github.com/newrelic/go-agent/v3/newrelic"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -82,7 +84,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware {
txn.AddAttribute("httpResponseCode", strconv.Itoa(status)) txn.AddAttribute("httpResponseCode", strconv.Itoa(status))
// Add custom attributes // Add custom attributes
if v, ok := logging.GetRequestID(r.Context()); ok { if v, ok := requestid.FromContext(r.Context()); ok {
txn.AddAttribute("request.id", v) txn.AddAttribute("request.id", v)
} }

@ -60,7 +60,7 @@ func MustFromContext(ctx context.Context) *Authority {
// SignAuthority is the interface for a signing authority // SignAuthority is the interface for a signing authority
type SignAuthority interface { type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
LoadProvisionerByName(string) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error)
} }
@ -306,7 +306,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m
} }
signOps = append(signOps, templateOptions) signOps = append(signOps, templateOptions)
certChain, err := a.signAuth.Sign(csr, opts, signOps...) certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...)
if err != nil { if err != nil {
return nil, fmt.Errorf("error generating certificate for order: %w", err) return nil, fmt.Errorf("error generating certificate for order: %w", err)
} }

@ -0,0 +1,289 @@
package integration
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs"
)
// reservePort "reserves" a TCP port by opening a listener on a random
// port and immediately closing it. The port can then be assumed to be
// available for running a server on.
func reservePort(t *testing.T) (host, port string) {
t.Helper()
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
address := l.Addr().String()
err = l.Close()
require.NoError(t, err)
host, port, err = net.SplitHostPort(address)
require.NoError(t, err)
return
}
func Test_reflectRequestID(t *testing.T) {
dir := t.TempDir()
m, err := minica.New(minica.WithName("Step E2E"))
require.NoError(t, err)
rootFilepath := filepath.Join(dir, "root.crt")
_, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath))
require.NoError(t, err)
intermediateCertFilepath := filepath.Join(dir, "intermediate.crt")
_, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath))
require.NoError(t, err)
intermediateKeyFilepath := filepath.Join(dir, "intermediate.key")
_, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath))
require.NoError(t, err)
// get a random address to listen on and connect to; currently no nicer way to get one before starting the server
// TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it?
host, port := reservePort(t)
authorizingSrv := newAuthorizingServer(t, m)
defer authorizingSrv.Close()
authorizingSrv.StartTLS()
password := []byte("1234")
jwk, jwe, err := jose.GenerateDefaultKeyPair(password)
require.NoError(t, err)
encryptedKey, err := jwe.CompactSerialize()
require.NoError(t, err)
prov := &provisioner.JWK{
ID: "jwk",
Name: "jwk",
Type: "JWK",
Key: jwk,
EncryptedKey: encryptedKey,
Claims: &config.GlobalProvisionerClaims,
Options: &provisioner.Options{
Webhooks: []*provisioner.Webhook{
{
ID: "webhook",
Name: "webhook-test",
URL: fmt.Sprintf("%s/authorize", authorizingSrv.URL),
Kind: "AUTHORIZING",
CertType: "X509",
},
},
},
}
err = prov.Init(provisioner.Config{})
require.NoError(t, err)
cfg := &config.Config{
Root: []string{rootFilepath},
IntermediateCert: intermediateCertFilepath,
IntermediateKey: intermediateKeyFilepath,
Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved"
DNSNames: []string{"127.0.0.1", "[::1]", "localhost"},
AuthorityConfig: &config.AuthConfig{
AuthorityID: "stepca-test",
DeploymentType: "standalone-test",
Provisioners: provisioner.List{prov},
},
Logger: json.RawMessage(`{"format": "text"}`),
}
c, err := ca.New(cfg)
require.NoError(t, err)
// instantiate a client for the CA running at the random address
caClient, err := ca.NewClient(
fmt.Sprintf("https://localhost:%s", port),
ca.WithRootFile(rootFilepath),
)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = c.Run()
require.ErrorIs(t, err, http.ErrServerClosed)
}()
// require OK health response as the baseline
ctx := context.Background()
healthResponse, err := caClient.HealthWithContext(ctx)
require.NoError(t, err)
if assert.NotNil(t, healthResponse) {
require.Equal(t, "ok", healthResponse.Status)
}
// expect an error when retrieving an invalid root
rootResponse, err := caClient.RootWithContext(ctx, "invalid")
var firstErr *errs.Error
if assert.ErrorAs(t, err, &firstErr) {
assert.Equal(t, 404, firstErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", firstErr.Err.Error())
assert.NotEmpty(t, firstErr.RequestID)
// TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759
//assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg)
}
assert.Nil(t, rootResponse)
// expect an error when retrieving an invalid root and provided request ID
rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid")
var secondErr *errs.Error
if assert.ErrorAs(t, err, &secondErr) {
assert.Equal(t, 404, secondErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", secondErr.Err.Error())
assert.Equal(t, "reqID", secondErr.RequestID)
}
assert.Nil(t, rootResponse)
// prepare a Sign request
subject := "test"
decryptedJWK := decryptPrivateKey(t, jwe, password)
ott := generateOTT(t, decryptedJWK, subject)
signer, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)
csr, err := x509util.CreateCertificateRequest(subject, []string{subject}, signer)
require.NoError(t, err)
// perform the Sign request using the OTT and CSR
signResponse, err := caClient.SignWithContext(client.NewRequestIDContext(ctx, "signRequestID"), &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: ott,
NotAfter: api.NewTimeDuration(time.Now().Add(1 * time.Hour)),
NotBefore: api.NewTimeDuration(time.Now().Add(-1 * time.Hour)),
})
assert.NoError(t, err)
// assert a certificate was returned for the subject "test"
if assert.NotNil(t, signResponse) {
assert.Len(t, signResponse.CertChainPEM, 2)
cert, err := x509.ParseCertificate(signResponse.CertChainPEM[0].Raw)
assert.NoError(t, err)
if assert.NotNil(t, cert) {
assert.Equal(t, "test", cert.Subject.CommonName)
assert.Contains(t, cert.DNSNames, "test")
}
}
// done testing; stop and wait for the server to quit
err = c.Stop()
require.NoError(t, err)
wg.Wait()
}
func decryptPrivateKey(t *testing.T, jwe *jose.JSONWebEncryption, pass []byte) *jose.JSONWebKey {
t.Helper()
d, err := jwe.Decrypt(pass)
require.NoError(t, err)
jwk := &jose.JSONWebKey{}
err = json.Unmarshal(d, jwk)
require.NoError(t, err)
return jwk
}
func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string {
t.Helper()
now := time.Now()
keyID, err := jose.Thumbprint(jwk)
require.NoError(t, err)
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", keyID)
signer, err := jose.NewSigner(jose.SigningKey{Key: jwk.Key}, opts)
require.NoError(t, err)
id, err := randutil.ASCII(64)
require.NoError(t, err)
cl := struct {
jose.Claims
SANS []string `json:"sans"`
}{
Claims: jose.Claims{
ID: id,
Subject: subject,
Issuer: "jwk",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"https://127.0.0.1/1.0/sign"},
},
SANS: []string{subject},
}
raw, err := jose.Signed(signer).Claims(cl).CompactSerialize()
require.NoError(t, err)
return raw
}
func newAuthorizingServer(t *testing.T, mca *minica.CA) *httptest.Server {
t.Helper()
key, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)
csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key)
require.NoError(t, err)
crt, err := mca.SignCSR(csr)
require.NoError(t, err)
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if assert.Equal(t, "signRequestID", r.Header.Get("X-Request-Id")) {
json.NewEncoder(w).Encode(struct{ Allow bool }{Allow: true})
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
trustedRoots := x509.NewCertPool()
trustedRoots.AddCert(mca.Root)
srv.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{crt.Raw, mca.Intermediate.Raw},
PrivateKey: key,
Leaf: crt,
},
},
ClientCAs: trustedRoots,
ClientAuth: tls.RequireAndVerifyClientCert,
ServerName: "localhost",
}
return srv
}
Loading…
Cancel
Save