Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is a configured trusted proxy (WithTrustedProxies), walking right-to-left to the first untrusted hop. This closes a trivial rate-limit bypass and the matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with opportunistic eviction of idle (fully-refilled) buckets to bound memory. Drop the hand-rolled indexOf in favor of stdlib.
203 lines
5.6 KiB
Go
203 lines
5.6 KiB
Go
package server_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.codelab.vc/pkg/httpx/server"
|
|
)
|
|
|
|
func TestRateLimit(t *testing.T) {
|
|
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
t.Run("allows requests within limit", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(100),
|
|
server.WithBurst(10),
|
|
)(okHandler)
|
|
|
|
for i := range 10 {
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(1),
|
|
server.WithBurst(2),
|
|
)(okHandler)
|
|
|
|
// Exhaust burst.
|
|
for range 2 {
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
}
|
|
|
|
// Next request should be rejected.
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
|
}
|
|
if w.Header().Get("Retry-After") == "" {
|
|
t.Fatal("expected Retry-After header")
|
|
}
|
|
})
|
|
|
|
t.Run("different IPs have independent limits", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(1),
|
|
server.WithBurst(1),
|
|
)(okHandler)
|
|
|
|
// First IP exhausts its limit.
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
// Second IP should still be allowed.
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "5.6.7.8:5678"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
|
|
// By default the limiter keys on RemoteAddr only. A spoofed,
|
|
// per-request X-Forwarded-For must not let a single peer bypass the
|
|
// limit by minting a fresh bucket each time.
|
|
mw := server.RateLimit(
|
|
server.WithRate(1),
|
|
server.WithBurst(1),
|
|
)(okHandler)
|
|
|
|
send := func(xff string) int {
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-Forwarded-For", xff)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
mw.ServeHTTP(w, req)
|
|
return w.Code
|
|
}
|
|
|
|
if code := send("10.0.0.1"); code != http.StatusOK {
|
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
|
}
|
|
// Different spoofed XFF, same peer — must still be limited.
|
|
if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
|
|
t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
|
|
}
|
|
})
|
|
|
|
t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(1),
|
|
server.WithBurst(1),
|
|
server.WithTrustedProxies("192.168.0.0/16"),
|
|
)(okHandler)
|
|
|
|
send := func(xff string) int {
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-Forwarded-For", xff)
|
|
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
|
|
mw.ServeHTTP(w, req)
|
|
return w.Code
|
|
}
|
|
|
|
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
|
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
|
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
|
}
|
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
|
|
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
|
|
}
|
|
// A different real client through the same proxy is independent.
|
|
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
|
|
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("custom key function", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(1),
|
|
server.WithBurst(1),
|
|
server.WithKeyFunc(func(r *http.Request) string {
|
|
return r.Header.Get("X-API-Key")
|
|
}),
|
|
)(okHandler)
|
|
|
|
// Exhaust key "abc".
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-API-Key", "abc")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
// Same key should be rate limited.
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-API-Key", "abc")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
// Different key should be allowed.
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-API-Key", "xyz")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("tokens refill over time", func(t *testing.T) {
|
|
mw := server.RateLimit(
|
|
server.WithRate(1000), // Very fast refill for test
|
|
server.WithBurst(1),
|
|
)(okHandler)
|
|
|
|
// Exhaust burst.
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
// Wait a bit for refill.
|
|
time.Sleep(5 * time.Millisecond)
|
|
|
|
w = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
|
}
|
|
})
|
|
}
|