- Logging: structured slog output with method, URL, status, duration - DefaultHeaders/UserAgent: inject headers without overwriting existing - BearerAuth/BasicAuth: per-request token resolution and static credentials - Recovery: catches panics in the RoundTripper chain
125 lines
3.2 KiB
Go
125 lines
3.2 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"git.codelab.vc/pkg/httpx/middleware"
|
|
)
|
|
|
|
// captureHandler is a slog.Handler that captures log records for inspection.
|
|
type captureHandler struct {
|
|
mu sync.Mutex
|
|
records []slog.Record
|
|
}
|
|
|
|
func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
|
|
|
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.records = append(h.records, r)
|
|
return nil
|
|
}
|
|
|
|
func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
|
func (h *captureHandler) WithGroup(_ string) slog.Handler { return h }
|
|
|
|
func (h *captureHandler) lastRecord() slog.Record {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
return h.records[len(h.records)-1]
|
|
}
|
|
|
|
func TestLogging(t *testing.T) {
|
|
t.Run("logs method url status duration on success", func(t *testing.T) {
|
|
handler := &captureHandler{}
|
|
logger := slog.New(handler)
|
|
|
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
Header: make(http.Header),
|
|
}, nil
|
|
})
|
|
|
|
transport := middleware.Logging(logger)(base)
|
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil)
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
rec := handler.lastRecord()
|
|
if rec.Level != slog.LevelInfo {
|
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo)
|
|
}
|
|
|
|
attrs := map[string]string{}
|
|
rec.Attrs(func(a slog.Attr) bool {
|
|
attrs[a.Key] = a.Value.String()
|
|
return true
|
|
})
|
|
|
|
if attrs["method"] != "POST" {
|
|
t.Errorf("method = %q, want %q", attrs["method"], "POST")
|
|
}
|
|
if attrs["url"] != "http://example.com/api" {
|
|
t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api")
|
|
}
|
|
if _, ok := attrs["status"]; !ok {
|
|
t.Error("missing status attribute")
|
|
}
|
|
if _, ok := attrs["duration"]; !ok {
|
|
t.Error("missing duration attribute")
|
|
}
|
|
})
|
|
|
|
t.Run("logs error on failure", func(t *testing.T) {
|
|
handler := &captureHandler{}
|
|
logger := slog.New(handler)
|
|
|
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
return nil, errors.New("connection refused")
|
|
})
|
|
|
|
transport := middleware.Logging(logger)(base)
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil)
|
|
_, err := transport.RoundTrip(req)
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
|
|
rec := handler.lastRecord()
|
|
if rec.Level != slog.LevelError {
|
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelError)
|
|
}
|
|
|
|
attrs := map[string]string{}
|
|
rec.Attrs(func(a slog.Attr) bool {
|
|
attrs[a.Key] = a.Value.String()
|
|
return true
|
|
})
|
|
|
|
if attrs["error"] != "connection refused" {
|
|
t.Errorf("error = %q, want %q", attrs["error"], "connection refused")
|
|
}
|
|
if _, ok := attrs["method"]; !ok {
|
|
t.Error("missing method attribute")
|
|
}
|
|
if _, ok := attrs["url"]; !ok {
|
|
t.Error("missing url attribute")
|
|
}
|
|
})
|
|
}
|