From 2d4a06e71546b83f7989f88119a147881b33cb83 Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Sat, 23 May 2026 13:47:08 +0300 Subject: [PATCH] Harden RateLimit against X-Forwarded-For spoofing Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is a configured trusted proxy (WithTrustedProxies), walking right-to-left to the first untrusted hop. This closes a trivial rate-limit bypass and the matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with opportunistic eviction of idle (fully-refilled) buckets to bound memory. Drop the hand-rolled indexOf in favor of stdlib. --- server/middleware_ratelimit.go | 221 ++++++++++++++----- server/middleware_ratelimit_internal_test.go | 62 ++++++ server/middleware_ratelimit_test.go | 61 +++-- 3 files changed, 278 insertions(+), 66 deletions(-) create mode 100644 server/middleware_ratelimit_internal_test.go diff --git a/server/middleware_ratelimit.go b/server/middleware_ratelimit.go index a95614a..5361072 100644 --- a/server/middleware_ratelimit.go +++ b/server/middleware_ratelimit.go @@ -4,17 +4,25 @@ import ( "net" "net/http" "strconv" + "strings" "sync" + "sync/atomic" "time" "git.codelab.vc/pkg/httpx/internal/clock" ) +// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in +// memory. When exceeded, fully-refilled (idle) buckets are evicted. +const defaultMaxKeys = 1 << 16 + type rateLimitOptions struct { - rate float64 - burst int - keyFunc func(r *http.Request) string - clock clock.Clock + rate float64 + burst int + keyFunc func(r *http.Request) string + clock clock.Clock + trustedProxies []*net.IPNet + maxKeys int } // RateLimitOption configures the RateLimit middleware. @@ -31,11 +39,55 @@ func WithBurst(n int) RateLimitOption { } // WithKeyFunc sets a custom function to extract the rate-limit key from a -// request. By default, the client IP address is used. +// request. By default, the client IP from RemoteAddr is used (see +// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy). func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { return func(o *rateLimitOptions) { o.keyFunc = fn } } +// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests +// whose immediate peer (RemoteAddr) falls within one of the given trusted +// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as +// a /32 or /128. When the peer is trusted, the client key is taken from the +// right-most X-Forwarded-For entry that is not itself a trusted proxy; +// otherwise RemoteAddr is used. Invalid entries are ignored (treated as +// untrusted), so a typo can never silently widen trust. +// +// Without this option the middleware never trusts client-supplied forwarding +// headers, which prevents trivial rate-limit bypass and bucket exhaustion via +// spoofed headers. +func WithTrustedProxies(cidrs ...string) RateLimitOption { + return func(o *rateLimitOptions) { + for _, c := range cidrs { + if _, ipnet, err := net.ParseCIDR(c); err == nil { + o.trustedProxies = append(o.trustedProxies, ipnet) + continue + } + if ip := net.ParseIP(c); ip != nil { + bits := 32 + if ip.To4() == nil { + bits = 128 + } + o.trustedProxies = append(o.trustedProxies, &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(bits, bits), + }) + } + } + } +} + +// WithMaxKeys sets the soft upper bound on the number of distinct buckets +// retained in memory. When exceeded, idle (fully-refilled) buckets are +// evicted; active buckets are never dropped. Default is 65536. +func WithMaxKeys(n int) RateLimitOption { + return func(o *rateLimitOptions) { + if n > 0 { + o.maxKeys = n + } + } +} + // withRateLimitClock sets the clock for testing. Not exported. func withRateLimitClock(c clock.Clock) RateLimitOption { return func(o *rateLimitOptions) { o.clock = c } @@ -44,86 +96,153 @@ func withRateLimitClock(c clock.Clock) RateLimitOption { // RateLimit returns a middleware that limits requests using a per-key token // bucket algorithm. When the limit is exceeded, it returns 429 Too Many // Requests with a Retry-After header. +// +// By default the key is the client IP taken from RemoteAddr. Forwarding +// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set, +// so the limiter cannot be bypassed by spoofing headers. func RateLimit(opts ...RateLimitOption) Middleware { o := &rateLimitOptions{ - rate: 10, - burst: 20, - clock: clock.System(), + rate: 10, + burst: 20, + clock: clock.System(), + maxKeys: defaultMaxKeys, } for _, opt := range opts { opt(o) } if o.keyFunc == nil { - o.keyFunc = clientIP + o.keyFunc = o.clientKey } - var buckets sync.Map + lim := &limiter{opts: o} return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := o.keyFunc(r) - val, _ := buckets.LoadOrStore(key, &bucket{ - tokens: float64(o.burst), - lastTime: o.clock.Now(), - }) - b := val.(*bucket) - - b.mu.Lock() - now := o.clock.Now() - elapsed := now.Sub(b.lastTime).Seconds() - b.tokens += elapsed * o.rate - if b.tokens > float64(o.burst) { - b.tokens = float64(o.burst) - } - b.lastTime = now - - if b.tokens < 1 { - retryAfter := (1 - b.tokens) / o.rate - b.mu.Unlock() + if allowed, retryAfter := lim.allow(key); !allowed { w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } - - b.tokens-- - b.mu.Unlock() - next.ServeHTTP(w, r) }) } } +// limiter holds the per-key token buckets for one RateLimit middleware. +type limiter struct { + opts *rateLimitOptions + buckets sync.Map // key -> *bucket + count atomic.Int64 + sweeping atomic.Bool +} + +// allow reports whether a request for key may proceed. When denied it also +// returns the suggested Retry-After delay in seconds. +func (l *limiter) allow(key string) (bool, float64) { + o := l.opts + val, loaded := l.buckets.LoadOrStore(key, &bucket{ + tokens: float64(o.burst), + lastTime: o.clock.Now(), + }) + if !loaded && l.count.Add(1) > int64(o.maxKeys) { + l.sweep() + } + b := val.(*bucket) + + b.mu.Lock() + defer b.mu.Unlock() + + now := o.clock.Now() + elapsed := now.Sub(b.lastTime).Seconds() + b.tokens += elapsed * o.rate + if b.tokens > float64(o.burst) { + b.tokens = float64(o.burst) + } + b.lastTime = now + + if b.tokens < 1 { + return false, (1 - b.tokens) / o.rate + } + b.tokens-- + return true, 0 +} + +// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep +// runs at a time; buckets that still hold a partial limit are preserved so +// that eviction can never reset an active client's allowance. +func (l *limiter) sweep() { + if !l.sweeping.CompareAndSwap(false, true) { + return + } + defer l.sweeping.Store(false) + + o := l.opts + now := o.clock.Now() + l.buckets.Range(func(k, v any) bool { + b := v.(*bucket) + b.mu.Lock() + elapsed := now.Sub(b.lastTime).Seconds() + full := b.tokens+elapsed*o.rate >= float64(o.burst) + b.mu.Unlock() + if full { + l.buckets.Delete(k) + l.count.Add(-1) + } + return true + }) +} + type bucket struct { mu sync.Mutex tokens float64 lastTime time.Time } -// clientIP extracts the client IP from the request. It checks -// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr. -func clientIP(r *http.Request) string { - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - // First IP in the comma-separated list is the original client. - if i := indexOf(xff, ','); i > 0 { - return xff[:i] +// clientKey derives the rate-limit key from a request. It uses RemoteAddr by +// default and only consults X-Forwarded-For when the peer is a configured +// trusted proxy (see WithTrustedProxies). +func (o *rateLimitOptions) clientKey(r *http.Request) string { + remote := remoteIP(r) + if len(o.trustedProxies) == 0 || !o.isTrusted(remote) { + return remote + } + // Peer is trusted: walk X-Forwarded-For right-to-left and return the first + // address that is not itself a trusted proxy — that is the real client. + xff := r.Header.Get("X-Forwarded-For") + if xff == "" { + return remote + } + parts := strings.Split(xff, ",") + for i := len(parts) - 1; i >= 0; i-- { + ip := strings.TrimSpace(parts[i]) + if ip == "" || o.isTrusted(ip) { + continue } - return xff + return ip } - if xri := r.Header.Get("X-Real-Ip"); xri != "" { - return xri + return remote +} + +func (o *rateLimitOptions) isTrusted(ip string) bool { + parsed := net.ParseIP(ip) + if parsed == nil { + return false } + for _, n := range o.trustedProxies { + if n.Contains(parsed) { + return true + } + } + return false +} + +// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it +// has no port. +func remoteIP(r *http.Request) string { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } - -func indexOf(s string, b byte) int { - for i := range len(s) { - if s[i] == b { - return i - } - } - return -1 -} diff --git a/server/middleware_ratelimit_internal_test.go b/server/middleware_ratelimit_internal_test.go new file mode 100644 index 0000000..8df7d43 --- /dev/null +++ b/server/middleware_ratelimit_internal_test.go @@ -0,0 +1,62 @@ +package server + +import ( + "net/http" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/internal/clock" +) + +func newTestRequest(remoteAddr, xff string) *http.Request { + r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}} + if xff != "" { + r.Header.Set("X-Forwarded-For", xff) + } + return r +} + +// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled +// (idle) buckets while preserving buckets that still hold an active limit, so +// memory is bounded without resetting live clients' allowances. +func TestLimiterSweepEvictsIdleBuckets(t *testing.T) { + clk := clock.Mock(time.Now()) + o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30} + lim := &limiter{opts: o} + + // "idle" makes a single request, then time passes so it refills to full. + lim.allow("idle") + clk.Advance(10 * time.Second) + + // "active" drains its whole burst at the (advanced) current time. + for i := 0; i < 6; i++ { + lim.allow("active") + } + + lim.sweep() + + if _, ok := lim.buckets.Load("idle"); ok { + t.Error("fully-refilled idle bucket was not evicted") + } + if _, ok := lim.buckets.Load("active"); !ok { + t.Error("active bucket with a partial limit was wrongly evicted") + } +} + +// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a +// trusted proxy, independent of the HTTP layer. +func TestClientKeyTrustedProxy(t *testing.T) { + o := &rateLimitOptions{} + WithTrustedProxies("192.168.0.0/16")(o) + + r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10") + if got := o.clientKey(r); got != "203.0.113.7" { + t.Fatalf("clientKey = %q, want real client 203.0.113.7", got) + } + + // Untrusted peer: X-Forwarded-For must be ignored entirely. + r = newTestRequest("203.0.113.7:443", "10.0.0.1") + if got := o.clientKey(r); got != "203.0.113.7" { + t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got) + } +} diff --git a/server/middleware_ratelimit_test.go b/server/middleware_ratelimit_test.go index dd3bffe..b04128f 100644 --- a/server/middleware_ratelimit_test.go +++ b/server/middleware_ratelimit_test.go @@ -83,28 +83,59 @@ func TestRateLimit(t *testing.T) { } }) - t.Run("uses X-Forwarded-For", func(t *testing.T) { + t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) { + // By default the limiter keys on RemoteAddr only. A spoofed, + // per-request X-Forwarded-For must not let a single peer bypass the + // limit by minting a fresh bucket each time. mw := server.RateLimit( server.WithRate(1), server.WithBurst(1), )(okHandler) - // Exhaust limit for 10.0.0.1. - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") - req.RemoteAddr = "192.168.1.1:1234" - mw.ServeHTTP(w, req) + send := func(xff string) int { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", xff) + req.RemoteAddr = "192.168.1.1:1234" + mw.ServeHTTP(w, req) + return w.Code + } - // Same forwarded IP should be rate limited. - w = httptest.NewRecorder() - req = httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") - req.RemoteAddr = "192.168.1.1:1234" - mw.ServeHTTP(w, req) + if code := send("10.0.0.1"); code != http.StatusOK { + t.Fatalf("first request: got %d, want %d", code, http.StatusOK) + } + // Different spoofed XFF, same peer — must still be limited. + if code := send("10.0.0.2"); code != http.StatusTooManyRequests { + t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests) + } + }) - if w.Code != http.StatusTooManyRequests { - t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) + t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1), + server.WithBurst(1), + server.WithTrustedProxies("192.168.0.0/16"), + )(okHandler) + + send := func(xff string) int { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", xff) + req.RemoteAddr = "192.168.1.1:1234" // trusted proxy + mw.ServeHTTP(w, req) + return w.Code + } + + // Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most). + if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK { + t.Fatalf("first request: got %d, want %d", code, http.StatusOK) + } + if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests { + t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests) + } + // A different real client through the same proxy is independent. + if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK { + t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK) } })