package middleware_test import ( "context" "errors" "io" "log/slog" "net/http" "strings" "sync" "testing" "git.codelab.vc/pkg/httpx/middleware" ) // captureHandler is a slog.Handler that captures log records for inspection. type captureHandler struct { mu sync.Mutex records []slog.Record } func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } func (h *captureHandler) Handle(_ context.Context, r slog.Record) error { h.mu.Lock() defer h.mu.Unlock() h.records = append(h.records, r) return nil } func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } func (h *captureHandler) WithGroup(_ string) slog.Handler { return h } func (h *captureHandler) lastRecord() slog.Record { h.mu.Lock() defer h.mu.Unlock() return h.records[len(h.records)-1] } func TestLogging(t *testing.T) { t.Run("logs method url status duration on success", func(t *testing.T) { handler := &captureHandler{} logger := slog.New(handler) base := mockTransport(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("ok")), Header: make(http.Header), }, nil }) transport := middleware.Logging(logger)(base) req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil) resp, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK) } rec := handler.lastRecord() if rec.Level != slog.LevelInfo { t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo) } attrs := map[string]string{} rec.Attrs(func(a slog.Attr) bool { attrs[a.Key] = a.Value.String() return true }) if attrs["method"] != "POST" { t.Errorf("method = %q, want %q", attrs["method"], "POST") } if attrs["url"] != "http://example.com/api" { t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api") } if _, ok := attrs["status"]; !ok { t.Error("missing status attribute") } if _, ok := attrs["duration"]; !ok { t.Error("missing duration attribute") } }) t.Run("logs error on failure", func(t *testing.T) { handler := &captureHandler{} logger := slog.New(handler) base := mockTransport(func(req *http.Request) (*http.Response, error) { return nil, errors.New("connection refused") }) transport := middleware.Logging(logger)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil) _, err := transport.RoundTrip(req) if err == nil { t.Fatal("expected error, got nil") } rec := handler.lastRecord() if rec.Level != slog.LevelError { t.Errorf("got level %v, want %v", rec.Level, slog.LevelError) } attrs := map[string]string{} rec.Attrs(func(a slog.Attr) bool { attrs[a.Key] = a.Value.String() return true }) if attrs["error"] != "connection refused" { t.Errorf("error = %q, want %q", attrs["error"], "connection refused") } if _, ok := attrs["method"]; !ok { t.Error("missing method attribute") } if _, ok := attrs["url"]; !ok { t.Error("missing url attribute") } }) }