diff --git a/logging/context.go b/logging/context.go index ab8464d0..9d7a7071 100644 --- a/logging/context.go +++ b/logging/context.go @@ -21,10 +21,10 @@ func NewRequestID() string { return xid.New().String() } -// defaultRequestIDHeader is the header name used for propagating -// request IDs. If available in an HTTP request, it'll be used instead -// of the X-Smallstep-Id header. -const defaultRequestIDHeader = "X-Request-Id" +// requestIDHeader is the header name used for propagating request IDs. If +// available in an HTTP request, it'll be used instead of the X-Smallstep-Id +// header. It'll always be used in response and set to the request ID. +const requestIDHeader = "X-Request-Id" // RequestID returns a new middleware that obtains the current request ID // and sets it in the context. It first tries to read the request ID from @@ -37,7 +37,7 @@ func RequestID(headerName string) func(next http.Handler) http.Handler { } return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(defaultRequestIDHeader) + requestID := req.Header.Get(requestIDHeader) if requestID == "" { requestID = req.Header.Get(headerName) } @@ -47,6 +47,10 @@ func RequestID(headerName string) func(next http.Handler) http.Handler { req.Header.Set(headerName, requestID) } + // immediately set the request ID to be reflected in the response + w.Header().Set(requestIDHeader, requestID) + + // continue down the handler chain ctx := WithRequestID(req.Context(), requestID) next.ServeHTTP(w, req.WithContext(ctx)) } diff --git a/logging/context_test.go b/logging/context_test.go index c519539d..da993f7b 100644 --- a/logging/context_test.go +++ b/logging/context_test.go @@ -33,20 +33,21 @@ func TestRequestID(t *testing.T) { { name: "default-request-id", headerName: defaultTraceIDHeader, - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) reqID, ok := GetRequestID(r.Context()) if assert.True(t, ok) { assert.Equal(t, "reqID", reqID) } + assert.Equal(t, "reqID", w.Header().Get("X-Request-Id")) }, req: requestWithID, }, { name: "no-request-id", headerName: "X-Request-Id", - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) value := r.Header.Get("X-Request-Id") assert.NotEmpty(t, value) @@ -54,13 +55,14 @@ func TestRequestID(t *testing.T) { if assert.True(t, ok) { assert.Equal(t, value, reqID) } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) }, req: requestWithoutID, }, { name: "empty-header-name", headerName: "", - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) value := r.Header.Get("X-Smallstep-Id") assert.NotEmpty(t, value) @@ -68,19 +70,21 @@ func TestRequestID(t *testing.T) { if assert.True(t, ok) { assert.Equal(t, value, reqID) } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) }, req: requestWithEmptyHeader, }, { name: "fallback-header-name", headerName: defaultTraceIDHeader, - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) reqID, ok := GetRequestID(r.Context()) if assert.True(t, ok) { assert.Equal(t, "smallstepID", reqID) } + assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id")) }, req: requestWithSmallstepID, },