diff --git a/server/middleware_ratelimit.go b/server/middleware_ratelimit.go new file mode 100644 index 0000000..a95614a --- /dev/null +++ b/server/middleware_ratelimit.go @@ -0,0 +1,129 @@ +package server + +import ( + "net" + "net/http" + "strconv" + "sync" + "time" + + "git.codelab.vc/pkg/httpx/internal/clock" +) + +type rateLimitOptions struct { + rate float64 + burst int + keyFunc func(r *http.Request) string + clock clock.Clock +} + +// RateLimitOption configures the RateLimit middleware. +type RateLimitOption func(*rateLimitOptions) + +// WithRate sets the token refill rate (tokens per second). +func WithRate(tokensPerSecond float64) RateLimitOption { + return func(o *rateLimitOptions) { o.rate = tokensPerSecond } +} + +// WithBurst sets the maximum burst size (bucket capacity). +func WithBurst(n int) RateLimitOption { + return func(o *rateLimitOptions) { o.burst = n } +} + +// WithKeyFunc sets a custom function to extract the rate-limit key from a +// request. By default, the client IP address is used. +func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { + return func(o *rateLimitOptions) { o.keyFunc = fn } +} + +// withRateLimitClock sets the clock for testing. Not exported. +func withRateLimitClock(c clock.Clock) RateLimitOption { + return func(o *rateLimitOptions) { o.clock = c } +} + +// RateLimit returns a middleware that limits requests using a per-key token +// bucket algorithm. When the limit is exceeded, it returns 429 Too Many +// Requests with a Retry-After header. +func RateLimit(opts ...RateLimitOption) Middleware { + o := &rateLimitOptions{ + rate: 10, + burst: 20, + clock: clock.System(), + } + for _, opt := range opts { + opt(o) + } + if o.keyFunc == nil { + o.keyFunc = clientIP + } + + var buckets sync.Map + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := o.keyFunc(r) + val, _ := buckets.LoadOrStore(key, &bucket{ + tokens: float64(o.burst), + lastTime: o.clock.Now(), + }) + b := val.(*bucket) + + b.mu.Lock() + now := o.clock.Now() + elapsed := now.Sub(b.lastTime).Seconds() + b.tokens += elapsed * o.rate + if b.tokens > float64(o.burst) { + b.tokens = float64(o.burst) + } + b.lastTime = now + + if b.tokens < 1 { + retryAfter := (1 - b.tokens) / o.rate + b.mu.Unlock() + w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + + b.tokens-- + b.mu.Unlock() + + next.ServeHTTP(w, r) + }) + } +} + +type bucket struct { + mu sync.Mutex + tokens float64 + lastTime time.Time +} + +// clientIP extracts the client IP from the request. It checks +// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr. +func clientIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // First IP in the comma-separated list is the original client. + if i := indexOf(xff, ','); i > 0 { + return xff[:i] + } + return xff + } + if xri := r.Header.Get("X-Real-Ip"); xri != "" { + return xri + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +func indexOf(s string, b byte) int { + for i := range len(s) { + if s[i] == b { + return i + } + } + return -1 +} diff --git a/server/middleware_ratelimit_test.go b/server/middleware_ratelimit_test.go new file mode 100644 index 0000000..dd3bffe --- /dev/null +++ b/server/middleware_ratelimit_test.go @@ -0,0 +1,171 @@ +package server_test + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestRateLimit(t *testing.T) { + okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("allows requests within limit", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(100), + server.WithBurst(10), + )(okHandler) + + for i := range 10 { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK) + } + } + }) + + t.Run("rejects when burst exhausted", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1), + server.WithBurst(2), + )(okHandler) + + // Exhaust burst. + for range 2 { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + } + + // Next request should be rejected. + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) + } + if w.Header().Get("Retry-After") == "" { + t.Fatal("expected Retry-After header") + } + }) + + t.Run("different IPs have independent limits", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1), + server.WithBurst(1), + )(okHandler) + + // First IP exhausts its limit. + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + + // Second IP should still be allowed. + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "5.6.7.8:5678" + mw.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("uses X-Forwarded-For", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1), + server.WithBurst(1), + )(okHandler) + + // Exhaust limit for 10.0.0.1. + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") + req.RemoteAddr = "192.168.1.1:1234" + mw.ServeHTTP(w, req) + + // Same forwarded IP should be rate limited. + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") + req.RemoteAddr = "192.168.1.1:1234" + mw.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) + } + }) + + t.Run("custom key function", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1), + server.WithBurst(1), + server.WithKeyFunc(func(r *http.Request) string { + return r.Header.Get("X-API-Key") + }), + )(okHandler) + + // Exhaust key "abc". + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "abc") + mw.ServeHTTP(w, req) + + // Same key should be rate limited. + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "abc") + mw.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) + } + + // Different key should be allowed. + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "xyz") + mw.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("tokens refill over time", func(t *testing.T) { + mw := server.RateLimit( + server.WithRate(1000), // Very fast refill for test + server.WithBurst(1), + )(okHandler) + + // Exhaust burst. + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + + // Wait a bit for refill. + time.Sleep(5 * time.Millisecond) + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + mw.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK) + } + }) +}