From a90c4cd7faab905985d65ace0f116016d10a2f43 Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Fri, 20 Mar 2026 14:22:14 +0300 Subject: [PATCH] Add standard middlewares: logging, headers, auth, and panic recovery - Logging: structured slog output with method, URL, status, duration - DefaultHeaders/UserAgent: inject headers without overwriting existing - BearerAuth/BasicAuth: per-request token resolution and static credentials - Recovery: catches panics in the RoundTripper chain --- middleware/auth.go | 33 +++++++++ middleware/auth_test.go | 89 ++++++++++++++++++++++++ middleware/headers.go | 29 ++++++++ middleware/headers_test.go | 107 +++++++++++++++++++++++++++++ middleware/logging.go | 38 +++++++++++ middleware/logging_test.go | 124 ++++++++++++++++++++++++++++++++++ middleware/middleware_test.go | 115 +++++++++++++++++++++++++++++++ middleware/recovery.go | 22 ++++++ middleware/recovery_test.go | 51 ++++++++++++++ 9 files changed, 608 insertions(+) create mode 100644 middleware/auth.go create mode 100644 middleware/auth_test.go create mode 100644 middleware/headers.go create mode 100644 middleware/headers_test.go create mode 100644 middleware/logging.go create mode 100644 middleware/logging_test.go create mode 100644 middleware/middleware_test.go create mode 100644 middleware/recovery.go create mode 100644 middleware/recovery_test.go diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..0564a40 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "context" + "net/http" +) + +// BearerAuth returns a middleware that sets the Authorization header to a +// Bearer token obtained by calling tokenFunc on each request. If tokenFunc +// returns an error, the request is not sent and the error is returned. +func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + token, err := tokenFunc(req.Context()) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + return next.RoundTrip(req) + }) + } +} + +// BasicAuth returns a middleware that sets HTTP Basic Authentication +// credentials on every outgoing request. +func BasicAuth(username, password string) Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + req.SetBasicAuth(username, password) + return next.RoundTrip(req) + }) + } +} diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..af64891 --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,89 @@ +package middleware_test + +import ( + "context" + "errors" + "net/http" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func TestBearerAuth(t *testing.T) { + t.Run("sets authorization header", func(t *testing.T) { + var captured http.Header + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req.Header.Clone() + return okResponse(), nil + }) + + tokenFunc := func(_ context.Context) (string, error) { + return "my-secret-token", nil + } + + transport := middleware.BearerAuth(tokenFunc)(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + want := "Bearer my-secret-token" + if got := captured.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + }) + + t.Run("returns error when tokenFunc fails", func(t *testing.T) { + base := mockTransport(func(req *http.Request) (*http.Response, error) { + t.Fatal("base transport should not be called") + return nil, nil + }) + + tokenErr := errors.New("token expired") + tokenFunc := func(_ context.Context) (string, error) { + return "", tokenErr + } + + transport := middleware.BearerAuth(tokenFunc)(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + _, err := transport.RoundTrip(req) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, tokenErr) { + t.Errorf("got error %v, want %v", err, tokenErr) + } + }) +} + +func TestBasicAuth(t *testing.T) { + t.Run("sets basic auth header", func(t *testing.T) { + var capturedReq *http.Request + base := mockTransport(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return okResponse(), nil + }) + + transport := middleware.BasicAuth("user", "pass")(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + username, password, ok := capturedReq.BasicAuth() + if !ok { + t.Fatal("BasicAuth() returned ok=false") + } + if username != "user" { + t.Errorf("username = %q, want %q", username, "user") + } + if password != "pass" { + t.Errorf("password = %q, want %q", password, "pass") + } + }) +} diff --git a/middleware/headers.go b/middleware/headers.go new file mode 100644 index 0000000..895f890 --- /dev/null +++ b/middleware/headers.go @@ -0,0 +1,29 @@ +package middleware + +import "net/http" + +// DefaultHeaders returns a middleware that adds the given headers to every +// outgoing request. Existing headers on the request are not overwritten. +func DefaultHeaders(headers http.Header) Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for key, values := range headers { + if req.Header.Get(key) != "" { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + return next.RoundTrip(req) + }) + } +} + +// UserAgent returns a middleware that sets the User-Agent header on every +// outgoing request, unless one is already present. +func UserAgent(ua string) Middleware { + return DefaultHeaders(http.Header{ + "User-Agent": {ua}, + }) +} diff --git a/middleware/headers_test.go b/middleware/headers_test.go new file mode 100644 index 0000000..bf9bcbf --- /dev/null +++ b/middleware/headers_test.go @@ -0,0 +1,107 @@ +package middleware_test + +import ( + "net/http" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func TestDefaultHeaders(t *testing.T) { + t.Run("adds headers without overwriting existing", func(t *testing.T) { + defaults := http.Header{ + "X-Custom": {"default-value"}, + "X-Untouched": {"from-middleware"}, + } + + var captured http.Header + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req.Header.Clone() + return okResponse(), nil + }) + + transport := middleware.DefaultHeaders(defaults)(base) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("X-Custom", "request-value") + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := captured.Get("X-Custom"); got != "request-value" { + t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value") + } + if got := captured.Get("X-Untouched"); got != "from-middleware" { + t.Errorf("X-Untouched = %q, want %q", got, "from-middleware") + } + }) + + t.Run("adds headers when absent", func(t *testing.T) { + defaults := http.Header{ + "Accept": {"application/json"}, + } + + var captured http.Header + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req.Header.Clone() + return okResponse(), nil + }) + + transport := middleware.DefaultHeaders(defaults)(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := captured.Get("Accept"); got != "application/json" { + t.Errorf("Accept = %q, want %q", got, "application/json") + } + }) +} + +func TestUserAgent(t *testing.T) { + t.Run("sets user agent header", func(t *testing.T) { + var captured http.Header + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req.Header.Clone() + return okResponse(), nil + }) + + transport := middleware.UserAgent("httpx/1.0")(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := captured.Get("User-Agent"); got != "httpx/1.0" { + t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0") + } + }) + + t.Run("does not overwrite existing user agent", func(t *testing.T) { + var captured http.Header + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req.Header.Clone() + return okResponse(), nil + }) + + transport := middleware.UserAgent("httpx/1.0")(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("User-Agent", "custom-agent") + + _, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := captured.Get("User-Agent"); got != "custom-agent" { + t.Errorf("User-Agent = %q, want %q", got, "custom-agent") + } + }) +} diff --git a/middleware/logging.go b/middleware/logging.go new file mode 100644 index 0000000..c215c28 --- /dev/null +++ b/middleware/logging.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "log/slog" + "net/http" + "time" +) + +// Logging returns a middleware that logs each request's method, URL, status +// code, duration, and error (if any) using the provided structured logger. +// Successful responses are logged at Info level; errors at Error level. +func Logging(logger *slog.Logger) Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + start := time.Now() + + resp, err := next.RoundTrip(req) + + duration := time.Since(start) + attrs := []slog.Attr{ + slog.String("method", req.Method), + slog.String("url", req.URL.String()), + slog.Duration("duration", duration), + } + + if err != nil { + attrs = append(attrs, slog.String("error", err.Error())) + logger.LogAttrs(req.Context(), slog.LevelError, "request failed", attrs...) + return resp, err + } + + attrs = append(attrs, slog.Int("status", resp.StatusCode)) + logger.LogAttrs(req.Context(), slog.LevelInfo, "request completed", attrs...) + + return resp, nil + }) + } +} diff --git a/middleware/logging_test.go b/middleware/logging_test.go new file mode 100644 index 0000000..87cdb59 --- /dev/null +++ b/middleware/logging_test.go @@ -0,0 +1,124 @@ +package middleware_test + +import ( + "context" + "errors" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +// captureHandler is a slog.Handler that captures log records for inspection. +type captureHandler struct { + mu sync.Mutex + records []slog.Record +} + +func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true } + +func (h *captureHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + h.records = append(h.records, r) + return nil +} + +func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *captureHandler) WithGroup(_ string) slog.Handler { return h } + +func (h *captureHandler) lastRecord() slog.Record { + h.mu.Lock() + defer h.mu.Unlock() + return h.records[len(h.records)-1] +} + +func TestLogging(t *testing.T) { + t.Run("logs method url status duration on success", func(t *testing.T) { + handler := &captureHandler{} + logger := slog.New(handler) + + base := mockTransport(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + }, nil + }) + + transport := middleware.Logging(logger)(base) + req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil) + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + + rec := handler.lastRecord() + if rec.Level != slog.LevelInfo { + t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo) + } + + attrs := map[string]string{} + rec.Attrs(func(a slog.Attr) bool { + attrs[a.Key] = a.Value.String() + return true + }) + + if attrs["method"] != "POST" { + t.Errorf("method = %q, want %q", attrs["method"], "POST") + } + if attrs["url"] != "http://example.com/api" { + t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api") + } + if _, ok := attrs["status"]; !ok { + t.Error("missing status attribute") + } + if _, ok := attrs["duration"]; !ok { + t.Error("missing duration attribute") + } + }) + + t.Run("logs error on failure", func(t *testing.T) { + handler := &captureHandler{} + logger := slog.New(handler) + + base := mockTransport(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + }) + + transport := middleware.Logging(logger)(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil) + _, err := transport.RoundTrip(req) + if err == nil { + t.Fatal("expected error, got nil") + } + + rec := handler.lastRecord() + if rec.Level != slog.LevelError { + t.Errorf("got level %v, want %v", rec.Level, slog.LevelError) + } + + attrs := map[string]string{} + rec.Attrs(func(a slog.Attr) bool { + attrs[a.Key] = a.Value.String() + return true + }) + + if attrs["error"] != "connection refused" { + t.Errorf("error = %q, want %q", attrs["error"], "connection refused") + } + if _, ok := attrs["method"]; !ok { + t.Error("missing method attribute") + } + if _, ok := attrs["url"]; !ok { + t.Error("missing url attribute") + } + }) +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..96e54b8 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,115 @@ +package middleware_test + +import ( + "io" + "net/http" + "strings" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { + return middleware.RoundTripperFunc(fn) +} + +func okResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + } +} + +func TestChain(t *testing.T) { + t.Run("applies middlewares in correct order", func(t *testing.T) { + var order []string + + mwA := func(next http.RoundTripper) http.RoundTripper { + return mockTransport(func(req *http.Request) (*http.Response, error) { + order = append(order, "A-before") + resp, err := next.RoundTrip(req) + order = append(order, "A-after") + return resp, err + }) + } + + mwB := func(next http.RoundTripper) http.RoundTripper { + return mockTransport(func(req *http.Request) (*http.Response, error) { + order = append(order, "B-before") + resp, err := next.RoundTrip(req) + order = append(order, "B-after") + return resp, err + }) + } + + mwC := func(next http.RoundTripper) http.RoundTripper { + return mockTransport(func(req *http.Request) (*http.Response, error) { + order = append(order, "C-before") + resp, err := next.RoundTrip(req) + order = append(order, "C-after") + return resp, err + }) + } + + base := mockTransport(func(req *http.Request) (*http.Response, error) { + order = append(order, "base") + return okResponse(), nil + }) + + chained := middleware.Chain(mwA, mwB, mwC)(base) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + _, err := chained.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"} + if len(order) != len(expected) { + t.Fatalf("got %v, want %v", order, expected) + } + for i, v := range expected { + if order[i] != v { + t.Fatalf("order[%d] = %q, want %q", i, order[i], v) + } + } + }) + + t.Run("empty chain returns base transport", func(t *testing.T) { + called := false + base := mockTransport(func(req *http.Request) (*http.Response, error) { + called = true + return okResponse(), nil + }) + + chained := middleware.Chain()(base) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + _, err := chained.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("base transport was not called") + } + }) +} + +func TestRoundTripperFunc(t *testing.T) { + t.Run("implements RoundTripper", func(t *testing.T) { + fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return okResponse(), nil + }) + + var rt http.RoundTripper = fn + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + }) +} diff --git a/middleware/recovery.go b/middleware/recovery.go new file mode 100644 index 0000000..da54994 --- /dev/null +++ b/middleware/recovery.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "fmt" + "net/http" +) + +// Recovery returns a middleware that recovers from panics in the inner +// RoundTripper chain. A recovered panic is converted to an error wrapping +// the panic value. +func Recovery() Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic recovered in round trip: %v", r) + } + }() + return next.RoundTrip(req) + }) + } +} diff --git a/middleware/recovery_test.go b/middleware/recovery_test.go new file mode 100644 index 0000000..bd0144e --- /dev/null +++ b/middleware/recovery_test.go @@ -0,0 +1,51 @@ +package middleware_test + +import ( + "net/http" + "strings" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func TestRecovery(t *testing.T) { + t.Run("recovers from panic and returns error", func(t *testing.T) { + base := mockTransport(func(req *http.Request) (*http.Response, error) { + panic("something went wrong") + }) + + transport := middleware.Recovery()(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + resp, err := transport.RoundTrip(req) + if err == nil { + t.Fatal("expected error, got nil") + } + if resp != nil { + t.Errorf("expected nil response, got %v", resp) + } + if !strings.Contains(err.Error(), "panic recovered") { + t.Errorf("error = %q, want it to contain %q", err.Error(), "panic recovered") + } + if !strings.Contains(err.Error(), "something went wrong") { + t.Errorf("error = %q, want it to contain %q", err.Error(), "something went wrong") + } + }) + + t.Run("passes through normal responses", func(t *testing.T) { + base := mockTransport(func(req *http.Request) (*http.Response, error) { + return okResponse(), nil + }) + + transport := middleware.Recovery()(base) + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + }) +}