Improve CA client request ID handling

pull/1743/head
Herman Slatman 3 months ago
parent 06696e6492
commit 532b9df0a3
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -109,9 +109,8 @@ const requestIDHeader = "X-Request-Id"
// empty, the context is searched for a request ID. If that's also empty, a new // empty, the context is searched for a request ID. If that's also empty, a new
// request ID is generated. // request ID is generated.
func enforceRequestID(r *http.Request) { func enforceRequestID(r *http.Request) {
requestID := r.Header.Get(requestIDHeader) if requestID := r.Header.Get(requestIDHeader); requestID == "" {
if requestID == "" { if reqID, ok := client.RequestIDFromContext(r.Context()); ok {
if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" {
// TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been
// used before by the client (unless it's a retry for the same request)? // used before by the client (unless it's a retry for the same request)?
requestID = reqID requestID = reqID
@ -759,14 +758,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
caClient := &http.Client{Transport: tr} httpClient := &http.Client{Transport: tr}
retry: retry:
req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := caClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
return nil, clientError(err) return nil, clientError(err)
} }
@ -836,14 +835,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
caClient := &http.Client{Transport: tr} httpClient := &http.Client{Transport: tr}
retry: retry:
httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
resp, err := caClient.Do(httpReq) resp, err := httpClient.Do(httpReq)
if err != nil { if err != nil {
return nil, clientError(err) return nil, clientError(err)
} }
@ -1530,7 +1529,7 @@ func readError(r *http.Response) error {
defer r.Body.Close() defer r.Body.Close()
apiErr := new(errs.Error) apiErr := new(errs.Error)
if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil {
return err return fmt.Errorf("failed decoding CA error response: %w", err)
} }
apiErr.RequestID = r.Header.Get("X-Request-Id") apiErr.RequestID = r.Header.Get("X-Request-Id")
return apiErr return apiErr

@ -4,14 +4,15 @@ import "context"
type requestIDKey struct{} type requestIDKey struct{}
// WithRequestID returns a new context with the given requestID added to the // NewRequestIDContext returns a new context with the given request ID added to the
// context. // context.
func WithRequestID(ctx context.Context, requestID string) context.Context { func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID) return context.WithValue(ctx, requestIDKey{}, requestID)
} }
// GetRequestID returns the request id from the context if it exists. // RequestIDFromContext returns the request ID from the context if it exists.
func GetRequestID(ctx context.Context) (string, bool) { // and is not empty.
func RequestIDFromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string) v, ok := ctx.Value(requestIDKey{}).(string)
return v, ok return v, ok && v != ""
} }

@ -17,16 +17,17 @@ import (
"testing" "testing"
"time" "time"
"go.step.sm/crypto/x509util" sassert "github.com/smallstep/assert"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -196,7 +197,7 @@ func TestClient_Version(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Version() = %v, want nil", got) t.Errorf("Client.Version() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Version() = %v, want %v", got, tt.response) t.Errorf("Client.Version() = %v, want %v", got, tt.response)
@ -247,7 +248,7 @@ func TestClient_Health(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Health() = %v, want nil", got) t.Errorf("Client.Health() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Health() = %v, want %v", got, tt.response) t.Errorf("Client.Health() = %v, want %v", got, tt.response)
@ -304,7 +305,7 @@ func TestClient_Root(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Root() = %v, want nil", got) t.Errorf("Client.Root() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Root() = %v, want %v", got, tt.response) t.Errorf("Client.Root() = %v, want %v", got, tt.response)
@ -359,7 +360,7 @@ func TestClient_Sign(t *testing.T) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") sassert.Fatal(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -386,7 +387,7 @@ func TestClient_Sign(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Sign() = %v, want nil", got) t.Errorf("Client.Sign() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Sign() = %v, want %v", got, tt.response) t.Errorf("Client.Sign() = %v, want %v", got, tt.response)
@ -431,7 +432,7 @@ func TestClient_Revoke(t *testing.T) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") sassert.Fatal(t, ok, "response expected to be error type")
render.Error(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
@ -458,7 +459,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got) t.Errorf("Client.Revoke() = %v, want nil", got)
} }
assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -520,10 +521,10 @@ func TestClient_Renew(t *testing.T) {
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error()) sassert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -589,10 +590,10 @@ func TestClient_RenewWithToken(t *testing.T) {
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error()) sassert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
@ -659,10 +660,10 @@ func TestClient_Rekey(t *testing.T) {
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error()) sassert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -722,7 +723,7 @@ func TestClient_Provisioners(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Provisioners() = %v, want nil", got) t.Errorf("Client.Provisioners() = %v, want nil", got)
} }
assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
@ -781,10 +782,10 @@ func TestClient_ProvisionerKey(t *testing.T) {
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, tt.err.Error(), err.Error()) sassert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
@ -841,10 +842,10 @@ func TestClient_Roots(t *testing.T) {
t.Errorf("Client.Roots() = %v, want nil", got) t.Errorf("Client.Roots() = %v, want nil", got)
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error()) sassert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response) t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -900,10 +901,10 @@ func TestClient_Federation(t *testing.T) {
t.Errorf("Client.Federation() = %v, want nil", got) t.Errorf("Client.Federation() = %v, want nil", got)
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, tt.err.Error(), err.Error()) sassert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response) t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
@ -963,10 +964,10 @@ func TestClient_SSHRoots(t *testing.T) {
t.Errorf("Client.SSHKeys() = %v, want nil", got) t.Errorf("Client.SSHKeys() = %v, want nil", got)
} }
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, tt.err.Error(), err.Error()) sassert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response)
@ -1069,11 +1070,11 @@ func TestClient_RootFingerprintWithServer(t *testing.T) {
defer srv.Close() defer srv.Close()
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
assert.FatalError(t, err) sassert.FatalError(t, err)
fp, err := client.RootFingerprint() fp, err := client.RootFingerprint()
assert.FatalError(t, err) sassert.FatalError(t, err)
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
} }
func TestClient_SSHBastion(t *testing.T) { func TestClient_SSHBastion(t *testing.T) {
@ -1126,10 +1127,10 @@ func TestClient_SSHBastion(t *testing.T) {
} }
if tt.responseCode != 200 { if tt.responseCode != 200 {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), tt.responseCode) sassert.Equals(t, sc.StatusCode(), tt.responseCode)
} }
assert.HasPrefix(t, err.Error(), tt.err.Error()) sassert.HasPrefix(t, err.Error(), tt.err.Error())
} }
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
@ -1164,3 +1165,44 @@ func TestClient_GetCaURL(t *testing.T) {
}) })
} }
} }
func Test_enforceRequestID(t *testing.T) {
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
set.Header.Set("X-Request-Id", "already-set")
inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context"))
new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
tests := []struct {
name string
r *http.Request
want string
}{
{
name: "set",
r: set,
want: "already-set",
},
{
name: "context",
r: inContext,
want: "from-context",
},
{
name: "new",
r: new,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
enforceRequestID(tt.r)
v := tt.r.Header.Get("X-Request-Id")
if assert.NotEmpty(t, v) {
if tt.want != "" {
assert.Equal(t, tt.want, v)
}
}
})
}
}

@ -113,7 +113,7 @@ func Test_reflectRequestID(t *testing.T) {
assert.Nil(t, rootResponse) assert.Nil(t, rootResponse)
// expect an error when retrieving an invalid root and provided request ID // expect an error when retrieving an invalid root and provided request ID
rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid") rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid")
if assert.Error(t, err) { if assert.Error(t, err) {
apiErr := &errs.Error{} apiErr := &errs.Error{}
if assert.ErrorAs(t, err, &apiErr) { if assert.ErrorAs(t, err, &apiErr) {

Loading…
Cancel
Save