package server_test import ( "net/http" "net/http/httptest" "testing" "time" "git.codelab.vc/pkg/httpx/server" ) func TestRateLimit(t *testing.T) { okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("allows requests within limit", func(t *testing.T) { mw := server.RateLimit( server.WithRate(100), server.WithBurst(10), )(okHandler) for i := range 10 { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK) } } }) t.Run("rejects when burst exhausted", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(2), )(okHandler) // Exhaust burst. for range 2 { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) } // Next request should be rejected. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) } if w.Header().Get("Retry-After") == "" { t.Fatal("expected Retry-After header") } }) t.Run("different IPs have independent limits", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), )(okHandler) // First IP exhausts its limit. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) // Second IP should still be allowed. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "5.6.7.8:5678" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) t.Run("uses X-Forwarded-For", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), )(okHandler) // Exhaust limit for 10.0.0.1. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") req.RemoteAddr = "192.168.1.1:1234" mw.ServeHTTP(w, req) // Same forwarded IP should be rate limited. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") req.RemoteAddr = "192.168.1.1:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) } }) t.Run("custom key function", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), server.WithKeyFunc(func(r *http.Request) string { return r.Header.Get("X-API-Key") }), )(okHandler) // Exhaust key "abc". w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "abc") mw.ServeHTTP(w, req) // Same key should be rate limited. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "abc") mw.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) } // Different key should be allowed. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "xyz") mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) t.Run("tokens refill over time", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1000), // Very fast refill for test server.WithBurst(1), )(okHandler) // Exhaust burst. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) // Wait a bit for refill. time.Sleep(5 * time.Millisecond) w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK) } }) }