diff --git a/logging/context.go b/logging/context.go index b24b3638..ab8464d0 100644 --- a/logging/context.go +++ b/logging/context.go @@ -21,14 +21,27 @@ func NewRequestID() string { return xid.New().String() } -// RequestID returns a new middleware that gets the given header and sets it -// in the context so it can be written in the logger. If the header does not -// exists or it's the empty string, it uses github.com/rs/xid to create a new -// one. +// defaultRequestIDHeader is the header name used for propagating +// request IDs. If available in an HTTP request, it'll be used instead +// of the X-Smallstep-Id header. +const defaultRequestIDHeader = "X-Request-Id" + +// RequestID returns a new middleware that obtains the current request ID +// and sets it in the context. It first tries to read the request ID from +// the "X-Request-Id" header. If that's not set, it tries to read it from +// the provided header name. If the header does not exist or its value is +// the empty string, it uses github.com/rs/xid to create a new one. func RequestID(headerName string) func(next http.Handler) http.Handler { + if headerName == "" { + headerName = defaultTraceIDHeader + } return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(headerName) + requestID := req.Header.Get(defaultRequestIDHeader) + if requestID == "" { + requestID = req.Header.Get(headerName) + } + if requestID == "" { requestID = NewRequestID() req.Header.Set(headerName, requestID) diff --git a/logging/context_test.go b/logging/context_test.go new file mode 100644 index 00000000..c519539d --- /dev/null +++ b/logging/context_test.go @@ -0,0 +1,94 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newRequest(t *testing.T) *http.Request { + r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + require.NoError(t, err) + return r +} + +func TestRequestID(t *testing.T) { + requestWithID := newRequest(t) + requestWithID.Header.Set("X-Request-Id", "reqID") + requestWithoutID := newRequest(t) + requestWithEmptyHeader := newRequest(t) + requestWithEmptyHeader.Header.Set("X-Request-Id", "") + requestWithSmallstepID := newRequest(t) + requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + + tests := []struct { + name string + headerName string + handler http.HandlerFunc + req *http.Request + }{ + { + name: "default-request-id", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "reqID", reqID) + } + }, + req: requestWithID, + }, + { + name: "no-request-id", + headerName: "X-Request-Id", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + value := r.Header.Get("X-Request-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithoutID, + }, + { + name: "empty-header-name", + headerName: "", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + value := r.Header.Get("X-Smallstep-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithEmptyHeader, + }, + { + name: "fallback-header-name", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "smallstepID", reqID) + } + }, + req: requestWithSmallstepID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := RequestID(tt.headerName) + h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req) + }) + } +}