Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is a configured trusted proxy (WithTrustedProxies), walking right-to-left to the first untrusted hop. This closes a trivial rate-limit bypass and the matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with opportunistic eviction of idle (fully-refilled) buckets to bound memory. Drop the hand-rolled indexOf in favor of stdlib.
249 lines
6.8 KiB
Go
249 lines
6.8 KiB
Go
package server
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
|
)
|
|
|
|
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
|
|
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
|
|
const defaultMaxKeys = 1 << 16
|
|
|
|
type rateLimitOptions struct {
|
|
rate float64
|
|
burst int
|
|
keyFunc func(r *http.Request) string
|
|
clock clock.Clock
|
|
trustedProxies []*net.IPNet
|
|
maxKeys int
|
|
}
|
|
|
|
// RateLimitOption configures the RateLimit middleware.
|
|
type RateLimitOption func(*rateLimitOptions)
|
|
|
|
// WithRate sets the token refill rate (tokens per second).
|
|
func WithRate(tokensPerSecond float64) RateLimitOption {
|
|
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
|
|
}
|
|
|
|
// WithBurst sets the maximum burst size (bucket capacity).
|
|
func WithBurst(n int) RateLimitOption {
|
|
return func(o *rateLimitOptions) { o.burst = n }
|
|
}
|
|
|
|
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
|
// request. By default, the client IP from RemoteAddr is used (see
|
|
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
|
|
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
|
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
|
}
|
|
|
|
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
|
|
// whose immediate peer (RemoteAddr) falls within one of the given trusted
|
|
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
|
|
// a /32 or /128. When the peer is trusted, the client key is taken from the
|
|
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
|
|
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
|
|
// untrusted), so a typo can never silently widen trust.
|
|
//
|
|
// Without this option the middleware never trusts client-supplied forwarding
|
|
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
|
|
// spoofed headers.
|
|
func WithTrustedProxies(cidrs ...string) RateLimitOption {
|
|
return func(o *rateLimitOptions) {
|
|
for _, c := range cidrs {
|
|
if _, ipnet, err := net.ParseCIDR(c); err == nil {
|
|
o.trustedProxies = append(o.trustedProxies, ipnet)
|
|
continue
|
|
}
|
|
if ip := net.ParseIP(c); ip != nil {
|
|
bits := 32
|
|
if ip.To4() == nil {
|
|
bits = 128
|
|
}
|
|
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
|
|
IP: ip,
|
|
Mask: net.CIDRMask(bits, bits),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
|
|
// retained in memory. When exceeded, idle (fully-refilled) buckets are
|
|
// evicted; active buckets are never dropped. Default is 65536.
|
|
func WithMaxKeys(n int) RateLimitOption {
|
|
return func(o *rateLimitOptions) {
|
|
if n > 0 {
|
|
o.maxKeys = n
|
|
}
|
|
}
|
|
}
|
|
|
|
// withRateLimitClock sets the clock for testing. Not exported.
|
|
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
|
return func(o *rateLimitOptions) { o.clock = c }
|
|
}
|
|
|
|
// RateLimit returns a middleware that limits requests using a per-key token
|
|
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
|
// Requests with a Retry-After header.
|
|
//
|
|
// By default the key is the client IP taken from RemoteAddr. Forwarding
|
|
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
|
|
// so the limiter cannot be bypassed by spoofing headers.
|
|
func RateLimit(opts ...RateLimitOption) Middleware {
|
|
o := &rateLimitOptions{
|
|
rate: 10,
|
|
burst: 20,
|
|
clock: clock.System(),
|
|
maxKeys: defaultMaxKeys,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
if o.keyFunc == nil {
|
|
o.keyFunc = o.clientKey
|
|
}
|
|
|
|
lim := &limiter{opts: o}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key := o.keyFunc(r)
|
|
if allowed, retryAfter := lim.allow(key); !allowed {
|
|
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// limiter holds the per-key token buckets for one RateLimit middleware.
|
|
type limiter struct {
|
|
opts *rateLimitOptions
|
|
buckets sync.Map // key -> *bucket
|
|
count atomic.Int64
|
|
sweeping atomic.Bool
|
|
}
|
|
|
|
// allow reports whether a request for key may proceed. When denied it also
|
|
// returns the suggested Retry-After delay in seconds.
|
|
func (l *limiter) allow(key string) (bool, float64) {
|
|
o := l.opts
|
|
val, loaded := l.buckets.LoadOrStore(key, &bucket{
|
|
tokens: float64(o.burst),
|
|
lastTime: o.clock.Now(),
|
|
})
|
|
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
|
|
l.sweep()
|
|
}
|
|
b := val.(*bucket)
|
|
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
now := o.clock.Now()
|
|
elapsed := now.Sub(b.lastTime).Seconds()
|
|
b.tokens += elapsed * o.rate
|
|
if b.tokens > float64(o.burst) {
|
|
b.tokens = float64(o.burst)
|
|
}
|
|
b.lastTime = now
|
|
|
|
if b.tokens < 1 {
|
|
return false, (1 - b.tokens) / o.rate
|
|
}
|
|
b.tokens--
|
|
return true, 0
|
|
}
|
|
|
|
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
|
|
// runs at a time; buckets that still hold a partial limit are preserved so
|
|
// that eviction can never reset an active client's allowance.
|
|
func (l *limiter) sweep() {
|
|
if !l.sweeping.CompareAndSwap(false, true) {
|
|
return
|
|
}
|
|
defer l.sweeping.Store(false)
|
|
|
|
o := l.opts
|
|
now := o.clock.Now()
|
|
l.buckets.Range(func(k, v any) bool {
|
|
b := v.(*bucket)
|
|
b.mu.Lock()
|
|
elapsed := now.Sub(b.lastTime).Seconds()
|
|
full := b.tokens+elapsed*o.rate >= float64(o.burst)
|
|
b.mu.Unlock()
|
|
if full {
|
|
l.buckets.Delete(k)
|
|
l.count.Add(-1)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
type bucket struct {
|
|
mu sync.Mutex
|
|
tokens float64
|
|
lastTime time.Time
|
|
}
|
|
|
|
// clientKey derives the rate-limit key from a request. It uses RemoteAddr by
|
|
// default and only consults X-Forwarded-For when the peer is a configured
|
|
// trusted proxy (see WithTrustedProxies).
|
|
func (o *rateLimitOptions) clientKey(r *http.Request) string {
|
|
remote := remoteIP(r)
|
|
if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
|
|
return remote
|
|
}
|
|
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
|
|
// address that is not itself a trusted proxy — that is the real client.
|
|
xff := r.Header.Get("X-Forwarded-For")
|
|
if xff == "" {
|
|
return remote
|
|
}
|
|
parts := strings.Split(xff, ",")
|
|
for i := len(parts) - 1; i >= 0; i-- {
|
|
ip := strings.TrimSpace(parts[i])
|
|
if ip == "" || o.isTrusted(ip) {
|
|
continue
|
|
}
|
|
return ip
|
|
}
|
|
return remote
|
|
}
|
|
|
|
func (o *rateLimitOptions) isTrusted(ip string) bool {
|
|
parsed := net.ParseIP(ip)
|
|
if parsed == nil {
|
|
return false
|
|
}
|
|
for _, n := range o.trustedProxies {
|
|
if n.Contains(parsed) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
|
|
// has no port.
|
|
func remoteIP(r *http.Request) string {
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|