Add server RateLimit middleware with per-key token bucket
Protects against abuse with configurable rate/burst per client IP. Supports custom key functions, X-Forwarded-For extraction, and Retry-After headers on 429 responses. Uses internal/clock for testability. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
129
server/middleware_ratelimit.go
Normal file
129
server/middleware_ratelimit.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rateLimitOptions struct {
|
||||||
|
rate float64
|
||||||
|
burst int
|
||||||
|
keyFunc func(r *http.Request) string
|
||||||
|
clock clock.Clock
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitOption configures the RateLimit middleware.
|
||||||
|
type RateLimitOption func(*rateLimitOptions)
|
||||||
|
|
||||||
|
// WithRate sets the token refill rate (tokens per second).
|
||||||
|
func WithRate(tokensPerSecond float64) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBurst sets the maximum burst size (bucket capacity).
|
||||||
|
func WithBurst(n int) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.burst = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||||
|
// request. By default, the client IP address is used.
|
||||||
|
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRateLimitClock sets the clock for testing. Not exported.
|
||||||
|
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.clock = c }
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimit returns a middleware that limits requests using a per-key token
|
||||||
|
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||||
|
// Requests with a Retry-After header.
|
||||||
|
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||||
|
o := &rateLimitOptions{
|
||||||
|
rate: 10,
|
||||||
|
burst: 20,
|
||||||
|
clock: clock.System(),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
if o.keyFunc == nil {
|
||||||
|
o.keyFunc = clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
var buckets sync.Map
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
key := o.keyFunc(r)
|
||||||
|
val, _ := buckets.LoadOrStore(key, &bucket{
|
||||||
|
tokens: float64(o.burst),
|
||||||
|
lastTime: o.clock.Now(),
|
||||||
|
})
|
||||||
|
b := val.(*bucket)
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
now := o.clock.Now()
|
||||||
|
elapsed := now.Sub(b.lastTime).Seconds()
|
||||||
|
b.tokens += elapsed * o.rate
|
||||||
|
if b.tokens > float64(o.burst) {
|
||||||
|
b.tokens = float64(o.burst)
|
||||||
|
}
|
||||||
|
b.lastTime = now
|
||||||
|
|
||||||
|
if b.tokens < 1 {
|
||||||
|
retryAfter := (1 - b.tokens) / o.rate
|
||||||
|
b.mu.Unlock()
|
||||||
|
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||||
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.tokens--
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type bucket struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens float64
|
||||||
|
lastTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientIP extracts the client IP from the request. It checks
|
||||||
|
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
|
||||||
|
func clientIP(r *http.Request) string {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// First IP in the comma-separated list is the original client.
|
||||||
|
if i := indexOf(xff, ','); i > 0 {
|
||||||
|
return xff[:i]
|
||||||
|
}
|
||||||
|
return xff
|
||||||
|
}
|
||||||
|
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||||
|
return xri
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexOf(s string, b byte) int {
|
||||||
|
for i := range len(s) {
|
||||||
|
if s[i] == b {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
171
server/middleware_ratelimit_test.go
Normal file
171
server/middleware_ratelimit_test.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimit(t *testing.T) {
|
||||||
|
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allows requests within limit", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(100),
|
||||||
|
server.WithBurst(10),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
for i := range 10 {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(2),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust burst.
|
||||||
|
for range 2 {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request should be rejected.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if w.Header().Get("Retry-After") == "" {
|
||||||
|
t.Fatal("expected Retry-After header")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different IPs have independent limits", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// First IP exhausts its limit.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Second IP should still be allowed.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "5.6.7.8:5678"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust limit for 10.0.0.1.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||||
|
req.RemoteAddr = "192.168.1.1:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Same forwarded IP should be rate limited.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||||
|
req.RemoteAddr = "192.168.1.1:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom key function", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
server.WithKeyFunc(func(r *http.Request) string {
|
||||||
|
return r.Header.Get("X-API-Key")
|
||||||
|
}),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust key "abc".
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "abc")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Same key should be rate limited.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "abc")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different key should be allowed.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "xyz")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tokens refill over time", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1000), // Very fast refill for test
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust burst.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Wait a bit for refill.
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user