Add comprehensive test coverage for server/ package
All checks were successful
CI / test (push) Successful in 30s
All checks were successful
CI / test (push) Successful in 30s
Cover edge cases: statusWriter multi-call/default/unwrap, UUID v4 format and uniqueness, non-string panics, recovery body and log attributes, 4xx log level, default status in logging, request ID propagation, server defaults/options/listen-error/multiple-hooks/logger, router groups with empty prefix/inherited middleware/ordering/path params/ isolation, mount trailing slash, health content-type and POST rejection. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -88,3 +88,79 @@ func TestHealthHandler(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealth_MultipleFailingCheckers(t *testing.T) {
|
||||
h := server.HealthHandler(
|
||||
func() error { return errors.New("db down") },
|
||||
func() error { return errors.New("cache down") },
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
errs, ok := resp["errors"].([]any)
|
||||
if !ok || len(errs) != 2 {
|
||||
t.Fatalf("expected 2 errors, got %v", resp["errors"])
|
||||
}
|
||||
|
||||
errStrs := make(map[string]bool)
|
||||
for _, e := range errs {
|
||||
errStrs[e.(string)] = true
|
||||
}
|
||||
if !errStrs["db down"] || !errStrs["cache down"] {
|
||||
t.Fatalf("expected 'db down' and 'cache down', got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_LivenessContentType(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if ct != "application/json" {
|
||||
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_ReadinessContentType(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if ct != "application/json" {
|
||||
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_PostMethodNotAllowed(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
t.Run("POST "+path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
// ServeMux with "GET /healthz" pattern should reject POST.
|
||||
if w.Code == http.StatusOK {
|
||||
t.Fatalf("POST %s should not return 200, got %d", path, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -72,6 +73,96 @@ func TestChain(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChain_SingleMiddleware(t *testing.T) {
|
||||
var called bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
chained := server.Chain(mw)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
chained.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("single middleware was not called")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.WriteHeader(http.StatusNotFound) // second call should not change captured status
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=201") {
|
||||
t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_WriteDefaultsTo200(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello")) // Write without WriteHeader
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=200") {
|
||||
t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_Unwrap(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
rc := http.NewResponseController(w)
|
||||
if err := rc.Flush(); err != nil {
|
||||
// httptest.ResponseRecorder implements Flusher, so this should succeed
|
||||
// if Unwrap works correctly.
|
||||
http.Error(w, "flush failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Use Logging to wrap in statusWriter
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusInternalServerError {
|
||||
t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
t.Run("generates ID when not present", func(t *testing.T) {
|
||||
var gotID string
|
||||
@@ -125,6 +216,51 @@ func TestRequestID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||
uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
|
||||
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !uuidV4Re.MatchString(gotID) {
|
||||
t.Fatalf("generated ID %q does not match UUID v4 format", gotID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 1000)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := server.RequestIDFromContext(r.Context())
|
||||
if _, exists := seen[id]; exists {
|
||||
t.Fatalf("duplicate request ID: %q", id)
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
if len(seen) != 1000 {
|
||||
t.Fatalf("expected 1000 unique IDs, got %d", len(seen))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -182,6 +318,81 @@ func TestRecovery(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRecovery_PanicWithNonString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
}{
|
||||
{"integer", 42},
|
||||
{"struct", struct{ X int }{X: 1}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic(tt.value)
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if !strings.Contains(buf.String(), "panic recovered") {
|
||||
t.Fatalf("expected 'panic recovered' in log, got %q", buf.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery_ResponseBody(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("fail")
|
||||
})
|
||||
|
||||
mw := server.Recovery()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
body := strings.TrimSpace(w.Body.String())
|
||||
if body != "Internal Server Error" {
|
||||
t.Fatalf("got body %q, want %q", body, "Internal Server Error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery_LogAttributes(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
// Put RequestID before Recovery so request_id is in context
|
||||
handler := server.RequestID()(
|
||||
server.Recovery(server.WithRecoveryLogger(logger))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("boom")
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
for _, attr := range []string{"method=", "path=", "request_id="} {
|
||||
if !strings.Contains(logOutput, attr) {
|
||||
t.Fatalf("expected %q in log, got %q", attr, logOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging(t *testing.T) {
|
||||
t.Run("logs request details", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
@@ -232,3 +443,66 @@ func TestLogging(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogging_4xxIsInfoLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "level=INFO") {
|
||||
t.Fatalf("expected INFO level for 404, got %q", logOutput)
|
||||
}
|
||||
if strings.Contains(logOutput, "level=ERROR") {
|
||||
t.Fatalf("404 should not be logged as ERROR, got %q", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging_DefaultStatus200(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=200") {
|
||||
t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging_IncludesRequestID(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := server.RequestID()(
|
||||
server.Logging(logger)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "request_id=") {
|
||||
t.Fatalf("expected request_id in log output, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,4 +140,197 @@ func TestRouterMount(t *testing.T) {
|
||||
t.Fatalf("got body %q, want %q", body, "info")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mount with trailing slash", func(t *testing.T) {
|
||||
sub := http.NewServeMux()
|
||||
sub.HandleFunc("GET /data", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("data"))
|
||||
})
|
||||
|
||||
r := server.NewRouter()
|
||||
r.Mount("/sub/", sub)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/sub/data", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
if string(body) != "data" {
|
||||
t.Fatalf("got body %q, want %q", body, "data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouter_PatternWithoutMethod(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("/static/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("static"))
|
||||
})
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodPost} {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(method, "/static/file.css", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("%s /static/file.css: got status %d, want %d", method, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupEmptyPrefix(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
g := r.Group("")
|
||||
g.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "hello" {
|
||||
t.Fatalf("got body %q, want %q", body, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupInheritsMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
parentMW := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "parent")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
childMW := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "child")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
parent := r.Group("/api", parentMW)
|
||||
child := parent.Group("/v1", childMW)
|
||||
child.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expected := []string{"parent", "child", "handler"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupMiddlewareOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
mwA := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "A")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
mwB := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "B")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
g := r.Group("/api", mwA)
|
||||
sub := g.Group("/v1", mwB)
|
||||
sub.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Parent MW (A) should run before child MW (B), then handler.
|
||||
expected := []string{"A", "B", "handler"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_PathParamWithGroup(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
api := r.Group("/api")
|
||||
api.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
_, _ = w.Write([]byte("id=" + req.PathValue("id")))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/42", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "id=42" {
|
||||
t.Fatalf("got body %q, want %q", body, "id=42")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MiddlewareNotAppliedToOtherRoutes(t *testing.T) {
|
||||
var mwCalled bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mwCalled = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
|
||||
// Add middleware only to /admin group.
|
||||
admin := r.Group("/admin", mw)
|
||||
admin.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("admin"))
|
||||
})
|
||||
|
||||
// Route outside the group.
|
||||
r.HandleFunc("GET /public", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("public"))
|
||||
})
|
||||
|
||||
// Request to /public should NOT trigger group middleware.
|
||||
mwCalled = false
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/public", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if mwCalled {
|
||||
t.Fatal("group middleware should not be called for routes outside the group")
|
||||
}
|
||||
if w.Body.String() != "public" {
|
||||
t.Fatalf("got body %q, want %q", w.Body.String(), "public")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -117,6 +120,140 @@ func TestServerWithMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerDefaults(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler, append(server.Defaults(logger), server.WithAddr(":0"))...)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Defaults includes RequestID middleware, so response should have X-Request-Id.
|
||||
if resp.Header.Get("X-Request-Id") == "" {
|
||||
t.Fatal("expected X-Request-Id header from Defaults middleware")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestServerListenError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Use an invalid address to trigger a listen error.
|
||||
srv := server.New(handler, server.WithAddr(":-1"))
|
||||
|
||||
err := srv.ListenAndServe()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from invalid address, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMultipleOnShutdownHooks(t *testing.T) {
|
||||
var calls []int
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 1) }),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 2) }),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 3) }),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 3 {
|
||||
t.Fatalf("expected 3 hooks called, got %d: %v", len(calls), calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerShutdownWithLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithLogger(logger),
|
||||
)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
// Send SIGINT to trigger graceful shutdown via ListenAndServe's signal handler.
|
||||
// Instead, use Shutdown directly and check log from server start.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
|
||||
// The server logs "server started" on ListenAndServe.
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "server started") {
|
||||
t.Fatalf("expected 'server started' in log, got %q", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerOptions(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Verify options don't panic and server starts correctly.
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithReadTimeout(5*time.Second),
|
||||
server.WithReadHeaderTimeout(3*time.Second),
|
||||
server.WithWriteTimeout(10*time.Second),
|
||||
server.WithIdleTimeout(60*time.Second),
|
||||
server.WithShutdownTimeout(5*time.Second),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// waitForAddr polls until the server's Addr() is non-empty.
|
||||
func waitForAddr(t *testing.T, srv *server.Server) {
|
||||
t.Helper()
|
||||
|
||||
Reference in New Issue
Block a user