Files
httpx/server/middleware_test.go
Aleksey Shakhmatov 7fae6247d5
All checks were successful
CI / test (push) Successful in 30s
Add comprehensive test coverage for server/ package
Cover edge cases: statusWriter multi-call/default/unwrap, UUID v4 format
and uniqueness, non-string panics, recovery body and log attributes,
4xx log level, default status in logging, request ID propagation,
server defaults/options/listen-error/multiple-hooks/logger, router
groups with empty prefix/inherited middleware/ordering/path params/
isolation, mount trailing slash, health content-type and POST rejection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:55:22 +03:00

509 lines
14 KiB
Go

package server_test
import (
"bytes"
"log/slog"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestChain(t *testing.T) {
t.Run("applies middlewares in correct order", func(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A-before")
next.ServeHTTP(w, r)
order = append(order, "A-after")
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B-before")
next.ServeHTTP(w, r)
order = append(order, "B-after")
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mwA, mwB)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
})
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
chained := server.Chain()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("handler was not called")
}
})
}
func TestChain_SingleMiddleware(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mw)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("single middleware was not called")
}
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
}
func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
w.WriteHeader(http.StatusNotFound) // second call should not change captured status
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=201") {
t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String())
}
}
func TestStatusWriter_WriteDefaultsTo200(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello")) // Write without WriteHeader
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String())
}
}
func TestStatusWriter_Unwrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
rc := http.NewResponseController(w)
if err := rc.Flush(); err != nil {
// httptest.ResponseRecorder implements Flusher, so this should succeed
// if Unwrap works correctly.
http.Error(w, "flush failed", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
})
// Use Logging to wrap in statusWriter
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code == http.StatusInternalServerError {
t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher")
}
}
func TestRequestID(t *testing.T) {
t.Run("generates ID when not present", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if gotID == "" {
t.Fatal("expected request ID in context, got empty")
}
if w.Header().Get("X-Request-Id") != gotID {
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
}
// UUID v4 format: 8-4-4-4-12 hex chars.
if len(gotID) != 36 {
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
}
})
t.Run("preserves existing ID", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", "custom-123")
mw.ServeHTTP(w, req)
if gotID != "custom-123" {
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
}
})
t.Run("context without ID returns empty", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if id := server.RequestIDFromContext(req.Context()); id != "" {
t.Fatalf("expected empty, got %q", id)
}
})
}
func TestRequestID_UUIDFormat(t *testing.T) {
uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !uuidV4Re.MatchString(gotID) {
t.Fatalf("generated ID %q does not match UUID v4 format", gotID)
}
}
func TestRequestID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 1000)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := server.RequestIDFromContext(r.Context())
if _, exists := seen[id]; exists {
t.Fatalf("duplicate request ID: %q", id)
}
seen[id] = struct{}{}
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
for i := 0; i < 1000; i++ {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
}
if len(seen) != 1000 {
t.Fatalf("expected 1000 unique IDs, got %d", len(seen))
}
}
func TestRecovery(t *testing.T) {
t.Run("recovers from panic and returns 500", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("something went wrong")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
})
t.Run("logs panic with logger", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
})
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
}
if !strings.Contains(buf.String(), "boom") {
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
}
})
t.Run("passes through without panic", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestRecovery_PanicWithNonString(t *testing.T) {
tests := []struct {
name string
value any
}{
{"integer", 42},
{"struct", struct{ X int }{X: 1}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic(tt.value)
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected 'panic recovered' in log, got %q", buf.String())
}
})
}
}
func TestRecovery_ResponseBody(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("fail")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
body := strings.TrimSpace(w.Body.String())
if body != "Internal Server Error" {
t.Fatalf("got body %q, want %q", body, "Internal Server Error")
}
}
func TestRecovery_LogAttributes(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
// Put RequestID before Recovery so request_id is in context
handler := server.RequestID()(
server.Recovery(server.WithRecoveryLogger(logger))(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
handler.ServeHTTP(w, req)
logOutput := buf.String()
for _, attr := range []string{"method=", "path=", "request_id="} {
if !strings.Contains(logOutput, attr) {
t.Fatalf("expected %q in log, got %q", attr, logOutput)
}
}
}
func TestLogging(t *testing.T) {
t.Run("logs request details", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "request completed") {
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "POST") {
t.Fatalf("expected method in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "/api/users") {
t.Fatalf("expected path in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "status=201") {
t.Fatalf("expected status=201 in log, got %q", logOutput)
}
})
t.Run("logs error level for 5xx", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("expected ERROR level in log, got %q", logOutput)
}
})
}
func TestLogging_4xxIsInfoLevel(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=INFO") {
t.Fatalf("expected INFO level for 404, got %q", logOutput)
}
if strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("404 should not be logged as ERROR, got %q", logOutput)
}
}
func TestLogging_DefaultStatus200(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello"))
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String())
}
}
func TestLogging_IncludesRequestID(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := server.RequestID()(
server.Logging(logger)(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "request_id=") {
t.Fatalf("expected request_id in log output, got %q", buf.String())
}
}