package server import ( "net" "net/http" "strconv" "sync" "time" "git.codelab.vc/pkg/httpx/internal/clock" ) type rateLimitOptions struct { rate float64 burst int keyFunc func(r *http.Request) string clock clock.Clock } // RateLimitOption configures the RateLimit middleware. type RateLimitOption func(*rateLimitOptions) // WithRate sets the token refill rate (tokens per second). func WithRate(tokensPerSecond float64) RateLimitOption { return func(o *rateLimitOptions) { o.rate = tokensPerSecond } } // WithBurst sets the maximum burst size (bucket capacity). func WithBurst(n int) RateLimitOption { return func(o *rateLimitOptions) { o.burst = n } } // WithKeyFunc sets a custom function to extract the rate-limit key from a // request. By default, the client IP address is used. func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { return func(o *rateLimitOptions) { o.keyFunc = fn } } // withRateLimitClock sets the clock for testing. Not exported. func withRateLimitClock(c clock.Clock) RateLimitOption { return func(o *rateLimitOptions) { o.clock = c } } // RateLimit returns a middleware that limits requests using a per-key token // bucket algorithm. When the limit is exceeded, it returns 429 Too Many // Requests with a Retry-After header. func RateLimit(opts ...RateLimitOption) Middleware { o := &rateLimitOptions{ rate: 10, burst: 20, clock: clock.System(), } for _, opt := range opts { opt(o) } if o.keyFunc == nil { o.keyFunc = clientIP } var buckets sync.Map return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := o.keyFunc(r) val, _ := buckets.LoadOrStore(key, &bucket{ tokens: float64(o.burst), lastTime: o.clock.Now(), }) b := val.(*bucket) b.mu.Lock() now := o.clock.Now() elapsed := now.Sub(b.lastTime).Seconds() b.tokens += elapsed * o.rate if b.tokens > float64(o.burst) { b.tokens = float64(o.burst) } b.lastTime = now if b.tokens < 1 { retryAfter := (1 - b.tokens) / o.rate b.mu.Unlock() w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } b.tokens-- b.mu.Unlock() next.ServeHTTP(w, r) }) } } type bucket struct { mu sync.Mutex tokens float64 lastTime time.Time } // clientIP extracts the client IP from the request. It checks // X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr. func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // First IP in the comma-separated list is the original client. if i := indexOf(xff, ','); i > 0 { return xff[:i] } return xff } if xri := r.Header.Get("X-Real-Ip"); xri != "" { return xri } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func indexOf(s string, b byte) int { for i := range len(s) { if s[i] == b { return i } } return -1 }