package server_test import ( "net/http" "net/http/httptest" "testing" "time" "git.codelab.vc/pkg/httpx/server" ) func TestTimeout(t *testing.T) { t.Run("handler completes within timeout", func(t *testing.T) { handler := server.Timeout(1 * time.Second)( http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) }), ) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) t.Run("handler exceeds timeout returns 503", func(t *testing.T) { handler := server.Timeout(10 * time.Millisecond)( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-time.After(1 * time.Second): case <-r.Context().Done(): } w.WriteHeader(http.StatusOK) }), ) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) handler.ServeHTTP(w, req) if w.Code != http.StatusServiceUnavailable { t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable) } }) }