package server import ( "net" "net/http" "strconv" "strings" "sync" "sync/atomic" "time" "git.codelab.vc/pkg/httpx/internal/clock" ) // defaultMaxKeys bounds the number of distinct rate-limit buckets retained in // memory. When exceeded, fully-refilled (idle) buckets are evicted. const defaultMaxKeys = 1 << 16 type rateLimitOptions struct { rate float64 burst int keyFunc func(r *http.Request) string clock clock.Clock trustedProxies []*net.IPNet maxKeys int } // RateLimitOption configures the RateLimit middleware. type RateLimitOption func(*rateLimitOptions) // WithRate sets the token refill rate (tokens per second). func WithRate(tokensPerSecond float64) RateLimitOption { return func(o *rateLimitOptions) { o.rate = tokensPerSecond } } // WithBurst sets the maximum burst size (bucket capacity). func WithBurst(n int) RateLimitOption { return func(o *rateLimitOptions) { o.burst = n } } // WithKeyFunc sets a custom function to extract the rate-limit key from a // request. By default, the client IP from RemoteAddr is used (see // WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy). func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { return func(o *rateLimitOptions) { o.keyFunc = fn } } // WithTrustedProxies enables X-Forwarded-For parsing, but only for requests // whose immediate peer (RemoteAddr) falls within one of the given trusted // CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as // a /32 or /128. When the peer is trusted, the client key is taken from the // right-most X-Forwarded-For entry that is not itself a trusted proxy; // otherwise RemoteAddr is used. Invalid entries are ignored (treated as // untrusted), so a typo can never silently widen trust. // // Without this option the middleware never trusts client-supplied forwarding // headers, which prevents trivial rate-limit bypass and bucket exhaustion via // spoofed headers. func WithTrustedProxies(cidrs ...string) RateLimitOption { return func(o *rateLimitOptions) { for _, c := range cidrs { if _, ipnet, err := net.ParseCIDR(c); err == nil { o.trustedProxies = append(o.trustedProxies, ipnet) continue } if ip := net.ParseIP(c); ip != nil { bits := 32 if ip.To4() == nil { bits = 128 } o.trustedProxies = append(o.trustedProxies, &net.IPNet{ IP: ip, Mask: net.CIDRMask(bits, bits), }) } } } } // WithMaxKeys sets the soft upper bound on the number of distinct buckets // retained in memory. When exceeded, idle (fully-refilled) buckets are // evicted; active buckets are never dropped. Default is 65536. func WithMaxKeys(n int) RateLimitOption { return func(o *rateLimitOptions) { if n > 0 { o.maxKeys = n } } } // withRateLimitClock sets the clock for testing. Not exported. func withRateLimitClock(c clock.Clock) RateLimitOption { return func(o *rateLimitOptions) { o.clock = c } } // RateLimit returns a middleware that limits requests using a per-key token // bucket algorithm. When the limit is exceeded, it returns 429 Too Many // Requests with a Retry-After header. // // By default the key is the client IP taken from RemoteAddr. Forwarding // headers (X-Forwarded-For) are honored only when WithTrustedProxies is set, // so the limiter cannot be bypassed by spoofing headers. func RateLimit(opts ...RateLimitOption) Middleware { o := &rateLimitOptions{ rate: 10, burst: 20, clock: clock.System(), maxKeys: defaultMaxKeys, } for _, opt := range opts { opt(o) } if o.keyFunc == nil { o.keyFunc = o.clientKey } lim := &limiter{opts: o} return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := o.keyFunc(r) if allowed, retryAfter := lim.allow(key); !allowed { w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } // limiter holds the per-key token buckets for one RateLimit middleware. type limiter struct { opts *rateLimitOptions buckets sync.Map // key -> *bucket count atomic.Int64 sweeping atomic.Bool } // allow reports whether a request for key may proceed. When denied it also // returns the suggested Retry-After delay in seconds. func (l *limiter) allow(key string) (bool, float64) { o := l.opts val, loaded := l.buckets.LoadOrStore(key, &bucket{ tokens: float64(o.burst), lastTime: o.clock.Now(), }) if !loaded && l.count.Add(1) > int64(o.maxKeys) { l.sweep() } b := val.(*bucket) b.mu.Lock() defer b.mu.Unlock() now := o.clock.Now() elapsed := now.Sub(b.lastTime).Seconds() b.tokens += elapsed * o.rate if b.tokens > float64(o.burst) { b.tokens = float64(o.burst) } b.lastTime = now if b.tokens < 1 { return false, (1 - b.tokens) / o.rate } b.tokens-- return true, 0 } // sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep // runs at a time; buckets that still hold a partial limit are preserved so // that eviction can never reset an active client's allowance. func (l *limiter) sweep() { if !l.sweeping.CompareAndSwap(false, true) { return } defer l.sweeping.Store(false) o := l.opts now := o.clock.Now() l.buckets.Range(func(k, v any) bool { b := v.(*bucket) b.mu.Lock() elapsed := now.Sub(b.lastTime).Seconds() full := b.tokens+elapsed*o.rate >= float64(o.burst) b.mu.Unlock() if full { l.buckets.Delete(k) l.count.Add(-1) } return true }) } type bucket struct { mu sync.Mutex tokens float64 lastTime time.Time } // clientKey derives the rate-limit key from a request. It uses RemoteAddr by // default and only consults X-Forwarded-For when the peer is a configured // trusted proxy (see WithTrustedProxies). func (o *rateLimitOptions) clientKey(r *http.Request) string { remote := remoteIP(r) if len(o.trustedProxies) == 0 || !o.isTrusted(remote) { return remote } // Peer is trusted: walk X-Forwarded-For right-to-left and return the first // address that is not itself a trusted proxy — that is the real client. xff := r.Header.Get("X-Forwarded-For") if xff == "" { return remote } parts := strings.Split(xff, ",") for i := len(parts) - 1; i >= 0; i-- { ip := strings.TrimSpace(parts[i]) if ip == "" || o.isTrusted(ip) { continue } return ip } return remote } func (o *rateLimitOptions) isTrusted(ip string) bool { parsed := net.ParseIP(ip) if parsed == nil { return false } for _, n := range o.trustedProxies { if n.Contains(parsed) { return true } } return false } // remoteIP returns the host portion of r.RemoteAddr, or the raw value if it // has no port. func remoteIP(r *http.Request) string { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host }