Add server RateLimit middleware with per-key token bucket
Protects against abuse with configurable rate/burst per client IP. Supports custom key functions, X-Forwarded-For extraction, and Retry-After headers on 429 responses. Uses internal/clock for testability. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
171
server/middleware_ratelimit_test.go
Normal file
171
server/middleware_ratelimit_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(10),
|
||||
)(okHandler)
|
||||
|
||||
for i := range 10 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(2),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Next request should be rejected.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("expected Retry-After header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different IPs have independent limits", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// First IP exhausts its limit.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Second IP should still be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "5.6.7.8:5678"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust limit for 10.0.0.1.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same forwarded IP should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom key function", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
server.WithKeyFunc(func(r *http.Request) string {
|
||||
return r.Header.Get("X-API-Key")
|
||||
}),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust key "abc".
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same key should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Different key should be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "xyz")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tokens refill over time", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1000), // Very fast refill for test
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Wait a bit for refill.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user