Add server RateLimit middleware with per-key token bucket

Protects against abuse with configurable rate/burst per client IP.
Supports custom key functions, X-Forwarded-For extraction, and
Retry-After headers on 429 responses. Uses internal/clock for
testability.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-22 21:47:51 +03:00
parent 7a2cef00c3
commit 3395f70abd
2 changed files with 300 additions and 0 deletions

View File

@@ -0,0 +1,129 @@
package server
import (
"net"
"net/http"
"strconv"
"sync"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
type rateLimitOptions struct {
rate float64
burst int
keyFunc func(r *http.Request) string
clock clock.Clock
}
// RateLimitOption configures the RateLimit middleware.
type RateLimitOption func(*rateLimitOptions)
// WithRate sets the token refill rate (tokens per second).
func WithRate(tokensPerSecond float64) RateLimitOption {
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
}
// WithBurst sets the maximum burst size (bucket capacity).
func WithBurst(n int) RateLimitOption {
return func(o *rateLimitOptions) { o.burst = n }
}
// WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP address is used.
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn }
}
// withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c }
}
// RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header.
func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{
rate: 10,
burst: 20,
clock: clock.System(),
}
for _, opt := range opts {
opt(o)
}
if o.keyFunc == nil {
o.keyFunc = clientIP
}
var buckets sync.Map
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r)
val, _ := buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
b := val.(*bucket)
b.mu.Lock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
retryAfter := (1 - b.tokens) / o.rate
b.mu.Unlock()
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
b.tokens--
b.mu.Unlock()
next.ServeHTTP(w, r)
})
}
}
type bucket struct {
mu sync.Mutex
tokens float64
lastTime time.Time
}
// clientIP extracts the client IP from the request. It checks
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// First IP in the comma-separated list is the original client.
if i := indexOf(xff, ','); i > 0 {
return xff[:i]
}
return xff
}
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
return xri
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func indexOf(s string, b byte) int {
for i := range len(s) {
if s[i] == b {
return i
}
}
return -1
}