Decouple request ID middleware from logging middleware

pull/1743/head
Herman Slatman 2 months ago
parent 535e2a96d5
commit 7e5f10927f
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -15,7 +15,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
@ -171,9 +171,8 @@ retry:
return nil, err
}
requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
if requestID, ok := requestid.FromContext(ctx); ok {
req.Header.Set("X-Request-Id", requestID)
}
secret, err := base64.StdEncoding.DecodeString(w.Secret)

@ -17,7 +17,7 @@ import (
"testing"
"time"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
}
}
// withRequestID is a helper that calls into [logging.WithRequestID] and returns
// a new context with the requestID added to the provided context.
// withRequestID is a helper that calls into [requestid.NewContext] and returns
// a new context with the requestID added.
func withRequestID(ctx context.Context, requestID string) context.Context {
return logging.WithRequestID(ctx, requestID)
return requestid.NewContext(ctx, requestID)
}
func TestWebhookController_Enrich(t *testing.T) {

@ -29,6 +29,7 @@ import (
"github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/metrix"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/monitoring"
"github.com/smallstep/certificates/scep"
@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}
// Add logger if configured
var legacyTraceHeader string
if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", cfg.Logger)
if err != nil {
return nil, err
}
legacyTraceHeader = logger.GetTraceHeader()
handler = logger.Middleware(handler)
insecureHandler = logger.Middleware(insecureHandler)
}
// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
handler = requestid.New(legacyTraceHeader).Middleware(handler)
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)
// Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)

@ -2,8 +2,9 @@ package errs
import (
"fmt"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestError_MarshalJSON(t *testing.T) {
@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) {
Err: tt.fields.Err,
}
got, err := e.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, got)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want)
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := new(Error)
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
//nolint:govet // best option
if !reflect.DeepEqual(tt.expected, e) {
t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e)
err := e.UnmarshalJSON(tt.args.data)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expected, e)
})
}
}

@ -0,0 +1,82 @@
package requestid
import (
"context"
"net/http"
"github.com/rs/xid"
)
const (
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
requestIDHeader = "X-Request-Id"
// defaultTraceHeader is the default Smallstep tracing header that's currently
// in use. It is used as a fallback to retrieve a request ID from, if the
// "X-Request-Id" request header is not set.
defaultTraceHeader = "X-Smallstep-Id"
)
type Handler struct {
legacyTraceHeader string
}
// New creates a new request ID [handler]. It takes a trace header,
// which is used keep the legacy behavior intact, which relies on the
// X-Smallstep-Id header instead of X-Request-Id.
func New(legacyTraceHeader string) *Handler {
if legacyTraceHeader == "" {
legacyTraceHeader = defaultTraceHeader
}
return &Handler{legacyTraceHeader: legacyTraceHeader}
}
// Middleware wraps an [http.Handler] with request ID extraction
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
// header if not set. If both are not set, a new request ID is generated.
// In all cases, the request ID is added to the request context, and
// set to be reflected in the response.
func (h *Handler) Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(h.legacyTraceHeader)
}
if requestID == "" {
requestID = newRequestID()
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
}
// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)
// continue down the handler chain
ctx := NewContext(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// newRequestID creates a new request ID using github.com/rs/xid.
func newRequestID() string {
return xid.New().String()
}
type requestIDKey struct{}
// NewContext returns a new context with the given request ID added to the
// context.
func NewContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// FromContext returns the request ID from the context if it exists and
// is not the empty value.
func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string)
return v, ok && v != ""
}

@ -1,4 +1,4 @@
package logging
package requestid
import (
"net/http"
@ -10,12 +10,13 @@ import (
)
func newRequest(t *testing.T) *http.Request {
t.Helper()
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
require.NoError(t, err)
return r
}
func TestRequestID(t *testing.T) {
func Test_Middleware(t *testing.T) {
requestWithID := newRequest(t)
requestWithID.Header.Set("X-Request-Id", "reqID")
requestWithoutID := newRequest(t)
@ -23,20 +24,19 @@ func TestRequestID(t *testing.T) {
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
requestWithSmallstepID := newRequest(t)
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")
tests := []struct {
name string
headerName string
handler http.HandlerFunc
req *http.Request
name string
traceHeader string
next http.HandlerFunc
req *http.Request
}{
{
name: "default-request-id",
headerName: defaultTraceIDHeader,
handler: func(w http.ResponseWriter, r *http.Request) {
name: "default-request-id",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "reqID", reqID)
}
@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) {
req: requestWithID,
},
{
name: "no-request-id",
headerName: "X-Request-Id",
handler: func(w http.ResponseWriter, r *http.Request) {
name: "no-request-id",
traceHeader: "X-Request-Id",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
value := r.Header.Get("X-Request-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) {
req: requestWithoutID,
},
{
name: "empty-header-name",
headerName: "",
handler: func(w http.ResponseWriter, r *http.Request) {
name: "empty-header",
traceHeader: "",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
value := r.Header.Get("X-Smallstep-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) {
req: requestWithEmptyHeader,
},
{
name: "fallback-header-name",
headerName: defaultTraceIDHeader,
handler: func(w http.ResponseWriter, r *http.Request) {
name: "fallback-header-name",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "smallstepID", reqID)
}
@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := RequestID(tt.headerName)
h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req)
handler := New(tt.traceHeader).Middleware(tt.next)
w := httptest.NewRecorder()
handler.ServeHTTP(w, tt.req)
assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
})
}
}

@ -2,82 +2,18 @@ package logging
import (
"context"
"net/http"
"github.com/rs/xid"
)
type key int
const (
// RequestIDKey is the context key that should store the request identifier.
RequestIDKey key = iota
// UserIDKey is the context key that should store the user identifier.
UserIDKey
)
// NewRequestID creates a new request id using github.com/rs/xid.
func NewRequestID() string {
return xid.New().String()
}
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
const requestIDHeader = "X-Request-Id"
// RequestID returns a new middleware that obtains the current request ID
// and sets it in the context. It first tries to read the request ID from
// the "X-Request-Id" header. If that's not set, it tries to read it from
// the provided header name. If the header does not exist or its value is
// the empty string, it uses github.com/rs/xid to create a new one.
func RequestID(headerName string) func(next http.Handler) http.Handler {
if headerName == "" {
headerName = defaultTraceIDHeader
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(headerName)
}
if requestID == "" {
requestID = NewRequestID()
req.Header.Set(headerName, requestID)
}
// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)
// continue down the handler chain
ctx := WithRequestID(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
}
// WithRequestID returns a new context with the given requestID added to the
// context.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, RequestIDKey, requestID)
}
// GetRequestID returns the request id from the context if it exists.
func GetRequestID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(RequestIDKey).(string)
return v, ok
}
type userIDKey struct{}
// WithUserID decodes the token, extracts the user from the payload and stores
// it in the context.
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, UserIDKey, userID)
return context.WithValue(ctx, userIDKey{}, userID)
}
// GetUserID returns the request id from the context if it exists.
func GetUserID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(UserIDKey).(string)
return v, ok
v, ok := ctx.Value(userIDKey{}).(string)
return v, ok && v != ""
}

@ -9,6 +9,7 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/smallstep/certificates/internal/requestid"
)
// LoggerHandler creates a logger handler
@ -29,16 +30,15 @@ type options struct {
// NewLoggerHandler returns the given http.Handler with the logger integrated.
func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler {
h := RequestID(logger.GetTraceHeader())
onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT"))
return h(&LoggerHandler{
return &LoggerHandler{
name: name,
logger: logger.GetImpl(),
options: options{
onlyTraceHealthEndpoint: onlyTraceHealthEndpoint,
},
next: next,
})
}
}
// ServeHTTP implements the http.Handler and call to the handler to log with a
@ -54,14 +54,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// writeEntry writes to the Logger writer the request information in the logger.
func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) {
var reqID, user string
var requestID, userID string
ctx := r.Context()
if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" {
reqID = v
if v, ok := requestid.FromContext(ctx); ok {
requestID = v
}
if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" {
user = v
if v, ok := GetUserID(ctx); ok && v != "" {
userID = v
}
// Remote hostname
@ -85,10 +85,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim
status := w.StatusCode()
fields := logrus.Fields{
"request-id": reqID,
"request-id": requestID,
"remote-address": addr,
"name": l.name,
"user-id": user,
"user-id": userID,
"time": t.Format(time.RFC3339),
"duration-ns": d.Nanoseconds(),
"duration": d.String(),

@ -9,6 +9,7 @@ import (
"github.com/newrelic/go-agent/v3/newrelic"
"github.com/pkg/errors"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging"
)
@ -82,7 +83,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware {
txn.AddAttribute("httpResponseCode", strconv.Itoa(status))
// Add custom attributes
if v, ok := logging.GetRequestID(r.Context()); ok {
if v, ok := requestid.FromContext(r.Context()); ok {
txn.AddAttribute("request.id", v)
}

Loading…
Cancel
Save