Use a function as the error logger

This commit addresses comment in the code review. Now, instead of
injecting an slog.Logger we can inject any method that can use a more
flexible implementation.
pull/1849/head
Mariano Cano 1 month ago
parent f3f484cee2
commit c0b7c33a58
No known key found for this signature in database

@ -4,26 +4,32 @@ package log
import (
"context"
"fmt"
"log/slog"
"net/http"
"os"
"github.com/pkg/errors"
)
// ErrorKey is the logging attribute key for error values.
const ErrorKey = "error"
type errorLoggerKey struct{}
type loggerKey struct{}
// ErrorLogger is the function type used to log errors.
type ErrorLogger func(http.ResponseWriter, *http.Request, error)
// NewContext creates a new context with the given slog.Logger.
func NewContext(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerKey{}, logger)
func (fn ErrorLogger) call(w http.ResponseWriter, r *http.Request, err error) {
if fn == nil {
return
}
fn(w, r, err)
}
// FromContext returns the logger from the given context.
func FromContext(ctx context.Context) (l *slog.Logger, ok bool) {
l, ok = ctx.Value(loggerKey{}).(*slog.Logger)
// WithErrorLogger returns a new context with the given error logger.
func WithErrorLogger(ctx context.Context, fn ErrorLogger) context.Context {
return context.WithValue(ctx, errorLoggerKey{}, fn)
}
// ErrorLoggerFromContext returns an error logger from the context.
func ErrorLoggerFromContext(ctx context.Context) (fn ErrorLogger) {
fn, _ = ctx.Value(errorLoggerKey{}).(ErrorLogger)
return
}
@ -45,13 +51,10 @@ type fieldCarrier interface {
// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func Error(rw http.ResponseWriter, r *http.Request, err error) {
ctx := r.Context()
if logger, ok := FromContext(ctx); ok && err != nil {
logger.ErrorContext(ctx, "request failed", slog.Any(ErrorKey, err))
}
func Error(w http.ResponseWriter, r *http.Request, err error) {
ErrorLoggerFromContext(r.Context()).call(w, r, err)
fc, ok := rw.(fieldCarrier)
fc, ok := w.(fieldCarrier)
if !ok {
return
}

@ -33,7 +33,11 @@ func TestError(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{}))
req := httptest.NewRequest("GET", "/test", http.NoBody)
reqWithLogger := req.WithContext(NewContext(req.Context(), logger))
reqWithLogger := req.WithContext(WithErrorLogger(req.Context(), func(w http.ResponseWriter, r *http.Request, err error) {
if err != nil {
logger.ErrorContext(r.Context(), "request failed", slog.Any("error", err))
}
}))
tests := []struct {
name string

Loading…
Cancel
Save