Harden RateLimit against X-Forwarded-For spoofing
Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is a configured trusted proxy (WithTrustedProxies), walking right-to-left to the first untrusted hop. This closes a trivial rate-limit bypass and the matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with opportunistic eviction of idle (fully-refilled) buckets to bound memory. Drop the hand-rolled indexOf in favor of stdlib.
This commit is contained in:
@@ -4,17 +4,25 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
|
||||||
|
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
|
||||||
|
const defaultMaxKeys = 1 << 16
|
||||||
|
|
||||||
type rateLimitOptions struct {
|
type rateLimitOptions struct {
|
||||||
rate float64
|
rate float64
|
||||||
burst int
|
burst int
|
||||||
keyFunc func(r *http.Request) string
|
keyFunc func(r *http.Request) string
|
||||||
clock clock.Clock
|
clock clock.Clock
|
||||||
|
trustedProxies []*net.IPNet
|
||||||
|
maxKeys int
|
||||||
}
|
}
|
||||||
|
|
||||||
// RateLimitOption configures the RateLimit middleware.
|
// RateLimitOption configures the RateLimit middleware.
|
||||||
@@ -31,11 +39,55 @@ func WithBurst(n int) RateLimitOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||||
// request. By default, the client IP address is used.
|
// request. By default, the client IP from RemoteAddr is used (see
|
||||||
|
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
|
||||||
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||||
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
|
||||||
|
// whose immediate peer (RemoteAddr) falls within one of the given trusted
|
||||||
|
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
|
||||||
|
// a /32 or /128. When the peer is trusted, the client key is taken from the
|
||||||
|
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
|
||||||
|
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
|
||||||
|
// untrusted), so a typo can never silently widen trust.
|
||||||
|
//
|
||||||
|
// Without this option the middleware never trusts client-supplied forwarding
|
||||||
|
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
|
||||||
|
// spoofed headers.
|
||||||
|
func WithTrustedProxies(cidrs ...string) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) {
|
||||||
|
for _, c := range cidrs {
|
||||||
|
if _, ipnet, err := net.ParseCIDR(c); err == nil {
|
||||||
|
o.trustedProxies = append(o.trustedProxies, ipnet)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(c); ip != nil {
|
||||||
|
bits := 32
|
||||||
|
if ip.To4() == nil {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: net.CIDRMask(bits, bits),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
|
||||||
|
// retained in memory. When exceeded, idle (fully-refilled) buckets are
|
||||||
|
// evicted; active buckets are never dropped. Default is 65536.
|
||||||
|
func WithMaxKeys(n int) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) {
|
||||||
|
if n > 0 {
|
||||||
|
o.maxKeys = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// withRateLimitClock sets the clock for testing. Not exported.
|
// withRateLimitClock sets the clock for testing. Not exported.
|
||||||
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||||
return func(o *rateLimitOptions) { o.clock = c }
|
return func(o *rateLimitOptions) { o.clock = c }
|
||||||
@@ -44,86 +96,153 @@ func withRateLimitClock(c clock.Clock) RateLimitOption {
|
|||||||
// RateLimit returns a middleware that limits requests using a per-key token
|
// RateLimit returns a middleware that limits requests using a per-key token
|
||||||
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||||
// Requests with a Retry-After header.
|
// Requests with a Retry-After header.
|
||||||
|
//
|
||||||
|
// By default the key is the client IP taken from RemoteAddr. Forwarding
|
||||||
|
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
|
||||||
|
// so the limiter cannot be bypassed by spoofing headers.
|
||||||
func RateLimit(opts ...RateLimitOption) Middleware {
|
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||||
o := &rateLimitOptions{
|
o := &rateLimitOptions{
|
||||||
rate: 10,
|
rate: 10,
|
||||||
burst: 20,
|
burst: 20,
|
||||||
clock: clock.System(),
|
clock: clock.System(),
|
||||||
|
maxKeys: defaultMaxKeys,
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(o)
|
opt(o)
|
||||||
}
|
}
|
||||||
if o.keyFunc == nil {
|
if o.keyFunc == nil {
|
||||||
o.keyFunc = clientIP
|
o.keyFunc = o.clientKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var buckets sync.Map
|
lim := &limiter{opts: o}
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
key := o.keyFunc(r)
|
key := o.keyFunc(r)
|
||||||
val, _ := buckets.LoadOrStore(key, &bucket{
|
if allowed, retryAfter := lim.allow(key); !allowed {
|
||||||
tokens: float64(o.burst),
|
|
||||||
lastTime: o.clock.Now(),
|
|
||||||
})
|
|
||||||
b := val.(*bucket)
|
|
||||||
|
|
||||||
b.mu.Lock()
|
|
||||||
now := o.clock.Now()
|
|
||||||
elapsed := now.Sub(b.lastTime).Seconds()
|
|
||||||
b.tokens += elapsed * o.rate
|
|
||||||
if b.tokens > float64(o.burst) {
|
|
||||||
b.tokens = float64(o.burst)
|
|
||||||
}
|
|
||||||
b.lastTime = now
|
|
||||||
|
|
||||||
if b.tokens < 1 {
|
|
||||||
retryAfter := (1 - b.tokens) / o.rate
|
|
||||||
b.mu.Unlock()
|
|
||||||
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b.tokens--
|
|
||||||
b.mu.Unlock()
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// limiter holds the per-key token buckets for one RateLimit middleware.
|
||||||
|
type limiter struct {
|
||||||
|
opts *rateLimitOptions
|
||||||
|
buckets sync.Map // key -> *bucket
|
||||||
|
count atomic.Int64
|
||||||
|
sweeping atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow reports whether a request for key may proceed. When denied it also
|
||||||
|
// returns the suggested Retry-After delay in seconds.
|
||||||
|
func (l *limiter) allow(key string) (bool, float64) {
|
||||||
|
o := l.opts
|
||||||
|
val, loaded := l.buckets.LoadOrStore(key, &bucket{
|
||||||
|
tokens: float64(o.burst),
|
||||||
|
lastTime: o.clock.Now(),
|
||||||
|
})
|
||||||
|
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
|
||||||
|
l.sweep()
|
||||||
|
}
|
||||||
|
b := val.(*bucket)
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
now := o.clock.Now()
|
||||||
|
elapsed := now.Sub(b.lastTime).Seconds()
|
||||||
|
b.tokens += elapsed * o.rate
|
||||||
|
if b.tokens > float64(o.burst) {
|
||||||
|
b.tokens = float64(o.burst)
|
||||||
|
}
|
||||||
|
b.lastTime = now
|
||||||
|
|
||||||
|
if b.tokens < 1 {
|
||||||
|
return false, (1 - b.tokens) / o.rate
|
||||||
|
}
|
||||||
|
b.tokens--
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
|
||||||
|
// runs at a time; buckets that still hold a partial limit are preserved so
|
||||||
|
// that eviction can never reset an active client's allowance.
|
||||||
|
func (l *limiter) sweep() {
|
||||||
|
if !l.sweeping.CompareAndSwap(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer l.sweeping.Store(false)
|
||||||
|
|
||||||
|
o := l.opts
|
||||||
|
now := o.clock.Now()
|
||||||
|
l.buckets.Range(func(k, v any) bool {
|
||||||
|
b := v.(*bucket)
|
||||||
|
b.mu.Lock()
|
||||||
|
elapsed := now.Sub(b.lastTime).Seconds()
|
||||||
|
full := b.tokens+elapsed*o.rate >= float64(o.burst)
|
||||||
|
b.mu.Unlock()
|
||||||
|
if full {
|
||||||
|
l.buckets.Delete(k)
|
||||||
|
l.count.Add(-1)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type bucket struct {
|
type bucket struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
tokens float64
|
tokens float64
|
||||||
lastTime time.Time
|
lastTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientIP extracts the client IP from the request. It checks
|
// clientKey derives the rate-limit key from a request. It uses RemoteAddr by
|
||||||
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
|
// default and only consults X-Forwarded-For when the peer is a configured
|
||||||
func clientIP(r *http.Request) string {
|
// trusted proxy (see WithTrustedProxies).
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
func (o *rateLimitOptions) clientKey(r *http.Request) string {
|
||||||
// First IP in the comma-separated list is the original client.
|
remote := remoteIP(r)
|
||||||
if i := indexOf(xff, ','); i > 0 {
|
if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
|
||||||
return xff[:i]
|
return remote
|
||||||
|
}
|
||||||
|
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
|
||||||
|
// address that is not itself a trusted proxy — that is the real client.
|
||||||
|
xff := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xff == "" {
|
||||||
|
return remote
|
||||||
|
}
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
ip := strings.TrimSpace(parts[i])
|
||||||
|
if ip == "" || o.isTrusted(ip) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
return xff
|
return ip
|
||||||
}
|
}
|
||||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
return remote
|
||||||
return xri
|
}
|
||||||
|
|
||||||
|
func (o *rateLimitOptions) isTrusted(ip string) bool {
|
||||||
|
parsed := net.ParseIP(ip)
|
||||||
|
if parsed == nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
for _, n := range o.trustedProxies {
|
||||||
|
if n.Contains(parsed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
|
||||||
|
// has no port.
|
||||||
|
func remoteIP(r *http.Request) string {
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r.RemoteAddr
|
return r.RemoteAddr
|
||||||
}
|
}
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
func indexOf(s string, b byte) int {
|
|
||||||
for i := range len(s) {
|
|
||||||
if s[i] == b {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|||||||
62
server/middleware_ratelimit_internal_test.go
Normal file
62
server/middleware_ratelimit_internal_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestRequest(remoteAddr, xff string) *http.Request {
|
||||||
|
r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}}
|
||||||
|
if xff != "" {
|
||||||
|
r.Header.Set("X-Forwarded-For", xff)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled
|
||||||
|
// (idle) buckets while preserving buckets that still hold an active limit, so
|
||||||
|
// memory is bounded without resetting live clients' allowances.
|
||||||
|
func TestLimiterSweepEvictsIdleBuckets(t *testing.T) {
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
|
o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30}
|
||||||
|
lim := &limiter{opts: o}
|
||||||
|
|
||||||
|
// "idle" makes a single request, then time passes so it refills to full.
|
||||||
|
lim.allow("idle")
|
||||||
|
clk.Advance(10 * time.Second)
|
||||||
|
|
||||||
|
// "active" drains its whole burst at the (advanced) current time.
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
lim.allow("active")
|
||||||
|
}
|
||||||
|
|
||||||
|
lim.sweep()
|
||||||
|
|
||||||
|
if _, ok := lim.buckets.Load("idle"); ok {
|
||||||
|
t.Error("fully-refilled idle bucket was not evicted")
|
||||||
|
}
|
||||||
|
if _, ok := lim.buckets.Load("active"); !ok {
|
||||||
|
t.Error("active bucket with a partial limit was wrongly evicted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a
|
||||||
|
// trusted proxy, independent of the HTTP layer.
|
||||||
|
func TestClientKeyTrustedProxy(t *testing.T) {
|
||||||
|
o := &rateLimitOptions{}
|
||||||
|
WithTrustedProxies("192.168.0.0/16")(o)
|
||||||
|
|
||||||
|
r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10")
|
||||||
|
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||||
|
t.Fatalf("clientKey = %q, want real client 203.0.113.7", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Untrusted peer: X-Forwarded-For must be ignored entirely.
|
||||||
|
r = newTestRequest("203.0.113.7:443", "10.0.0.1")
|
||||||
|
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||||
|
t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -83,28 +83,59 @@ func TestRateLimit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
|
||||||
|
// By default the limiter keys on RemoteAddr only. A spoofed,
|
||||||
|
// per-request X-Forwarded-For must not let a single peer bypass the
|
||||||
|
// limit by minting a fresh bucket each time.
|
||||||
mw := server.RateLimit(
|
mw := server.RateLimit(
|
||||||
server.WithRate(1),
|
server.WithRate(1),
|
||||||
server.WithBurst(1),
|
server.WithBurst(1),
|
||||||
)(okHandler)
|
)(okHandler)
|
||||||
|
|
||||||
// Exhaust limit for 10.0.0.1.
|
send := func(xff string) int {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
req.Header.Set("X-Forwarded-For", xff)
|
||||||
req.RemoteAddr = "192.168.1.1:1234"
|
req.RemoteAddr = "192.168.1.1:1234"
|
||||||
mw.ServeHTTP(w, req)
|
mw.ServeHTTP(w, req)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
// Same forwarded IP should be rate limited.
|
if code := send("10.0.0.1"); code != http.StatusOK {
|
||||||
w = httptest.NewRecorder()
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
}
|
||||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
// Different spoofed XFF, same peer — must still be limited.
|
||||||
req.RemoteAddr = "192.168.1.1:1234"
|
if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
|
||||||
mw.ServeHTTP(w, req)
|
t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusTooManyRequests {
|
t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
|
||||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
server.WithTrustedProxies("192.168.0.0/16"),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
send := func(xff string) int {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", xff)
|
||||||
|
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
|
||||||
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
|
||||||
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
// A different real client through the same proxy is independent.
|
||||||
|
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
|
||||||
|
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user