package server_test import ( "bytes" "log/slog" "net/http" "net/http/httptest" "regexp" "strings" "testing" "git.codelab.vc/pkg/httpx/server" ) func TestChain(t *testing.T) { t.Run("applies middlewares in correct order", func(t *testing.T) { var order []string mwA := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { order = append(order, "A-before") next.ServeHTTP(w, r) order = append(order, "A-after") }) } mwB := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { order = append(order, "B-before") next.ServeHTTP(w, r) order = append(order, "B-after") }) } handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { order = append(order, "handler") w.WriteHeader(http.StatusOK) }) chained := server.Chain(mwA, mwB)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) chained.ServeHTTP(w, req) expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"} if len(order) != len(expected) { t.Fatalf("got %v, want %v", order, expected) } for i, v := range expected { if order[i] != v { t.Fatalf("order[%d] = %q, want %q", i, order[i], v) } } }) t.Run("empty chain returns handler unchanged", func(t *testing.T) { called := false handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) chained := server.Chain()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) chained.ServeHTTP(w, req) if !called { t.Fatal("handler was not called") } }) } func TestChain_SingleMiddleware(t *testing.T) { var called bool mw := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true next.ServeHTTP(w, r) }) } handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) chained := server.Chain(mw)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) chained.ServeHTTP(w, req) if !called { t.Fatal("single middleware was not called") } if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } } func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusNotFound) // second call should not change captured status }) var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if !strings.Contains(buf.String(), "status=201") { t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String()) } } func TestStatusWriter_WriteDefaultsTo200(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("hello")) // Write without WriteHeader }) var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if !strings.Contains(buf.String(), "status=200") { t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String()) } } func TestStatusWriter_Unwrap(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { rc := http.NewResponseController(w) if err := rc.Flush(); err != nil { // httptest.ResponseRecorder implements Flusher, so this should succeed // if Unwrap works correctly. http.Error(w, "flush failed", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) }) // Use Logging to wrap in statusWriter var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if w.Code == http.StatusInternalServerError { t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher") } } func TestRequestID(t *testing.T) { t.Run("generates ID when not present", func(t *testing.T) { var gotID string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotID = server.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) mw := server.RequestID()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if gotID == "" { t.Fatal("expected request ID in context, got empty") } if w.Header().Get("X-Request-Id") != gotID { t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID) } // UUID v4 format: 8-4-4-4-12 hex chars. if len(gotID) != 36 { t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID) } }) t.Run("preserves existing ID", func(t *testing.T) { var gotID string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotID = server.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) mw := server.RequestID()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Request-Id", "custom-123") mw.ServeHTTP(w, req) if gotID != "custom-123" { t.Fatalf("got ID %q, want %q", gotID, "custom-123") } }) t.Run("context without ID returns empty", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if id := server.RequestIDFromContext(req.Context()); id != "" { t.Fatalf("expected empty, got %q", id) } }) } func TestRequestID_UUIDFormat(t *testing.T) { uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`) var gotID string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotID = server.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) mw := server.RequestID()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if !uuidV4Re.MatchString(gotID) { t.Fatalf("generated ID %q does not match UUID v4 format", gotID) } } func TestRequestID_Uniqueness(t *testing.T) { seen := make(map[string]struct{}, 1000) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := server.RequestIDFromContext(r.Context()) if _, exists := seen[id]; exists { t.Fatalf("duplicate request ID: %q", id) } seen[id] = struct{}{} w.WriteHeader(http.StatusOK) }) mw := server.RequestID()(handler) for i := 0; i < 1000; i++ { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) } if len(seen) != 1000 { t.Fatalf("expected 1000 unique IDs, got %d", len(seen)) } } func TestRecovery(t *testing.T) { t.Run("recovers from panic and returns 500", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { panic("something went wrong") }) mw := server.Recovery()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError) } }) t.Run("logs panic with logger", func(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { panic("boom") }) mw := server.Recovery(server.WithRecoveryLogger(logger))(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if !strings.Contains(buf.String(), "panic recovered") { t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String()) } if !strings.Contains(buf.String(), "boom") { t.Fatalf("expected log to contain 'boom', got %q", buf.String()) } }) t.Run("passes through without panic", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) mw := server.Recovery()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) } func TestRecovery_PanicWithNonString(t *testing.T) { tests := []struct { name string value any }{ {"integer", 42}, {"struct", struct{ X int }{X: 1}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { panic(tt.value) }) var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) mw := server.Recovery(server.WithRecoveryLogger(logger))(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/test", nil) mw.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError) } if !strings.Contains(buf.String(), "panic recovered") { t.Fatalf("expected 'panic recovered' in log, got %q", buf.String()) } }) } } func TestRecovery_ResponseBody(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { panic("fail") }) mw := server.Recovery()(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) body := strings.TrimSpace(w.Body.String()) if body != "Internal Server Error" { t.Fatalf("got body %q, want %q", body, "Internal Server Error") } } func TestRecovery_LogAttributes(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) // Put RequestID before Recovery so request_id is in context handler := server.RequestID()( server.Recovery(server.WithRecoveryLogger(logger))( http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { panic("boom") }), ), ) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/test", nil) handler.ServeHTTP(w, req) logOutput := buf.String() for _, attr := range []string{"method=", "path=", "request_id="} { if !strings.Contains(logOutput, attr) { t.Fatalf("expected %q in log, got %q", attr, logOutput) } } } func TestLogging(t *testing.T) { t.Run("logs request details", func(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusCreated) }) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/users", nil) mw.ServeHTTP(w, req) logOutput := buf.String() if !strings.Contains(logOutput, "request completed") { t.Fatalf("expected 'request completed' in log, got %q", logOutput) } if !strings.Contains(logOutput, "POST") { t.Fatalf("expected method in log, got %q", logOutput) } if !strings.Contains(logOutput, "/api/users") { t.Fatalf("expected path in log, got %q", logOutput) } if !strings.Contains(logOutput, "status=201") { t.Fatalf("expected status=201 in log, got %q", logOutput) } }) t.Run("logs error level for 5xx", func(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadGateway) }) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) logOutput := buf.String() if !strings.Contains(logOutput, "level=ERROR") { t.Fatalf("expected ERROR level in log, got %q", logOutput) } }) } func TestLogging_4xxIsInfoLevel(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) }) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/missing", nil) mw.ServeHTTP(w, req) logOutput := buf.String() if !strings.Contains(logOutput, "level=INFO") { t.Fatalf("expected INFO level for 404, got %q", logOutput) } if strings.Contains(logOutput, "level=ERROR") { t.Fatalf("404 should not be logged as ERROR, got %q", logOutput) } } func TestLogging_DefaultStatus200(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("hello")) }) mw := server.Logging(logger)(handler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if !strings.Contains(buf.String(), "status=200") { t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String()) } } func TestLogging_IncludesRequestID(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := server.RequestID()( server.Logging(logger)( http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }), ), ) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) handler.ServeHTTP(w, req) if !strings.Contains(buf.String(), "request_id=") { t.Fatalf("expected request_id in log output, got %q", buf.String()) } }