diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 4c80796f..90583418 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,13 +17,15 @@ import ( "testing" "time" - "github.com/smallstep/certificates/internal/requestid" - "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" + + "github.com/smallstep/certificates/internal/requestid" + "github.com/smallstep/certificates/webhook" ) func TestWebhookController_isCertTypeOK(t *testing.T) { @@ -103,7 +105,8 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { // withRequestID is a helper that calls into [requestid.NewContext] and returns // a new context with the requestID added. -func withRequestID(ctx context.Context, requestID string) context.Context { +func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context { + t.Helper() return requestid.NewContext(ctx, requestID) } @@ -138,7 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -153,7 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -177,7 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -197,7 +200,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -220,7 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -235,7 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -296,7 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -307,7 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -318,7 +321,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -339,7 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -352,7 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -568,7 +571,7 @@ func TestWebhook_Do(t *testing.T) { ctx := context.Background() if tc.requestID != "" { - ctx = withRequestID(context.Background(), tc.requestID) + ctx = withRequestID(t, ctx, tc.requestID) } ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() diff --git a/ca/client/requestid.go b/ca/client/requestid.go index 2bebb7e5..1fb785eb 100644 --- a/ca/client/requestid.go +++ b/ca/client/requestid.go @@ -2,17 +2,17 @@ package client import "context" -type requestIDKey struct{} +type contextKey struct{} // NewRequestIDContext returns a new context with the given request ID added to the // context. func NewRequestIDContext(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, requestIDKey{}, requestID) + return context.WithValue(ctx, contextKey{}, requestID) } // RequestIDFromContext returns the request ID from the context if it exists. // and is not empty. func RequestIDFromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(requestIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 39193f3f..5a754f08 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" @@ -41,14 +43,12 @@ func getTestProvisioner(t *testing.T, caURL string) *Provisioner { } func TestNewProvisioner(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() want := getTestProvisioner(t, ca.URL) caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type args struct { name string diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index c29947ad..4ac6ff85 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,6 +10,8 @@ import ( "sort" "testing" + "github.com/stretchr/testify/require" + "github.com/smallstep/certificates/api" ) @@ -196,23 +198,17 @@ func TestAddClientCA(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -251,23 +247,17 @@ func TestAddRootsToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -306,28 +296,20 @@ func TestAddRootsToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) @@ -371,28 +353,20 @@ func TestAddFederationToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) @@ -436,23 +410,17 @@ func TestAddFederationToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -491,28 +459,20 @@ func TestAddRootsToCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) diff --git a/ca/tls_test.go b/ca/tls_test.go index a19685ce..d1ce11ea 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -17,27 +17,28 @@ import ( "testing" "time" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" ) -func generateOTT(subject string) string { +func generateOTT(t *testing.T, subject string) string { + t.Helper() now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) - if err != nil { - panic(err) - } + require.NoError(t, err) + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) - if err != nil { - panic(err) - } + require.NoError(t, err) + id, err := randutil.ASCII(64) - if err != nil { - panic(err) - } + require.NoError(t, err) + cl := struct { jose.Claims SANS []string `json:"sans"` @@ -53,9 +54,8 @@ func generateOTT(subject string) string { SANS: []string{subject}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - if err != nil { - panic(err) - } + require.NoError(t, err) + return raw } @@ -72,32 +72,28 @@ func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler return srv } -func startCATestServer() *httptest.Server { +func startCATestServer(t *testing.T) *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") - if err != nil { - panic(err) - } + require.NoError(t, err) ca, err := New(config) - if err != nil { - panic(err) - } + require.NoError(t, err) // Use a httptest.Server instead baseContext := buildContext(ca.auth, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } -func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { - srv := startCATestServer() +func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + srv := startCATestServer(t) defer srv.Close() - return signDuration(srv, domain, 0) + return signDuration(t, srv, domain, 0) } -func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { - req, pk, err := CreateSignRequest(generateOTT(domain)) - if err != nil { - panic(err) - } +func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + req, pk, err := CreateSignRequest(generateOTT(t, domain)) + require.NoError(t, err) if duration > 0 { req.NotBefore = api.NewTimeDuration(time.Now()) @@ -105,13 +101,11 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) ( } client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - panic(err) - } + require.NoError(t, err) + sr, err := client.Sign(req) - if err != nil { - panic(err) - } + require.NoError(t, err) + return client, sr, pk } @@ -145,7 +139,7 @@ func serverHandler(t *testing.T, clientDomain string) http.Handler { func TestClient_GetServerTLSConfig_http(t *testing.T) { clientDomain := "test.domain" - client, sr, pk := sign("127.0.0.1") + client, sr, pk := sign(t, "127.0.0.1") // Create mTLS server ctx, cancel := context.WithCancel(context.Background()) @@ -212,7 +206,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client, sr, pk := sign(clientDomain) + client, sr, pk := sign(t, clientDomain) cli := tt.getClient(t, client, sr, pk) if cli == nil { return @@ -246,19 +240,18 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { defer reset() // Start CA - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() clientDomain := "test.domain" - client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) + client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) defer cancel() tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() @@ -266,30 +259,26 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) defer cancel() tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tr1, err := client.Transport(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.Transport() error = %v", err) - } + require.NoError(t, err) + // Transport with tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetClientTLSConfig() error = %v", err) - } + require.NoError(t, err) + tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) - if err != nil { - t.Fatalf("RootCertificate() error = %v", err) - } + require.NoError(t, err) + tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go index 7008d469..ace08f16 100644 --- a/internal/requestid/requestid.go +++ b/internal/requestid/requestid.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/rs/xid" + "go.step.sm/crypto/randutil" ) @@ -74,17 +75,17 @@ func newRequestID() string { return requestID } -type requestIDKey struct{} +type contextKey struct{} // NewContext returns a new context with the given request ID added to the // context. func NewContext(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, requestIDKey{}, requestID) + return context.WithValue(ctx, contextKey{}, requestID) } // FromContext returns the request ID from the context if it exists and // is not the empty value. func FromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(requestIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/internal/requestid/requestid_test.go b/internal/requestid/requestid_test.go index 4d0e872d..84a9021f 100644 --- a/internal/requestid/requestid_test.go +++ b/internal/requestid/requestid_test.go @@ -19,11 +19,15 @@ func newRequest(t *testing.T) *http.Request { func Test_Middleware(t *testing.T) { requestWithID := newRequest(t) requestWithID.Header.Set("X-Request-Id", "reqID") + requestWithoutID := newRequest(t) + requestWithEmptyHeader := newRequest(t) requestWithEmptyHeader.Header.Set("X-Request-Id", "") + requestWithSmallstepID := newRequest(t) requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + tests := []struct { name string traceHeader string diff --git a/internal/userid/userid.go b/internal/userid/userid.go index bab4908f..48087da8 100644 --- a/internal/userid/userid.go +++ b/internal/userid/userid.go @@ -2,19 +2,19 @@ package userid import "context" -type userIDKey struct{} +type contextKey struct{} // NewContext returns a new context with the given user ID added to the // context. // TODO(hs): this doesn't seem to be used / set currently; implement // when/where it makes sense. func NewContext(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, userIDKey{}, userID) + return context.WithValue(ctx, contextKey{}, userID) } // FromContext returns the user ID from the context if it exists // and is not empty. func FromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(userIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/logging/handler.go b/logging/handler.go index a29383b2..06fc56d3 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/internal/userid" ) diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index 7c88ab3b..2ca2ef54 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -9,6 +9,7 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" ) diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go deleted file mode 100644 index d2f968c3..00000000 --- a/test/e2e/requestid_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package e2e - -import ( - "context" - "encoding/json" - "fmt" - "net" - "path/filepath" - "sync" - "testing" - - "github.com/smallstep/certificates/authority/config" - "github.com/smallstep/certificates/ca" - "github.com/smallstep/certificates/ca/client" - "github.com/smallstep/certificates/errs" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.step.sm/crypto/minica" - "go.step.sm/crypto/pemutil" -) - -// reservePort "reserves" a TCP port by opening a listener on a random -// port and immediately closing it. The port can then be assumed to be -// available for running a server on. -func reservePort(t *testing.T) (host, port string) { - t.Helper() - l, err := net.Listen("tcp", ":0") - require.NoError(t, err) - - address := l.Addr().String() - err = l.Close() - require.NoError(t, err) - - host, port, err = net.SplitHostPort(address) - require.NoError(t, err) - - return -} - -func Test_reflectRequestID(t *testing.T) { - dir := t.TempDir() - m, err := minica.New(minica.WithName("Step E2E")) - require.NoError(t, err) - - rootFilepath := filepath.Join(dir, "root.crt") - _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) - require.NoError(t, err) - - intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") - _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) - require.NoError(t, err) - - intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") - _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) - require.NoError(t, err) - - // get a random address to listen on and connect to; currently no nicer way to get one before starting the server - // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? - host, port := reservePort(t) - - cfg := &config.Config{ - Root: []string{rootFilepath}, - IntermediateCert: intermediateCertFilepath, - IntermediateKey: intermediateKeyFilepath, - Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" - DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, - AuthorityConfig: &config.AuthConfig{ - AuthorityID: "stepca-test", - DeploymentType: "standalone-test", - }, - Logger: json.RawMessage(`{"format": "text"}`), - } - c, err := ca.New(cfg) - require.NoError(t, err) - - // instantiate a client for the CA running at the random address - caClient, err := ca.NewClient( - fmt.Sprintf("https://localhost:%s", port), - ca.WithRootFile(rootFilepath), - ) - require.NoError(t, err) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - err = c.Run() - require.Error(t, err) // expect error when server is stopped - }() - - // require OK health response as the baseline - ctx := context.Background() - healthResponse, err := caClient.HealthWithContext(ctx) - require.NoError(t, err) - if assert.NotNil(t, healthResponse) { - require.Equal(t, "ok", healthResponse.Status) - } - - // expect an error when retrieving an invalid root - rootResponse, err := caClient.RootWithContext(ctx, "invalid") - if assert.Error(t, err) { - apiErr := &errs.Error{} - if assert.ErrorAs(t, err, &apiErr) { - assert.Equal(t, 404, apiErr.StatusCode()) - assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) - assert.NotEmpty(t, apiErr.RequestID) - - // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 - //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) - } - } - assert.Nil(t, rootResponse) - - // expect an error when retrieving an invalid root and provided request ID - rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") - if assert.Error(t, err) { - apiErr := &errs.Error{} - if assert.ErrorAs(t, err, &apiErr) { - assert.Equal(t, 404, apiErr.StatusCode()) - assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) - assert.Equal(t, "reqID", apiErr.RequestID) - } - } - assert.Nil(t, rootResponse) - - // done testing; stop and wait for the server to quit - err = c.Stop() - require.NoError(t, err) - - wg.Wait() -} diff --git a/test/integration/requestid_test.go b/test/integration/requestid_test.go new file mode 100644 index 00000000..f15db12f --- /dev/null +++ b/test/integration/requestid_test.go @@ -0,0 +1,289 @@ +package integration + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/ca/client" + "github.com/smallstep/certificates/errs" +) + +// reservePort "reserves" a TCP port by opening a listener on a random +// port and immediately closing it. The port can then be assumed to be +// available for running a server on. +func reservePort(t *testing.T) (host, port string) { + t.Helper() + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + address := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + host, port, err = net.SplitHostPort(address) + require.NoError(t, err) + + return +} + +func Test_reflectRequestID(t *testing.T) { + dir := t.TempDir() + m, err := minica.New(minica.WithName("Step E2E")) + require.NoError(t, err) + + rootFilepath := filepath.Join(dir, "root.crt") + _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) + require.NoError(t, err) + + intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") + _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) + require.NoError(t, err) + + intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") + _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) + require.NoError(t, err) + + // get a random address to listen on and connect to; currently no nicer way to get one before starting the server + // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? + host, port := reservePort(t) + + authorizingSrv := newAuthorizingServer(t, m) + defer authorizingSrv.Close() + authorizingSrv.StartTLS() + + password := []byte("1234") + jwk, jwe, err := jose.GenerateDefaultKeyPair(password) + require.NoError(t, err) + encryptedKey, err := jwe.CompactSerialize() + require.NoError(t, err) + prov := &provisioner.JWK{ + ID: "jwk", + Name: "jwk", + Type: "JWK", + Key: jwk, + EncryptedKey: encryptedKey, + Claims: &config.GlobalProvisionerClaims, + Options: &provisioner.Options{ + Webhooks: []*provisioner.Webhook{ + { + ID: "webhook", + Name: "webhook-test", + URL: fmt.Sprintf("%s/authorize", authorizingSrv.URL), + Kind: "AUTHORIZING", + CertType: "X509", + }, + }, + }, + } + err = prov.Init(provisioner.Config{}) + require.NoError(t, err) + + cfg := &config.Config{ + Root: []string{rootFilepath}, + IntermediateCert: intermediateCertFilepath, + IntermediateKey: intermediateKeyFilepath, + Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" + DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, + AuthorityConfig: &config.AuthConfig{ + AuthorityID: "stepca-test", + DeploymentType: "standalone-test", + Provisioners: provisioner.List{prov}, + }, + Logger: json.RawMessage(`{"format": "text"}`), + } + c, err := ca.New(cfg) + require.NoError(t, err) + + // instantiate a client for the CA running at the random address + caClient, err := ca.NewClient( + fmt.Sprintf("https://localhost:%s", port), + ca.WithRootFile(rootFilepath), + ) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + err = c.Run() + require.ErrorIs(t, err, http.ErrServerClosed) + }() + + // require OK health response as the baseline + ctx := context.Background() + healthResponse, err := caClient.HealthWithContext(ctx) + require.NoError(t, err) + if assert.NotNil(t, healthResponse) { + require.Equal(t, "ok", healthResponse.Status) + } + + // expect an error when retrieving an invalid root + rootResponse, err := caClient.RootWithContext(ctx, "invalid") + var firstErr *errs.Error + if assert.ErrorAs(t, err, &firstErr) { + assert.Equal(t, 404, firstErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", firstErr.Err.Error()) + assert.NotEmpty(t, firstErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + assert.Nil(t, rootResponse) + + // expect an error when retrieving an invalid root and provided request ID + rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") + var secondErr *errs.Error + if assert.ErrorAs(t, err, &secondErr) { + assert.Equal(t, 404, secondErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", secondErr.Err.Error()) + assert.Equal(t, "reqID", secondErr.RequestID) + } + assert.Nil(t, rootResponse) + + // prepare a Sign request + subject := "test" + decryptedJWK := decryptPrivateKey(t, jwe, password) + ott := generateOTT(t, decryptedJWK, subject) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest(subject, []string{subject}, signer) + require.NoError(t, err) + + // perform the Sign request using the OTT and CSR + signResponse, err := caClient.SignWithContext(client.NewRequestIDContext(ctx, "signRequestID"), &api.SignRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: csr}, + OTT: ott, + NotAfter: api.NewTimeDuration(time.Now().Add(1 * time.Hour)), + NotBefore: api.NewTimeDuration(time.Now().Add(-1 * time.Hour)), + }) + assert.NoError(t, err) + + // assert a certificate was returned for the subject "test" + if assert.NotNil(t, signResponse) { + assert.Len(t, signResponse.CertChainPEM, 2) + cert, err := x509.ParseCertificate(signResponse.CertChainPEM[0].Raw) + assert.NoError(t, err) + if assert.NotNil(t, cert) { + assert.Equal(t, "test", cert.Subject.CommonName) + assert.Contains(t, cert.DNSNames, "test") + } + } + + // done testing; stop and wait for the server to quit + err = c.Stop() + require.NoError(t, err) + + wg.Wait() +} + +func decryptPrivateKey(t *testing.T, jwe *jose.JSONWebEncryption, pass []byte) *jose.JSONWebKey { + t.Helper() + d, err := jwe.Decrypt(pass) + require.NoError(t, err) + + jwk := &jose.JSONWebKey{} + err = json.Unmarshal(d, jwk) + require.NoError(t, err) + + return jwk +} + +func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string { + t.Helper() + now := time.Now() + + keyID, err := jose.Thumbprint(jwk) + require.NoError(t, err) + + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", keyID) + signer, err := jose.NewSigner(jose.SigningKey{Key: jwk.Key}, opts) + require.NoError(t, err) + + id, err := randutil.ASCII(64) + require.NoError(t, err) + + cl := struct { + jose.Claims + SANS []string `json:"sans"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: subject, + Issuer: "jwk", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(time.Minute)), + Audience: []string{"https://127.0.0.1/1.0/sign"}, + }, + SANS: []string{subject}, + } + raw, err := jose.Signed(signer).Claims(cl).CompactSerialize() + require.NoError(t, err) + + return raw +} + +func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server { + t.Helper() + + key, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key) + require.NoError(t, err) + + crt, err := ca.SignCSR(csr) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if assert.Equal(t, "signRequestID", r.Header.Get("X-Request-Id")) { + json.NewEncoder(w).Encode(struct{ Allow bool }{Allow: true}) + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusBadRequest) + })) + trustedRoots := x509.NewCertPool() + trustedRoots.AddCert(ca.Root) + + srv.TLS = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{crt.Raw, ca.Intermediate.Raw}, + PrivateKey: key, + Leaf: crt, + }, + }, + ClientCAs: trustedRoots, + ClientAuth: tls.RequireAndVerifyClientCert, + ServerName: "localhost", + } + + return srv +}