diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index c33dfa23..1e08b8b7 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,7 +15,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -171,9 +171,8 @@ retry: return nil, err } - requestID, ok := logging.GetRequestID(ctx) - if ok { - req.Header.Set("X-Request-ID", requestID) + if requestID, ok := requestid.FromContext(ctx); ok { + req.Header.Set("X-Request-Id", requestID) } secret, err := base64.StdEncoding.DecodeString(w.Secret) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 60dcdbc7..4c80796f 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } } -// withRequestID is a helper that calls into [logging.WithRequestID] and returns -// a new context with the requestID added to the provided context. +// withRequestID is a helper that calls into [requestid.NewContext] and returns +// a new context with the requestID added. func withRequestID(ctx context.Context, requestID string) context.Context { - return logging.WithRequestID(ctx, requestID) + return requestid.NewContext(ctx, requestID) } func TestWebhookController_Enrich(t *testing.T) { diff --git a/ca/ca.go b/ca/ca.go index 4146466d..ab4a1a9b 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -29,6 +29,7 @@ import ( "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/metrix" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/scep" @@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Add logger if configured + var legacyTraceHeader string if len(cfg.Logger) > 0 { logger, err := logging.New("ca", cfg.Logger) if err != nil { return nil, err } + legacyTraceHeader = logger.GetTraceHeader() handler = logger.Middleware(handler) insecureHandler = logger.Middleware(insecureHandler) } + // always use request ID middleware; traceHeader is provided for backwards compatibility (for now) + handler = requestid.New(legacyTraceHeader).Middleware(handler) + insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) + // Create context with all the necessary values. baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) diff --git a/errs/errors_test.go b/errs/errors_test.go index 7b83c8d9..11590d7d 100644 --- a/errs/errors_test.go +++ b/errs/errors_test.go @@ -2,8 +2,9 @@ package errs import ( "fmt" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestError_MarshalJSON(t *testing.T) { @@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) { Err: tt.fields.Err, } got, err := e.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := new(Error) - if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { - t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - } - //nolint:govet // best option - if !reflect.DeepEqual(tt.expected, e) { - t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e) + err := e.UnmarshalJSON(tt.args.data) + if tt.wantErr { + assert.Error(t, err) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.expected, e) }) } } diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go new file mode 100644 index 00000000..97f58f8c --- /dev/null +++ b/internal/requestid/requestid.go @@ -0,0 +1,82 @@ +package requestid + +import ( + "context" + "net/http" + + "github.com/rs/xid" +) + +const ( + // requestIDHeader is the header name used for propagating request IDs. If + // available in an HTTP request, it'll be used instead of the X-Smallstep-Id + // header. It'll always be used in response and set to the request ID. + requestIDHeader = "X-Request-Id" + + // defaultTraceHeader is the default Smallstep tracing header that's currently + // in use. It is used as a fallback to retrieve a request ID from, if the + // "X-Request-Id" request header is not set. + defaultTraceHeader = "X-Smallstep-Id" +) + +type Handler struct { + legacyTraceHeader string +} + +// New creates a new request ID [handler]. It takes a trace header, +// which is used keep the legacy behavior intact, which relies on the +// X-Smallstep-Id header instead of X-Request-Id. +func New(legacyTraceHeader string) *Handler { + if legacyTraceHeader == "" { + legacyTraceHeader = defaultTraceHeader + } + + return &Handler{legacyTraceHeader: legacyTraceHeader} +} + +// Middleware wraps an [http.Handler] with request ID extraction +// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id +// header if not set. If both are not set, a new request ID is generated. +// In all cases, the request ID is added to the request context, and +// set to be reflected in the response. +func (h *Handler) Middleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, req *http.Request) { + requestID := req.Header.Get(requestIDHeader) + if requestID == "" { + requestID = req.Header.Get(h.legacyTraceHeader) + } + + if requestID == "" { + requestID = newRequestID() + req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior + } + + // immediately set the request ID to be reflected in the response + w.Header().Set(requestIDHeader, requestID) + + // continue down the handler chain + ctx := NewContext(req.Context(), requestID) + next.ServeHTTP(w, req.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// newRequestID creates a new request ID using github.com/rs/xid. +func newRequestID() string { + return xid.New().String() +} + +type requestIDKey struct{} + +// NewContext returns a new context with the given request ID added to the +// context. +func NewContext(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDKey{}, requestID) +} + +// FromContext returns the request ID from the context if it exists and +// is not the empty value. +func FromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(requestIDKey{}).(string) + return v, ok && v != "" +} diff --git a/logging/context_test.go b/internal/requestid/requestid_test.go similarity index 65% rename from logging/context_test.go rename to internal/requestid/requestid_test.go index da993f7b..4d0e872d 100644 --- a/logging/context_test.go +++ b/internal/requestid/requestid_test.go @@ -1,4 +1,4 @@ -package logging +package requestid import ( "net/http" @@ -10,12 +10,13 @@ import ( ) func newRequest(t *testing.T) *http.Request { + t.Helper() r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) require.NoError(t, err) return r } -func TestRequestID(t *testing.T) { +func Test_Middleware(t *testing.T) { requestWithID := newRequest(t) requestWithID.Header.Set("X-Request-Id", "reqID") requestWithoutID := newRequest(t) @@ -23,20 +24,19 @@ func TestRequestID(t *testing.T) { requestWithEmptyHeader.Header.Set("X-Request-Id", "") requestWithSmallstepID := newRequest(t) requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") - tests := []struct { - name string - headerName string - handler http.HandlerFunc - req *http.Request + name string + traceHeader string + next http.HandlerFunc + req *http.Request }{ { - name: "default-request-id", - headerName: defaultTraceIDHeader, - handler: func(w http.ResponseWriter, r *http.Request) { + name: "default-request-id", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, "reqID", reqID) } @@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) { req: requestWithID, }, { - name: "no-request-id", - headerName: "X-Request-Id", - handler: func(w http.ResponseWriter, r *http.Request) { + name: "no-request-id", + traceHeader: "X-Request-Id", + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) value := r.Header.Get("X-Request-Id") assert.NotEmpty(t, value) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, value, reqID) } @@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) { req: requestWithoutID, }, { - name: "empty-header-name", - headerName: "", - handler: func(w http.ResponseWriter, r *http.Request) { + name: "empty-header", + traceHeader: "", + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) value := r.Header.Get("X-Smallstep-Id") assert.NotEmpty(t, value) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, value, reqID) } @@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) { req: requestWithEmptyHeader, }, { - name: "fallback-header-name", - headerName: defaultTraceIDHeader, - handler: func(w http.ResponseWriter, r *http.Request) { + name: "fallback-header-name", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, "smallstepID", reqID) } @@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := RequestID(tt.headerName) - h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req) + handler := New(tt.traceHeader).Middleware(tt.next) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, tt.req) + assert.NotEmpty(t, w.Header().Get("X-Request-Id")) }) } } diff --git a/logging/context.go b/logging/context.go index 9d7a7071..212e2560 100644 --- a/logging/context.go +++ b/logging/context.go @@ -2,82 +2,18 @@ package logging import ( "context" - "net/http" - - "github.com/rs/xid" -) - -type key int - -const ( - // RequestIDKey is the context key that should store the request identifier. - RequestIDKey key = iota - // UserIDKey is the context key that should store the user identifier. - UserIDKey ) -// NewRequestID creates a new request id using github.com/rs/xid. -func NewRequestID() string { - return xid.New().String() -} - -// requestIDHeader is the header name used for propagating request IDs. If -// available in an HTTP request, it'll be used instead of the X-Smallstep-Id -// header. It'll always be used in response and set to the request ID. -const requestIDHeader = "X-Request-Id" - -// RequestID returns a new middleware that obtains the current request ID -// and sets it in the context. It first tries to read the request ID from -// the "X-Request-Id" header. If that's not set, it tries to read it from -// the provided header name. If the header does not exist or its value is -// the empty string, it uses github.com/rs/xid to create a new one. -func RequestID(headerName string) func(next http.Handler) http.Handler { - if headerName == "" { - headerName = defaultTraceIDHeader - } - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(requestIDHeader) - if requestID == "" { - requestID = req.Header.Get(headerName) - } - - if requestID == "" { - requestID = NewRequestID() - req.Header.Set(headerName, requestID) - } - - // immediately set the request ID to be reflected in the response - w.Header().Set(requestIDHeader, requestID) - - // continue down the handler chain - ctx := WithRequestID(req.Context(), requestID) - next.ServeHTTP(w, req.WithContext(ctx)) - } - return http.HandlerFunc(fn) - } -} - -// WithRequestID returns a new context with the given requestID added to the -// context. -func WithRequestID(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, RequestIDKey, requestID) -} - -// GetRequestID returns the request id from the context if it exists. -func GetRequestID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(RequestIDKey).(string) - return v, ok -} +type userIDKey struct{} // WithUserID decodes the token, extracts the user from the payload and stores // it in the context. func WithUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, UserIDKey, userID) + return context.WithValue(ctx, userIDKey{}, userID) } // GetUserID returns the request id from the context if it exists. func GetUserID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(UserIDKey).(string) - return v, ok + v, ok := ctx.Value(userIDKey{}).(string) + return v, ok && v != "" } diff --git a/logging/handler.go b/logging/handler.go index a8b77d60..77287690 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/smallstep/certificates/internal/requestid" ) // LoggerHandler creates a logger handler @@ -29,16 +30,15 @@ type options struct { // NewLoggerHandler returns the given http.Handler with the logger integrated. func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler { - h := RequestID(logger.GetTraceHeader()) onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT")) - return h(&LoggerHandler{ + return &LoggerHandler{ name: name, logger: logger.GetImpl(), options: options{ onlyTraceHealthEndpoint: onlyTraceHealthEndpoint, }, next: next, - }) + } } // ServeHTTP implements the http.Handler and call to the handler to log with a @@ -54,14 +54,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // writeEntry writes to the Logger writer the request information in the logger. func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) { - var reqID, user string + var requestID, userID string ctx := r.Context() - if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" { - reqID = v + if v, ok := requestid.FromContext(ctx); ok { + requestID = v } - if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" { - user = v + if v, ok := GetUserID(ctx); ok && v != "" { + userID = v } // Remote hostname @@ -85,10 +85,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim status := w.StatusCode() fields := logrus.Fields{ - "request-id": reqID, + "request-id": requestID, "remote-address": addr, "name": l.name, - "user-id": user, + "user-id": userID, "time": t.Format(time.RFC3339), "duration-ns": d.Nanoseconds(), "duration": d.String(), diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index a0d0886b..7c88ab3b 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -9,6 +9,7 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" ) @@ -82,7 +83,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware { txn.AddAttribute("httpResponseCode", strconv.Itoa(status)) // Add custom attributes - if v, ok := logging.GetRequestID(r.Context()); ok { + if v, ok := requestid.FromContext(r.Context()); ok { txn.AddAttribute("request.id", v) }