From 532b9df0a3cbf312ef0e54aa8b350c00309e6bab Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 13:57:37 +0100 Subject: [PATCH] Improve CA client request ID handling --- ca/client.go | 15 +++-- ca/client/requestid.go | 11 ++-- ca/client_test.go | 120 +++++++++++++++++++++++++------------ test/e2e/requestid_test.go | 2 +- 4 files changed, 95 insertions(+), 53 deletions(-) diff --git a/ca/client.go b/ca/client.go index 0c0f9907..9e245cd7 100644 --- a/ca/client.go +++ b/ca/client.go @@ -109,9 +109,8 @@ const requestIDHeader = "X-Request-Id" // empty, the context is searched for a request ID. If that's also empty, a new // request ID is generated. func enforceRequestID(r *http.Request) { - requestID := r.Header.Get(requestIDHeader) - if requestID == "" { - if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" { + if requestID := r.Header.Get(requestIDHeader); requestID == "" { + if reqID, ok := client.RequestIDFromContext(r.Context()); ok { // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been // used before by the client (unless it's a retry for the same request)? requestID = reqID @@ -759,14 +758,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) - caClient := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - resp, err := caClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, clientError(err) } @@ -836,14 +835,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) - caClient := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") - resp, err := caClient.Do(httpReq) + resp, err := httpClient.Do(httpReq) if err != nil { return nil, clientError(err) } @@ -1530,7 +1529,7 @@ func readError(r *http.Response) error { defer r.Body.Close() apiErr := new(errs.Error) if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { - return err + return fmt.Errorf("failed decoding CA error response: %w", err) } apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr diff --git a/ca/client/requestid.go b/ca/client/requestid.go index de92f8c0..2bebb7e5 100644 --- a/ca/client/requestid.go +++ b/ca/client/requestid.go @@ -4,14 +4,15 @@ import "context" type requestIDKey struct{} -// WithRequestID returns a new context with the given requestID added to the +// NewRequestIDContext returns a new context with the given request ID added to the // context. -func WithRequestID(ctx context.Context, requestID string) context.Context { +func NewRequestIDContext(ctx context.Context, requestID string) context.Context { return context.WithValue(ctx, requestIDKey{}, requestID) } -// GetRequestID returns the request id from the context if it exists. -func GetRequestID(ctx context.Context) (string, bool) { +// RequestIDFromContext returns the request ID from the context if it exists. +// and is not empty. +func RequestIDFromContext(ctx context.Context) (string, bool) { v, ok := ctx.Value(requestIDKey{}).(string) - return v, ok + return v, ok && v != "" } diff --git a/ca/client_test.go b/ca/client_test.go index 6292e3ea..6fe8a135 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -17,16 +17,17 @@ import ( "testing" "time" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" - - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" ) const ( @@ -196,7 +197,7 @@ func TestClient_Version(t *testing.T) { if got != nil { t.Errorf("Client.Version() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Version() = %v, want %v", got, tt.response) @@ -247,7 +248,7 @@ func TestClient_Health(t *testing.T) { if got != nil { t.Errorf("Client.Health() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Health() = %v, want %v", got, tt.response) @@ -304,7 +305,7 @@ func TestClient_Root(t *testing.T) { if got != nil { t.Errorf("Client.Root() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Root() = %v, want %v", got, tt.response) @@ -359,7 +360,7 @@ func TestClient_Sign(t *testing.T) { body := new(api.SignRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + sassert.Fatal(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -386,7 +387,7 @@ func TestClient_Sign(t *testing.T) { if got != nil { t.Errorf("Client.Sign() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Sign() = %v, want %v", got, tt.response) @@ -431,7 +432,7 @@ func TestClient_Revoke(t *testing.T) { body := new(api.RevokeRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + sassert.Fatal(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -458,7 +459,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) + sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -520,10 +521,10 @@ func TestClient_Renew(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -589,10 +590,10 @@ func TestClient_RenewWithToken(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) @@ -659,10 +660,10 @@ func TestClient_Rekey(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -722,7 +723,7 @@ func TestClient_Provisioners(t *testing.T) { if got != nil { t.Errorf("Client.Provisioners() = %v, want nil", got) } - assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) + sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) @@ -781,10 +782,10 @@ func TestClient_ProvisionerKey(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) @@ -841,10 +842,10 @@ func TestClient_Roots(t *testing.T) { t.Errorf("Client.Roots() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -900,10 +901,10 @@ func TestClient_Federation(t *testing.T) { t.Errorf("Client.Federation() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Federation() = %v, want %v", got, tt.response) @@ -963,10 +964,10 @@ func TestClient_SSHRoots(t *testing.T) { t.Errorf("Client.SSHKeys() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) @@ -1069,11 +1070,11 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { defer srv.Close() client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) - assert.FatalError(t, err) + sassert.FatalError(t, err) fp, err := client.RootFingerprint() - assert.FatalError(t, err) - assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) + sassert.FatalError(t, err) + sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { @@ -1126,10 +1127,10 @@ func TestClient_SSHBastion(t *testing.T) { } if tt.responseCode != 200 { var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { @@ -1164,3 +1165,44 @@ func TestClient_GetCaURL(t *testing.T) { }) } } + +func Test_enforceRequestID(t *testing.T) { + set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + set.Header.Set("X-Request-Id", "already-set") + inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) + new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + + tests := []struct { + name string + r *http.Request + want string + }{ + { + name: "set", + r: set, + want: "already-set", + }, + { + name: "context", + r: inContext, + want: "from-context", + }, + { + name: "new", + r: new, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enforceRequestID(tt.r) + + v := tt.r.Header.Get("X-Request-Id") + if assert.NotEmpty(t, v) { + if tt.want != "" { + assert.Equal(t, tt.want, v) + } + } + }) + } +} diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index 62b2feb1..d2f968c3 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -113,7 +113,7 @@ func Test_reflectRequestID(t *testing.T) { assert.Nil(t, rootResponse) // expect an error when retrieving an invalid root and provided request ID - rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid") + rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") if assert.Error(t, err) { apiErr := &errs.Error{} if assert.ErrorAs(t, err, &apiErr) {