Files
httpx/server/middleware_ratelimit.go
Aleksey Shakhmatov 2d4a06e715 Harden RateLimit against X-Forwarded-For spoofing
Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is
a configured trusted proxy (WithTrustedProxies), walking right-to-left to
the first untrusted hop. This closes a trivial rate-limit bypass and the
matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with
opportunistic eviction of idle (fully-refilled) buckets to bound memory.
Drop the hand-rolled indexOf in favor of stdlib.
2026-05-23 13:47:08 +03:00

249 lines
6.8 KiB
Go

package server
import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
const defaultMaxKeys = 1 << 16
type rateLimitOptions struct {
rate float64
burst int
keyFunc func(r *http.Request) string
clock clock.Clock
trustedProxies []*net.IPNet
maxKeys int
}
// RateLimitOption configures the RateLimit middleware.
type RateLimitOption func(*rateLimitOptions)
// WithRate sets the token refill rate (tokens per second).
func WithRate(tokensPerSecond float64) RateLimitOption {
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
}
// WithBurst sets the maximum burst size (bucket capacity).
func WithBurst(n int) RateLimitOption {
return func(o *rateLimitOptions) { o.burst = n }
}
// WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP from RemoteAddr is used (see
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn }
}
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
// whose immediate peer (RemoteAddr) falls within one of the given trusted
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
// a /32 or /128. When the peer is trusted, the client key is taken from the
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
// untrusted), so a typo can never silently widen trust.
//
// Without this option the middleware never trusts client-supplied forwarding
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
// spoofed headers.
func WithTrustedProxies(cidrs ...string) RateLimitOption {
return func(o *rateLimitOptions) {
for _, c := range cidrs {
if _, ipnet, err := net.ParseCIDR(c); err == nil {
o.trustedProxies = append(o.trustedProxies, ipnet)
continue
}
if ip := net.ParseIP(c); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
IP: ip,
Mask: net.CIDRMask(bits, bits),
})
}
}
}
}
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
// retained in memory. When exceeded, idle (fully-refilled) buckets are
// evicted; active buckets are never dropped. Default is 65536.
func WithMaxKeys(n int) RateLimitOption {
return func(o *rateLimitOptions) {
if n > 0 {
o.maxKeys = n
}
}
}
// withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c }
}
// RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header.
//
// By default the key is the client IP taken from RemoteAddr. Forwarding
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
// so the limiter cannot be bypassed by spoofing headers.
func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{
rate: 10,
burst: 20,
clock: clock.System(),
maxKeys: defaultMaxKeys,
}
for _, opt := range opts {
opt(o)
}
if o.keyFunc == nil {
o.keyFunc = o.clientKey
}
lim := &limiter{opts: o}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r)
if allowed, retryAfter := lim.allow(key); !allowed {
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// limiter holds the per-key token buckets for one RateLimit middleware.
type limiter struct {
opts *rateLimitOptions
buckets sync.Map // key -> *bucket
count atomic.Int64
sweeping atomic.Bool
}
// allow reports whether a request for key may proceed. When denied it also
// returns the suggested Retry-After delay in seconds.
func (l *limiter) allow(key string) (bool, float64) {
o := l.opts
val, loaded := l.buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
l.sweep()
}
b := val.(*bucket)
b.mu.Lock()
defer b.mu.Unlock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
return false, (1 - b.tokens) / o.rate
}
b.tokens--
return true, 0
}
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
// runs at a time; buckets that still hold a partial limit are preserved so
// that eviction can never reset an active client's allowance.
func (l *limiter) sweep() {
if !l.sweeping.CompareAndSwap(false, true) {
return
}
defer l.sweeping.Store(false)
o := l.opts
now := o.clock.Now()
l.buckets.Range(func(k, v any) bool {
b := v.(*bucket)
b.mu.Lock()
elapsed := now.Sub(b.lastTime).Seconds()
full := b.tokens+elapsed*o.rate >= float64(o.burst)
b.mu.Unlock()
if full {
l.buckets.Delete(k)
l.count.Add(-1)
}
return true
})
}
type bucket struct {
mu sync.Mutex
tokens float64
lastTime time.Time
}
// clientKey derives the rate-limit key from a request. It uses RemoteAddr by
// default and only consults X-Forwarded-For when the peer is a configured
// trusted proxy (see WithTrustedProxies).
func (o *rateLimitOptions) clientKey(r *http.Request) string {
remote := remoteIP(r)
if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
return remote
}
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
// address that is not itself a trusted proxy — that is the real client.
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remote
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" || o.isTrusted(ip) {
continue
}
return ip
}
return remote
}
func (o *rateLimitOptions) isTrusted(ip string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, n := range o.trustedProxies {
if n.Contains(parsed) {
return true
}
}
return false
}
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
// has no port.
func remoteIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}