Improve functional coverage of request ID integration test

pull/1743/head
Herman Slatman 3 months ago
parent 7fd524f70b
commit d392c169fc
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -17,13 +17,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
) )
func TestWebhookController_isCertTypeOK(t *testing.T) { func TestWebhookController_isCertTypeOK(t *testing.T) {
@ -103,7 +105,8 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
// withRequestID is a helper that calls into [requestid.NewContext] and returns // withRequestID is a helper that calls into [requestid.NewContext] and returns
// a new context with the requestID added. // a new context with the requestID added.
func withRequestID(ctx context.Context, requestID string) context.Context { func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context {
t.Helper()
return requestid.NewContext(ctx, requestID) return requestid.NewContext(ctx, requestID)
} }
@ -138,7 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
@ -153,7 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) {
}, },
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -177,7 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
certType: linkedca.Webhook_X509, certType: linkedca.Webhook_X509,
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{ responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"role": "bar"}},
@ -197,7 +200,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false, expectErr: false,
@ -220,7 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{}, TemplateData: x509util.TemplateData{},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -235,7 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -296,7 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
@ -307,7 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
certType: linkedca.Webhook_SSH, certType: linkedca.Webhook_SSH,
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false, expectErr: false,
@ -318,7 +321,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}}, responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false, expectErr: false,
@ -339,7 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient, client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -352,7 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) {
PublicKey: []byte("bad"), PublicKey: []byte("bad"),
})}, })},
}, },
ctx: withRequestID(context.Background(), "reqID"), ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{}, req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
@ -568,7 +571,7 @@ func TestWebhook_Do(t *testing.T) {
ctx := context.Background() ctx := context.Background()
if tc.requestID != "" { if tc.requestID != "" {
ctx = withRequestID(context.Background(), tc.requestID) ctx = withRequestID(t, ctx, tc.requestID)
} }
ctx, cancel := context.WithTimeout(ctx, time.Second*10) ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() defer cancel()

@ -2,17 +2,17 @@ package client
import "context" import "context"
type requestIDKey struct{} type contextKey struct{}
// NewRequestIDContext returns a new context with the given request ID added to the // NewRequestIDContext returns a new context with the given request ID added to the
// context. // context.
func NewRequestIDContext(ctx context.Context, requestID string) context.Context { func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID) return context.WithValue(ctx, contextKey{}, requestID)
} }
// RequestIDFromContext returns the request ID from the context if it exists. // RequestIDFromContext returns the request ID from the context if it exists.
// and is not empty. // and is not empty.
func RequestIDFromContext(ctx context.Context) (string, bool) { func RequestIDFromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string) v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != "" return v, ok && v != ""
} }

@ -7,6 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
@ -41,14 +43,12 @@ func getTestProvisioner(t *testing.T, caURL string) *Provisioner {
} }
func TestNewProvisioner(t *testing.T) { func TestNewProvisioner(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
type args struct { type args struct {
name string name string

@ -10,6 +10,8 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/stretchr/testify/require"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
) )
@ -196,23 +198,17 @@ func TestAddClientCA(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToRootCAs(t *testing.T) { func TestAddRootsToRootCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(t, string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
@ -251,23 +247,17 @@ func TestAddRootsToRootCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToClientCAs(t *testing.T) { func TestAddRootsToClientCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(t, string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
@ -306,28 +296,20 @@ func TestAddRootsToClientCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToRootCAs(t *testing.T) { func TestAddFederationToRootCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(t, string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(t, string(federated)) crt2 := parseCertificate(t, string(federated))
@ -371,28 +353,20 @@ func TestAddFederationToRootCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToClientCAs(t *testing.T) { func TestAddFederationToClientCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(t, string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(t, string(federated)) crt2 := parseCertificate(t, string(federated))
@ -436,23 +410,17 @@ func TestAddFederationToClientCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddRootsToCAs(t *testing.T) { func TestAddRootsToCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
cert := parseCertificate(t, string(root)) cert := parseCertificate(t, string(root))
pool := x509.NewCertPool() pool := x509.NewCertPool()
@ -491,28 +459,20 @@ func TestAddRootsToCAs(t *testing.T) {
//nolint:gosec // test tls config //nolint:gosec // test tls config
func TestAddFederationToCAs(t *testing.T) { func TestAddFederationToCAs(t *testing.T) {
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
root, err := os.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
crt1 := parseCertificate(t, string(root)) crt1 := parseCertificate(t, string(root))
crt2 := parseCertificate(t, string(federated)) crt2 := parseCertificate(t, string(federated))

@ -17,27 +17,28 @@ import (
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/api" "github.com/stretchr/testify/require"
"github.com/smallstep/certificates/authority"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
) )
func generateOTT(subject string) string { func generateOTT(t *testing.T, subject string) string {
t.Helper()
now := time.Now() now := time.Now()
jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
if err != nil { require.NoError(t, err)
panic(err)
}
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
if err != nil { require.NoError(t, err)
panic(err)
}
id, err := randutil.ASCII(64) id, err := randutil.ASCII(64)
if err != nil { require.NoError(t, err)
panic(err)
}
cl := struct { cl := struct {
jose.Claims jose.Claims
SANS []string `json:"sans"` SANS []string `json:"sans"`
@ -53,9 +54,8 @@ func generateOTT(subject string) string {
SANS: []string{subject}, SANS: []string{subject},
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
if err != nil { require.NoError(t, err)
panic(err)
}
return raw return raw
} }
@ -72,32 +72,28 @@ func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler
return srv return srv
} }
func startCATestServer() *httptest.Server { func startCATestServer(t *testing.T) *httptest.Server {
config, err := authority.LoadConfiguration("testdata/ca.json") config, err := authority.LoadConfiguration("testdata/ca.json")
if err != nil { require.NoError(t, err)
panic(err)
}
ca, err := New(config) ca, err := New(config)
if err != nil { require.NoError(t, err)
panic(err)
}
// Use a httptest.Server instead // Use a httptest.Server instead
baseContext := buildContext(ca.auth, nil, nil, nil) baseContext := buildContext(ca.auth, nil, nil, nil)
srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler)
return srv return srv
} }
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {
srv := startCATestServer() t.Helper()
srv := startCATestServer(t)
defer srv.Close() defer srv.Close()
return signDuration(srv, domain, 0) return signDuration(t, srv, domain, 0)
} }
func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) {
req, pk, err := CreateSignRequest(generateOTT(domain)) t.Helper()
if err != nil { req, pk, err := CreateSignRequest(generateOTT(t, domain))
panic(err) require.NoError(t, err)
}
if duration > 0 { if duration > 0 {
req.NotBefore = api.NewTimeDuration(time.Now()) req.NotBefore = api.NewTimeDuration(time.Now())
@ -105,13 +101,11 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
} }
client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil { require.NoError(t, err)
panic(err)
}
sr, err := client.Sign(req) sr, err := client.Sign(req)
if err != nil { require.NoError(t, err)
panic(err)
}
return client, sr, pk return client, sr, pk
} }
@ -145,7 +139,7 @@ func serverHandler(t *testing.T, clientDomain string) http.Handler {
func TestClient_GetServerTLSConfig_http(t *testing.T) { func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain" clientDomain := "test.domain"
client, sr, pk := sign("127.0.0.1") client, sr, pk := sign(t, "127.0.0.1")
// Create mTLS server // Create mTLS server
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -212,7 +206,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain) client, sr, pk := sign(t, clientDomain)
cli := tt.getClient(t, client, sr, pk) cli := tt.getClient(t, client, sr, pk)
if cli == nil { if cli == nil {
return return
@ -246,19 +240,18 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
defer reset() defer reset()
// Start CA // Start CA
ca := startCATestServer() ca := startCATestServer(t)
defer ca.Close() defer ca.Close()
clientDomain := "test.domain" clientDomain := "test.domain"
client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second)
// Start mTLS server // Start mTLS server
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvMTLS.Close() defer srvMTLS.Close()
@ -266,30 +259,26 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(context.Background())
defer cancel() defer cancel()
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain))
defer srvTLS.Close() defer srvTLS.Close()
// Transport // Transport
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
tr1, err := client.Transport(context.Background(), sr, pk) tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.Transport() error = %v", err)
}
// Transport with tlsConfig // Transport with tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil { require.NoError(t, err)
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2 := getDefaultTransport(tlsConfig) tr2 := getDefaultTransport(tlsConfig)
// No client cert // No client cert
root, err := RootCertificate(sr) root, err := RootCertificate(sr)
if err != nil { require.NoError(t, err)
t.Fatalf("RootCertificate() error = %v", err)
}
tlsConfig = getDefaultTLSConfig(sr) tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root) tlsConfig.RootCAs.AddCert(root)

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/rs/xid" "github.com/rs/xid"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
) )
@ -74,17 +75,17 @@ func newRequestID() string {
return requestID return requestID
} }
type requestIDKey struct{} type contextKey struct{}
// NewContext returns a new context with the given request ID added to the // NewContext returns a new context with the given request ID added to the
// context. // context.
func NewContext(ctx context.Context, requestID string) context.Context { func NewContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID) return context.WithValue(ctx, contextKey{}, requestID)
} }
// FromContext returns the request ID from the context if it exists and // FromContext returns the request ID from the context if it exists and
// is not the empty value. // is not the empty value.
func FromContext(ctx context.Context) (string, bool) { func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string) v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != "" return v, ok && v != ""
} }

@ -19,11 +19,15 @@ func newRequest(t *testing.T) *http.Request {
func Test_Middleware(t *testing.T) { func Test_Middleware(t *testing.T) {
requestWithID := newRequest(t) requestWithID := newRequest(t)
requestWithID.Header.Set("X-Request-Id", "reqID") requestWithID.Header.Set("X-Request-Id", "reqID")
requestWithoutID := newRequest(t) requestWithoutID := newRequest(t)
requestWithEmptyHeader := newRequest(t) requestWithEmptyHeader := newRequest(t)
requestWithEmptyHeader.Header.Set("X-Request-Id", "") requestWithEmptyHeader.Header.Set("X-Request-Id", "")
requestWithSmallstepID := newRequest(t) requestWithSmallstepID := newRequest(t)
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")
tests := []struct { tests := []struct {
name string name string
traceHeader string traceHeader string

@ -2,19 +2,19 @@ package userid
import "context" import "context"
type userIDKey struct{} type contextKey struct{}
// NewContext returns a new context with the given user ID added to the // NewContext returns a new context with the given user ID added to the
// context. // context.
// TODO(hs): this doesn't seem to be used / set currently; implement // TODO(hs): this doesn't seem to be used / set currently; implement
// when/where it makes sense. // when/where it makes sense.
func NewContext(ctx context.Context, userID string) context.Context { func NewContext(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey{}, userID) return context.WithValue(ctx, contextKey{}, userID)
} }
// FromContext returns the user ID from the context if it exists // FromContext returns the user ID from the context if it exists
// and is not empty. // and is not empty.
func FromContext(ctx context.Context) (string, bool) { func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(userIDKey{}).(string) v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != "" return v, ok && v != ""
} }

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/internal/userid" "github.com/smallstep/certificates/internal/userid"
) )

@ -9,6 +9,7 @@ import (
"github.com/newrelic/go-agent/v3/newrelic" "github.com/newrelic/go-agent/v3/newrelic"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )

@ -1,132 +0,0 @@
package e2e
import (
"context"
"encoding/json"
"fmt"
"net"
"path/filepath"
"sync"
"testing"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/ca"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil"
)
// reservePort "reserves" a TCP port by opening a listener on a random
// port and immediately closing it. The port can then be assumed to be
// available for running a server on.
func reservePort(t *testing.T) (host, port string) {
t.Helper()
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
address := l.Addr().String()
err = l.Close()
require.NoError(t, err)
host, port, err = net.SplitHostPort(address)
require.NoError(t, err)
return
}
func Test_reflectRequestID(t *testing.T) {
dir := t.TempDir()
m, err := minica.New(minica.WithName("Step E2E"))
require.NoError(t, err)
rootFilepath := filepath.Join(dir, "root.crt")
_, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath))
require.NoError(t, err)
intermediateCertFilepath := filepath.Join(dir, "intermediate.crt")
_, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath))
require.NoError(t, err)
intermediateKeyFilepath := filepath.Join(dir, "intermediate.key")
_, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath))
require.NoError(t, err)
// get a random address to listen on and connect to; currently no nicer way to get one before starting the server
// TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it?
host, port := reservePort(t)
cfg := &config.Config{
Root: []string{rootFilepath},
IntermediateCert: intermediateCertFilepath,
IntermediateKey: intermediateKeyFilepath,
Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved"
DNSNames: []string{"127.0.0.1", "[::1]", "localhost"},
AuthorityConfig: &config.AuthConfig{
AuthorityID: "stepca-test",
DeploymentType: "standalone-test",
},
Logger: json.RawMessage(`{"format": "text"}`),
}
c, err := ca.New(cfg)
require.NoError(t, err)
// instantiate a client for the CA running at the random address
caClient, err := ca.NewClient(
fmt.Sprintf("https://localhost:%s", port),
ca.WithRootFile(rootFilepath),
)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = c.Run()
require.Error(t, err) // expect error when server is stopped
}()
// require OK health response as the baseline
ctx := context.Background()
healthResponse, err := caClient.HealthWithContext(ctx)
require.NoError(t, err)
if assert.NotNil(t, healthResponse) {
require.Equal(t, "ok", healthResponse.Status)
}
// expect an error when retrieving an invalid root
rootResponse, err := caClient.RootWithContext(ctx, "invalid")
if assert.Error(t, err) {
apiErr := &errs.Error{}
if assert.ErrorAs(t, err, &apiErr) {
assert.Equal(t, 404, apiErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error())
assert.NotEmpty(t, apiErr.RequestID)
// TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759
//assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg)
}
}
assert.Nil(t, rootResponse)
// expect an error when retrieving an invalid root and provided request ID
rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid")
if assert.Error(t, err) {
apiErr := &errs.Error{}
if assert.ErrorAs(t, err, &apiErr) {
assert.Equal(t, 404, apiErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error())
assert.Equal(t, "reqID", apiErr.RequestID)
}
}
assert.Nil(t, rootResponse)
// done testing; stop and wait for the server to quit
err = c.Stop()
require.NoError(t, err)
wg.Wait()
}

@ -0,0 +1,289 @@
package integration
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs"
)
// reservePort "reserves" a TCP port by opening a listener on a random
// port and immediately closing it. The port can then be assumed to be
// available for running a server on.
func reservePort(t *testing.T) (host, port string) {
t.Helper()
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
address := l.Addr().String()
err = l.Close()
require.NoError(t, err)
host, port, err = net.SplitHostPort(address)
require.NoError(t, err)
return
}
func Test_reflectRequestID(t *testing.T) {
dir := t.TempDir()
m, err := minica.New(minica.WithName("Step E2E"))
require.NoError(t, err)
rootFilepath := filepath.Join(dir, "root.crt")
_, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath))
require.NoError(t, err)
intermediateCertFilepath := filepath.Join(dir, "intermediate.crt")
_, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath))
require.NoError(t, err)
intermediateKeyFilepath := filepath.Join(dir, "intermediate.key")
_, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath))
require.NoError(t, err)
// get a random address to listen on and connect to; currently no nicer way to get one before starting the server
// TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it?
host, port := reservePort(t)
authorizingSrv := newAuthorizingServer(t, m)
defer authorizingSrv.Close()
authorizingSrv.StartTLS()
password := []byte("1234")
jwk, jwe, err := jose.GenerateDefaultKeyPair(password)
require.NoError(t, err)
encryptedKey, err := jwe.CompactSerialize()
require.NoError(t, err)
prov := &provisioner.JWK{
ID: "jwk",
Name: "jwk",
Type: "JWK",
Key: jwk,
EncryptedKey: encryptedKey,
Claims: &config.GlobalProvisionerClaims,
Options: &provisioner.Options{
Webhooks: []*provisioner.Webhook{
{
ID: "webhook",
Name: "webhook-test",
URL: fmt.Sprintf("%s/authorize", authorizingSrv.URL),
Kind: "AUTHORIZING",
CertType: "X509",
},
},
},
}
err = prov.Init(provisioner.Config{})
require.NoError(t, err)
cfg := &config.Config{
Root: []string{rootFilepath},
IntermediateCert: intermediateCertFilepath,
IntermediateKey: intermediateKeyFilepath,
Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved"
DNSNames: []string{"127.0.0.1", "[::1]", "localhost"},
AuthorityConfig: &config.AuthConfig{
AuthorityID: "stepca-test",
DeploymentType: "standalone-test",
Provisioners: provisioner.List{prov},
},
Logger: json.RawMessage(`{"format": "text"}`),
}
c, err := ca.New(cfg)
require.NoError(t, err)
// instantiate a client for the CA running at the random address
caClient, err := ca.NewClient(
fmt.Sprintf("https://localhost:%s", port),
ca.WithRootFile(rootFilepath),
)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = c.Run()
require.ErrorIs(t, err, http.ErrServerClosed)
}()
// require OK health response as the baseline
ctx := context.Background()
healthResponse, err := caClient.HealthWithContext(ctx)
require.NoError(t, err)
if assert.NotNil(t, healthResponse) {
require.Equal(t, "ok", healthResponse.Status)
}
// expect an error when retrieving an invalid root
rootResponse, err := caClient.RootWithContext(ctx, "invalid")
var firstErr *errs.Error
if assert.ErrorAs(t, err, &firstErr) {
assert.Equal(t, 404, firstErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", firstErr.Err.Error())
assert.NotEmpty(t, firstErr.RequestID)
// TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759
//assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg)
}
assert.Nil(t, rootResponse)
// expect an error when retrieving an invalid root and provided request ID
rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid")
var secondErr *errs.Error
if assert.ErrorAs(t, err, &secondErr) {
assert.Equal(t, 404, secondErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", secondErr.Err.Error())
assert.Equal(t, "reqID", secondErr.RequestID)
}
assert.Nil(t, rootResponse)
// prepare a Sign request
subject := "test"
decryptedJWK := decryptPrivateKey(t, jwe, password)
ott := generateOTT(t, decryptedJWK, subject)
signer, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)
csr, err := x509util.CreateCertificateRequest(subject, []string{subject}, signer)
require.NoError(t, err)
// perform the Sign request using the OTT and CSR
signResponse, err := caClient.SignWithContext(client.NewRequestIDContext(ctx, "signRequestID"), &api.SignRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: csr},
OTT: ott,
NotAfter: api.NewTimeDuration(time.Now().Add(1 * time.Hour)),
NotBefore: api.NewTimeDuration(time.Now().Add(-1 * time.Hour)),
})
assert.NoError(t, err)
// assert a certificate was returned for the subject "test"
if assert.NotNil(t, signResponse) {
assert.Len(t, signResponse.CertChainPEM, 2)
cert, err := x509.ParseCertificate(signResponse.CertChainPEM[0].Raw)
assert.NoError(t, err)
if assert.NotNil(t, cert) {
assert.Equal(t, "test", cert.Subject.CommonName)
assert.Contains(t, cert.DNSNames, "test")
}
}
// done testing; stop and wait for the server to quit
err = c.Stop()
require.NoError(t, err)
wg.Wait()
}
func decryptPrivateKey(t *testing.T, jwe *jose.JSONWebEncryption, pass []byte) *jose.JSONWebKey {
t.Helper()
d, err := jwe.Decrypt(pass)
require.NoError(t, err)
jwk := &jose.JSONWebKey{}
err = json.Unmarshal(d, jwk)
require.NoError(t, err)
return jwk
}
func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string {
t.Helper()
now := time.Now()
keyID, err := jose.Thumbprint(jwk)
require.NoError(t, err)
opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", keyID)
signer, err := jose.NewSigner(jose.SigningKey{Key: jwk.Key}, opts)
require.NoError(t, err)
id, err := randutil.ASCII(64)
require.NoError(t, err)
cl := struct {
jose.Claims
SANS []string `json:"sans"`
}{
Claims: jose.Claims{
ID: id,
Subject: subject,
Issuer: "jwk",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: []string{"https://127.0.0.1/1.0/sign"},
},
SANS: []string{subject},
}
raw, err := jose.Signed(signer).Claims(cl).CompactSerialize()
require.NoError(t, err)
return raw
}
func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server {
t.Helper()
key, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)
csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key)
require.NoError(t, err)
crt, err := ca.SignCSR(csr)
require.NoError(t, err)
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if assert.Equal(t, "signRequestID", r.Header.Get("X-Request-Id")) {
json.NewEncoder(w).Encode(struct{ Allow bool }{Allow: true})
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
trustedRoots := x509.NewCertPool()
trustedRoots.AddCert(ca.Root)
srv.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{crt.Raw, ca.Intermediate.Raw},
PrivateKey: key,
Leaf: crt,
},
},
ClientCAs: trustedRoots,
ClientAuth: tls.RequireAndVerifyClientCert,
ServerName: "localhost",
}
return srv
}
Loading…
Cancel
Save