Add reflection of request ID in `X-Request-Id` response header

pull/1743/head
Herman Slatman 3 months ago
parent c1c2e73475
commit a58f5956e3
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -21,10 +21,10 @@ func NewRequestID() string {
return xid.New().String()
}
// defaultRequestIDHeader is the header name used for propagating
// request IDs. If available in an HTTP request, it'll be used instead
// of the X-Smallstep-Id header.
const defaultRequestIDHeader = "X-Request-Id"
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
const requestIDHeader = "X-Request-Id"
// RequestID returns a new middleware that obtains the current request ID
// and sets it in the context. It first tries to read the request ID from
@ -37,7 +37,7 @@ func RequestID(headerName string) func(next http.Handler) http.Handler {
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(defaultRequestIDHeader)
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(headerName)
}
@ -47,6 +47,10 @@ func RequestID(headerName string) func(next http.Handler) http.Handler {
req.Header.Set(headerName, requestID)
}
// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)
// continue down the handler chain
ctx := WithRequestID(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}

@ -33,20 +33,21 @@ func TestRequestID(t *testing.T) {
{
name: "default-request-id",
headerName: defaultTraceIDHeader,
handler: func(_ http.ResponseWriter, r *http.Request) {
handler: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "reqID", reqID)
}
assert.Equal(t, "reqID", w.Header().Get("X-Request-Id"))
},
req: requestWithID,
},
{
name: "no-request-id",
headerName: "X-Request-Id",
handler: func(_ http.ResponseWriter, r *http.Request) {
handler: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
value := r.Header.Get("X-Request-Id")
assert.NotEmpty(t, value)
@ -54,13 +55,14 @@ func TestRequestID(t *testing.T) {
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
assert.Equal(t, value, w.Header().Get("X-Request-Id"))
},
req: requestWithoutID,
},
{
name: "empty-header-name",
headerName: "",
handler: func(_ http.ResponseWriter, r *http.Request) {
handler: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
value := r.Header.Get("X-Smallstep-Id")
assert.NotEmpty(t, value)
@ -68,19 +70,21 @@ func TestRequestID(t *testing.T) {
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
assert.Equal(t, value, w.Header().Get("X-Request-Id"))
},
req: requestWithEmptyHeader,
},
{
name: "fallback-header-name",
headerName: defaultTraceIDHeader,
handler: func(_ http.ResponseWriter, r *http.Request) {
handler: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "smallstepID", reqID)
}
assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id"))
},
req: requestWithSmallstepID,
},

Loading…
Cancel
Save