Harden RateLimit against X-Forwarded-For spoofing

Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is
a configured trusted proxy (WithTrustedProxies), walking right-to-left to
the first untrusted hop. This closes a trivial rate-limit bypass and the
matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with
opportunistic eviction of idle (fully-refilled) buckets to bound memory.
Drop the hand-rolled indexOf in favor of stdlib.
This commit is contained in:
2026-05-23 13:47:08 +03:00
parent b6350185d9
commit 2d4a06e715
3 changed files with 278 additions and 66 deletions

View File

@@ -4,17 +4,25 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"git.codelab.vc/pkg/httpx/internal/clock" "git.codelab.vc/pkg/httpx/internal/clock"
) )
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
const defaultMaxKeys = 1 << 16
type rateLimitOptions struct { type rateLimitOptions struct {
rate float64 rate float64
burst int burst int
keyFunc func(r *http.Request) string keyFunc func(r *http.Request) string
clock clock.Clock clock clock.Clock
trustedProxies []*net.IPNet
maxKeys int
} }
// RateLimitOption configures the RateLimit middleware. // RateLimitOption configures the RateLimit middleware.
@@ -31,11 +39,55 @@ func WithBurst(n int) RateLimitOption {
} }
// WithKeyFunc sets a custom function to extract the rate-limit key from a // WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP address is used. // request. By default, the client IP from RemoteAddr is used (see
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn } return func(o *rateLimitOptions) { o.keyFunc = fn }
} }
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
// whose immediate peer (RemoteAddr) falls within one of the given trusted
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
// a /32 or /128. When the peer is trusted, the client key is taken from the
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
// untrusted), so a typo can never silently widen trust.
//
// Without this option the middleware never trusts client-supplied forwarding
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
// spoofed headers.
func WithTrustedProxies(cidrs ...string) RateLimitOption {
return func(o *rateLimitOptions) {
for _, c := range cidrs {
if _, ipnet, err := net.ParseCIDR(c); err == nil {
o.trustedProxies = append(o.trustedProxies, ipnet)
continue
}
if ip := net.ParseIP(c); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
IP: ip,
Mask: net.CIDRMask(bits, bits),
})
}
}
}
}
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
// retained in memory. When exceeded, idle (fully-refilled) buckets are
// evicted; active buckets are never dropped. Default is 65536.
func WithMaxKeys(n int) RateLimitOption {
return func(o *rateLimitOptions) {
if n > 0 {
o.maxKeys = n
}
}
}
// withRateLimitClock sets the clock for testing. Not exported. // withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption { func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c } return func(o *rateLimitOptions) { o.clock = c }
@@ -44,86 +96,153 @@ func withRateLimitClock(c clock.Clock) RateLimitOption {
// RateLimit returns a middleware that limits requests using a per-key token // RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many // bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header. // Requests with a Retry-After header.
//
// By default the key is the client IP taken from RemoteAddr. Forwarding
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
// so the limiter cannot be bypassed by spoofing headers.
func RateLimit(opts ...RateLimitOption) Middleware { func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{ o := &rateLimitOptions{
rate: 10, rate: 10,
burst: 20, burst: 20,
clock: clock.System(), clock: clock.System(),
maxKeys: defaultMaxKeys,
} }
for _, opt := range opts { for _, opt := range opts {
opt(o) opt(o)
} }
if o.keyFunc == nil { if o.keyFunc == nil {
o.keyFunc = clientIP o.keyFunc = o.clientKey
} }
var buckets sync.Map lim := &limiter{opts: o}
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r) key := o.keyFunc(r)
val, _ := buckets.LoadOrStore(key, &bucket{ if allowed, retryAfter := lim.allow(key); !allowed {
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
b := val.(*bucket)
b.mu.Lock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
retryAfter := (1 - b.tokens) / o.rate
b.mu.Unlock()
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests) http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return return
} }
b.tokens--
b.mu.Unlock()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
// limiter holds the per-key token buckets for one RateLimit middleware.
type limiter struct {
opts *rateLimitOptions
buckets sync.Map // key -> *bucket
count atomic.Int64
sweeping atomic.Bool
}
// allow reports whether a request for key may proceed. When denied it also
// returns the suggested Retry-After delay in seconds.
func (l *limiter) allow(key string) (bool, float64) {
o := l.opts
val, loaded := l.buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
l.sweep()
}
b := val.(*bucket)
b.mu.Lock()
defer b.mu.Unlock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
return false, (1 - b.tokens) / o.rate
}
b.tokens--
return true, 0
}
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
// runs at a time; buckets that still hold a partial limit are preserved so
// that eviction can never reset an active client's allowance.
func (l *limiter) sweep() {
if !l.sweeping.CompareAndSwap(false, true) {
return
}
defer l.sweeping.Store(false)
o := l.opts
now := o.clock.Now()
l.buckets.Range(func(k, v any) bool {
b := v.(*bucket)
b.mu.Lock()
elapsed := now.Sub(b.lastTime).Seconds()
full := b.tokens+elapsed*o.rate >= float64(o.burst)
b.mu.Unlock()
if full {
l.buckets.Delete(k)
l.count.Add(-1)
}
return true
})
}
type bucket struct { type bucket struct {
mu sync.Mutex mu sync.Mutex
tokens float64 tokens float64
lastTime time.Time lastTime time.Time
} }
// clientIP extracts the client IP from the request. It checks // clientKey derives the rate-limit key from a request. It uses RemoteAddr by
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr. // default and only consults X-Forwarded-For when the peer is a configured
func clientIP(r *http.Request) string { // trusted proxy (see WithTrustedProxies).
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { func (o *rateLimitOptions) clientKey(r *http.Request) string {
// First IP in the comma-separated list is the original client. remote := remoteIP(r)
if i := indexOf(xff, ','); i > 0 { if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
return xff[:i] return remote
}
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
// address that is not itself a trusted proxy — that is the real client.
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remote
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" || o.isTrusted(ip) {
continue
} }
return xff return ip
} }
if xri := r.Header.Get("X-Real-Ip"); xri != "" { return remote
return xri }
func (o *rateLimitOptions) isTrusted(ip string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
} }
for _, n := range o.trustedProxies {
if n.Contains(parsed) {
return true
}
}
return false
}
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
// has no port.
func remoteIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
return r.RemoteAddr return r.RemoteAddr
} }
return host return host
} }
func indexOf(s string, b byte) int {
for i := range len(s) {
if s[i] == b {
return i
}
}
return -1
}

View File

@@ -0,0 +1,62 @@
package server
import (
"net/http"
"testing"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
func newTestRequest(remoteAddr, xff string) *http.Request {
r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}}
if xff != "" {
r.Header.Set("X-Forwarded-For", xff)
}
return r
}
// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled
// (idle) buckets while preserving buckets that still hold an active limit, so
// memory is bounded without resetting live clients' allowances.
func TestLimiterSweepEvictsIdleBuckets(t *testing.T) {
clk := clock.Mock(time.Now())
o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30}
lim := &limiter{opts: o}
// "idle" makes a single request, then time passes so it refills to full.
lim.allow("idle")
clk.Advance(10 * time.Second)
// "active" drains its whole burst at the (advanced) current time.
for i := 0; i < 6; i++ {
lim.allow("active")
}
lim.sweep()
if _, ok := lim.buckets.Load("idle"); ok {
t.Error("fully-refilled idle bucket was not evicted")
}
if _, ok := lim.buckets.Load("active"); !ok {
t.Error("active bucket with a partial limit was wrongly evicted")
}
}
// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a
// trusted proxy, independent of the HTTP layer.
func TestClientKeyTrustedProxy(t *testing.T) {
o := &rateLimitOptions{}
WithTrustedProxies("192.168.0.0/16")(o)
r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10")
if got := o.clientKey(r); got != "203.0.113.7" {
t.Fatalf("clientKey = %q, want real client 203.0.113.7", got)
}
// Untrusted peer: X-Forwarded-For must be ignored entirely.
r = newTestRequest("203.0.113.7:443", "10.0.0.1")
if got := o.clientKey(r); got != "203.0.113.7" {
t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got)
}
}

View File

@@ -83,28 +83,59 @@ func TestRateLimit(t *testing.T) {
} }
}) })
t.Run("uses X-Forwarded-For", func(t *testing.T) { t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
// By default the limiter keys on RemoteAddr only. A spoofed,
// per-request X-Forwarded-For must not let a single peer bypass the
// limit by minting a fresh bucket each time.
mw := server.RateLimit( mw := server.RateLimit(
server.WithRate(1), server.WithRate(1),
server.WithBurst(1), server.WithBurst(1),
)(okHandler) )(okHandler)
// Exhaust limit for 10.0.0.1. send := func(xff string) int {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") req.Header.Set("X-Forwarded-For", xff)
req.RemoteAddr = "192.168.1.1:1234" req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req) mw.ServeHTTP(w, req)
return w.Code
}
// Same forwarded IP should be rate limited. if code := send("10.0.0.1"); code != http.StatusOK {
w = httptest.NewRecorder() t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
req = httptest.NewRequest(http.MethodGet, "/", nil) }
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") // Different spoofed XFF, same peer — must still be limited.
req.RemoteAddr = "192.168.1.1:1234" if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
mw.ServeHTTP(w, req) t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
}
})
if w.Code != http.StatusTooManyRequests { t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
server.WithTrustedProxies("192.168.0.0/16"),
)(okHandler)
send := func(xff string) int {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", xff)
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
mw.ServeHTTP(w, req)
return w.Code
}
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
}
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
}
// A different real client through the same proxy is independent.
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
} }
}) })