From c0b7c33a58807ff2fee7f654f847af78a9f08107 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 16 May 2024 11:00:36 -0700 Subject: [PATCH] Use a function as the error logger This commit addresses comment in the code review. Now, instead of injecting an slog.Logger we can inject any method that can use a more flexible implementation. --- api/log/log.go | 35 +++++++++++++++++++---------------- api/log/log_test.go | 6 +++++- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/api/log/log.go b/api/log/log.go index b01c8404..6cc61a77 100644 --- a/api/log/log.go +++ b/api/log/log.go @@ -4,26 +4,32 @@ package log import ( "context" "fmt" - "log/slog" "net/http" "os" "github.com/pkg/errors" ) -// ErrorKey is the logging attribute key for error values. -const ErrorKey = "error" +type errorLoggerKey struct{} -type loggerKey struct{} +// ErrorLogger is the function type used to log errors. +type ErrorLogger func(http.ResponseWriter, *http.Request, error) -// NewContext creates a new context with the given slog.Logger. -func NewContext(ctx context.Context, logger *slog.Logger) context.Context { - return context.WithValue(ctx, loggerKey{}, logger) +func (fn ErrorLogger) call(w http.ResponseWriter, r *http.Request, err error) { + if fn == nil { + return + } + fn(w, r, err) } -// FromContext returns the logger from the given context. -func FromContext(ctx context.Context) (l *slog.Logger, ok bool) { - l, ok = ctx.Value(loggerKey{}).(*slog.Logger) +// WithErrorLogger returns a new context with the given error logger. +func WithErrorLogger(ctx context.Context, fn ErrorLogger) context.Context { + return context.WithValue(ctx, errorLoggerKey{}, fn) +} + +// ErrorLoggerFromContext returns an error logger from the context. +func ErrorLoggerFromContext(ctx context.Context) (fn ErrorLogger) { + fn, _ = ctx.Value(errorLoggerKey{}).(ErrorLogger) return } @@ -45,13 +51,10 @@ type fieldCarrier interface { // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. -func Error(rw http.ResponseWriter, r *http.Request, err error) { - ctx := r.Context() - if logger, ok := FromContext(ctx); ok && err != nil { - logger.ErrorContext(ctx, "request failed", slog.Any(ErrorKey, err)) - } +func Error(w http.ResponseWriter, r *http.Request, err error) { + ErrorLoggerFromContext(r.Context()).call(w, r, err) - fc, ok := rw.(fieldCarrier) + fc, ok := w.(fieldCarrier) if !ok { return } diff --git a/api/log/log_test.go b/api/log/log_test.go index b1de3710..e1da274f 100644 --- a/api/log/log_test.go +++ b/api/log/log_test.go @@ -33,7 +33,11 @@ func TestError(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{})) req := httptest.NewRequest("GET", "/test", http.NoBody) - reqWithLogger := req.WithContext(NewContext(req.Context(), logger)) + reqWithLogger := req.WithContext(WithErrorLogger(req.Context(), func(w http.ResponseWriter, r *http.Request, err error) { + if err != nil { + logger.ErrorContext(r.Context(), "request failed", slog.Any("error", err)) + } + })) tests := []struct { name string