package server_test import ( "net/http" "net/http/httptest" "testing" "time" "git.codelab.vc/pkg/httpx/server" ) func TestRateLimit(t *testing.T) { okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("allows requests within limit", func(t *testing.T) { mw := server.RateLimit( server.WithRate(100), server.WithBurst(10), )(okHandler) for i := range 10 { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK) } } }) t.Run("rejects when burst exhausted", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(2), )(okHandler) // Exhaust burst. for range 2 { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) } // Next request should be rejected. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) } if w.Header().Get("Retry-After") == "" { t.Fatal("expected Retry-After header") } }) t.Run("different IPs have independent limits", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), )(okHandler) // First IP exhausts its limit. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) // Second IP should still be allowed. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "5.6.7.8:5678" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) { // By default the limiter keys on RemoteAddr only. A spoofed, // per-request X-Forwarded-For must not let a single peer bypass the // limit by minting a fresh bucket each time. mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), )(okHandler) send := func(xff string) int { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Forwarded-For", xff) req.RemoteAddr = "192.168.1.1:1234" mw.ServeHTTP(w, req) return w.Code } if code := send("10.0.0.1"); code != http.StatusOK { t.Fatalf("first request: got %d, want %d", code, http.StatusOK) } // Different spoofed XFF, same peer — must still be limited. if code := send("10.0.0.2"); code != http.StatusTooManyRequests { t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests) } }) t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), server.WithTrustedProxies("192.168.0.0/16"), )(okHandler) send := func(xff string) int { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Forwarded-For", xff) req.RemoteAddr = "192.168.1.1:1234" // trusted proxy mw.ServeHTTP(w, req) return w.Code } // Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most). if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK { t.Fatalf("first request: got %d, want %d", code, http.StatusOK) } if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests { t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests) } // A different real client through the same proxy is independent. if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK { t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK) } }) t.Run("custom key function", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), server.WithKeyFunc(func(r *http.Request) string { return r.Header.Get("X-API-Key") }), )(okHandler) // Exhaust key "abc". w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "abc") mw.ServeHTTP(w, req) // Same key should be rate limited. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "abc") mw.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) } // Different key should be allowed. w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-API-Key", "xyz") mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) } }) t.Run("tokens refill over time", func(t *testing.T) { mw := server.RateLimit( server.WithRate(1000), // Very fast refill for test server.WithBurst(1), )(okHandler) // Exhaust burst. w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) // Wait a bit for refill. time.Sleep(5 * time.Millisecond) w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "1.2.3.4:1234" mw.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK) } }) }