Compare commits
7 Commits
b6350185d9
...
f609b12c2f
| Author | SHA1 | Date | |
|---|---|---|---|
| f609b12c2f | |||
| b5259af73e | |||
| 01478be0dc | |||
| b07d487e63 | |||
| 43d3ecfba1 | |||
| e8c4577c6f | |||
| 2d4a06e715 |
@@ -24,7 +24,7 @@ go vet ./... # static analysis
|
||||
- **Client.Close()** stops the health checker goroutine
|
||||
- **Client.Patch()** — PATCH method, same pattern as Put/Post
|
||||
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
|
||||
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM
|
||||
- **WithMaxResponseBody** — caps `resp.Body` reads; returns `ErrResponseTooLarge` (not silent truncation) when exceeded
|
||||
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
|
||||
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
|
||||
|
||||
@@ -37,7 +37,7 @@ go vet ./... # static analysis
|
||||
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
|
||||
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
|
||||
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
|
||||
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock`
|
||||
- **RateLimit** middleware — per-key token bucket (`sync.Map`), keys on `RemoteAddr` by default; `X-Forwarded-For` is honored only via `WithTrustedProxies`; `WithRate`/`WithBurst`/`WithKeyFunc`/`WithMaxKeys`, uses `internal/clock`, idle buckets evicted to bound memory
|
||||
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
|
||||
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
|
||||
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
|
||||
|
||||
@@ -67,7 +67,7 @@ Server middleware is `func(http.Handler) http.Handler`. The `server` package pro
|
||||
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
|
||||
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
|
||||
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
|
||||
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. |
|
||||
| `server.RateLimit` | Per-key token bucket rate limiting (keys on `RemoteAddr`; `X-Forwarded-For` via `WithTrustedProxies`) with `Retry-After`. |
|
||||
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
|
||||
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
|
||||
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
|
||||
@@ -195,6 +195,9 @@ client := httpx.New(
|
||||
)
|
||||
```
|
||||
|
||||
Reading a body that exceeds the limit returns `httpx.ErrResponseTooLarge`
|
||||
(checkable with `errors.Is`) rather than silently truncating.
|
||||
|
||||
## Examples
|
||||
|
||||
See the [`examples/`](examples/) directory for runnable programs:
|
||||
|
||||
@@ -55,12 +55,19 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// Pre-parse endpoint URLs once at construction time.
|
||||
// Pre-parse endpoint URLs once at construction time. A malformed URL is a
|
||||
// configuration error: rather than panicking (which would crash the host
|
||||
// application, often at startup from external config), we capture the
|
||||
// error and surface it from the transport on first use.
|
||||
parsed := make(map[string]*url.URL, len(endpoints))
|
||||
var parseErr error
|
||||
for _, ep := range endpoints {
|
||||
u, err := url.Parse(ep.URL)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
|
||||
if parseErr == nil {
|
||||
parseErr = fmt.Errorf("balancer: invalid endpoint URL %q: %w", ep.URL, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
parsed[ep.URL] = u
|
||||
}
|
||||
@@ -73,6 +80,10 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
|
||||
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
|
||||
healthy := endpoints
|
||||
if o.healthChecker != nil {
|
||||
healthy = o.healthChecker.Healthy(endpoints)
|
||||
|
||||
@@ -61,6 +61,27 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_InvalidEndpointURLReturnsError(t *testing.T) {
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
t.Fatal("base transport should not be reached for an invalid endpoint")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
// A malformed URL must not panic; the error surfaces on first use.
|
||||
mw, closer := Transport([]Endpoint{{URL: "://missing-scheme"}})
|
||||
defer closer.Close()
|
||||
rt := mw(base)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := rt.RoundTrip(req); err == nil {
|
||||
t.Fatal("expected an error for invalid endpoint URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
||||
var endpoints []Endpoint
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
const (
|
||||
defaultHealthInterval = 10 * time.Second
|
||||
defaultHealthPath = "/health"
|
||||
defaultHealthPath = "/healthz"
|
||||
defaultHealthTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
|
||||
99
balancer/health_test.go
Normal file
99
balancer/health_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHealthChecker_InitialProbeClassifiesEndpoints(t *testing.T) {
|
||||
healthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer healthy.Close()
|
||||
|
||||
unhealthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer unhealthy.Close()
|
||||
|
||||
eps := []Endpoint{{URL: healthy.URL}, {URL: unhealthy.URL}}
|
||||
|
||||
hc := newHealthChecker()
|
||||
hc.Start(eps) // runs an initial synchronous probe
|
||||
defer hc.Stop()
|
||||
|
||||
if !hc.IsHealthy(eps[0]) {
|
||||
t.Errorf("healthy endpoint reported unhealthy")
|
||||
}
|
||||
if hc.IsHealthy(eps[1]) {
|
||||
t.Errorf("unhealthy endpoint reported healthy")
|
||||
}
|
||||
|
||||
got := hc.Healthy(eps)
|
||||
if len(got) != 1 || got[0].URL != healthy.URL {
|
||||
t.Errorf("Healthy() = %v, want only %s", got, healthy.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthChecker_DetectsRecovery(t *testing.T) {
|
||||
var up atomic.Bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if up.Load() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
eps := []Endpoint{{URL: srv.URL}}
|
||||
|
||||
hc := newHealthChecker()
|
||||
hc.Start(eps)
|
||||
defer hc.Stop()
|
||||
|
||||
if hc.IsHealthy(eps[0]) {
|
||||
t.Fatalf("endpoint should start unhealthy")
|
||||
}
|
||||
|
||||
// Recover the backend and force a deterministic re-probe.
|
||||
up.Store(true)
|
||||
hc.probe(context.Background(), eps)
|
||||
|
||||
if !hc.IsHealthy(eps[0]) {
|
||||
t.Fatalf("endpoint should be healthy after recovery")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthChecker_StopTerminatesLoop(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
hc := newHealthChecker(WithHealthInterval(time.Millisecond))
|
||||
hc.Start([]Endpoint{{URL: srv.URL}})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hc.Stop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Stop did not return within 2s — loop goroutine leaked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthChecker_UnknownEndpointIsUnhealthy(t *testing.T) {
|
||||
hc := newHealthChecker()
|
||||
if hc.IsHealthy(Endpoint{URL: "http://never-probed.example"}) {
|
||||
t.Error("unknown endpoint should be reported unhealthy")
|
||||
}
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (b *Breaker) State() State {
|
||||
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
||||
// open duration has elapsed. Caller must hold b.mu.
|
||||
func (b *Breaker) stateLocked() State {
|
||||
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
|
||||
if b.state == StateOpen && b.opts.clk.Since(b.openedAt) >= b.opts.openDuration {
|
||||
b.state = StateHalfOpen
|
||||
b.halfOpenCur = 0
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func (b *Breaker) record(success bool) {
|
||||
// tripLocked transitions to the Open state and records the timestamp.
|
||||
func (b *Breaker) tripLocked() {
|
||||
b.state = StateOpen
|
||||
b.openedAt = time.Now()
|
||||
b.openedAt = b.opts.clk.Now()
|
||||
b.halfOpenCur = 0
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
@@ -80,9 +81,11 @@ func TestBreaker_OpenRejectsRequests(t *testing.T) {
|
||||
|
||||
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
clk := clock.Mock(time.Now())
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
withClock(clk),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
@@ -96,8 +99,8 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||
}
|
||||
|
||||
// Wait for the open duration to elapse.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
// Advance past the open duration.
|
||||
clk.Advance(openDuration + time.Millisecond)
|
||||
|
||||
if s := b.State(); s != StateHalfOpen {
|
||||
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
||||
@@ -106,9 +109,11 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||
|
||||
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
clk := clock.Mock(time.Now())
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
withClock(clk),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
@@ -118,8 +123,8 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||
}
|
||||
done(false)
|
||||
|
||||
// Wait for half-open.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
// Advance into half-open.
|
||||
clk.Advance(openDuration + time.Millisecond)
|
||||
|
||||
// A successful request in half-open should close the breaker.
|
||||
done, err = b.Allow()
|
||||
@@ -135,9 +140,11 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||
|
||||
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
clk := clock.Mock(time.Now())
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
withClock(clk),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
@@ -147,8 +154,8 @@ func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||
}
|
||||
done(false)
|
||||
|
||||
// Wait for half-open.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
// Advance into half-open.
|
||||
clk.Advance(openDuration + time.Millisecond)
|
||||
|
||||
// A failed request in half-open should re-open the breaker.
|
||||
done, err = b.Allow()
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package circuitbreaker
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
type options struct {
|
||||
failureThreshold int // consecutive failures to trip
|
||||
openDuration time.Duration // how long to stay open before half-open
|
||||
halfOpenMax int // max concurrent requests in half-open
|
||||
clk clock.Clock // time source (real by default)
|
||||
}
|
||||
|
||||
func defaults() options {
|
||||
@@ -13,12 +18,23 @@ func defaults() options {
|
||||
failureThreshold: 5,
|
||||
openDuration: 30 * time.Second,
|
||||
halfOpenMax: 1,
|
||||
clk: clock.System(),
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures a Breaker.
|
||||
type Option func(*options)
|
||||
|
||||
// withClock sets the clock used for state-transition timing. Unexported; for
|
||||
// deterministic tests.
|
||||
func withClock(c clock.Clock) Option {
|
||||
return func(o *options) {
|
||||
if c != nil {
|
||||
o.clk = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithFailureThreshold sets the number of consecutive failures required to
|
||||
// trip the breaker from Closed to Open. Default is 5.
|
||||
func WithFailureThreshold(n int) Option {
|
||||
|
||||
@@ -102,9 +102,12 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
||||
}
|
||||
|
||||
if c.maxResponseBody > 0 {
|
||||
// Read one byte past the limit so we can distinguish "exactly at the
|
||||
// limit" (allowed) from "exceeds the limit" (ErrResponseTooLarge).
|
||||
resp.Body = &limitedReadCloser{
|
||||
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody},
|
||||
C: resp.Body,
|
||||
r: io.LimitReader(resp.Body, c.maxResponseBody+1),
|
||||
c: resp.Body,
|
||||
limit: c.maxResponseBody,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -89,8 +89,8 @@ func WithBalancer(opts ...balancer.Option) Option {
|
||||
|
||||
// WithMaxResponseBody limits the number of bytes read from response bodies
|
||||
// by Response.Bytes (and by extension String, JSON, XML). If the response
|
||||
// body exceeds n bytes, reading stops and returns an error.
|
||||
// A value of 0 means no limit (the default).
|
||||
// body exceeds n bytes, reading returns ErrResponseTooLarge instead of
|
||||
// silently truncating. A value of 0 means no limit (the default).
|
||||
func WithMaxResponseBody(n int64) Option {
|
||||
return func(o *clientOptions) { o.maxResponseBody = n }
|
||||
}
|
||||
|
||||
6
error.go
6
error.go
@@ -1,6 +1,7 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -18,6 +19,11 @@ var (
|
||||
ErrNoHealthy = balancer.ErrNoHealthy
|
||||
)
|
||||
|
||||
// ErrResponseTooLarge is returned when reading a response body that exceeds
|
||||
// the limit configured via WithMaxResponseBody. Any bytes read up to the
|
||||
// limit are returned alongside the error.
|
||||
var ErrResponseTooLarge = errors.New("httpx: response body exceeds configured limit")
|
||||
|
||||
// Error provides structured error information for failed HTTP operations.
|
||||
type Error struct {
|
||||
// Op is the operation that failed (e.g. "Get", "Do").
|
||||
|
||||
@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// RoundTrippers must not mutate the caller's request; clone before
|
||||
// setting headers (req.Clone is shallow + a header copy).
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
||||
func BasicAuth(username, password string) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
req.SetBasicAuth(username, password)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
|
||||
@@ -7,10 +7,17 @@ import "net/http"
|
||||
func DefaultHeaders(headers http.Header) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
// Clone lazily on the first header we actually add, so that
|
||||
// RoundTrippers never mutate the caller's request.
|
||||
cloned := false
|
||||
for key, values := range headers {
|
||||
if req.Header.Get(key) != "" {
|
||||
continue
|
||||
}
|
||||
if !cloned {
|
||||
req = req.Clone(req.Context())
|
||||
cloned = true
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
|
||||
21
response.go
21
response.go
@@ -98,17 +98,26 @@ func (r *Response) BodyReader() io.Reader {
|
||||
return r.Body
|
||||
}
|
||||
|
||||
// limitedReadCloser wraps an io.LimitedReader with a separate Closer
|
||||
// so the original body can be closed.
|
||||
// limitedReadCloser enforces a maximum number of bytes that may be read from
|
||||
// a response body. Reading more than limit bytes returns ErrResponseTooLarge
|
||||
// rather than silently truncating the body. The original body is closed via
|
||||
// the separate Closer.
|
||||
type limitedReadCloser struct {
|
||||
R io.LimitedReader
|
||||
C io.Closer
|
||||
r io.Reader // an io.LimitReader over the original body (limit+1 bytes)
|
||||
c io.Closer // the original body, for Close
|
||||
limit int64
|
||||
read int64
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Read(p []byte) (int, error) {
|
||||
return l.R.Read(p)
|
||||
n, err := l.r.Read(p)
|
||||
l.read += int64(n)
|
||||
if l.read > l.limit {
|
||||
return n, ErrResponseTooLarge
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Close() error {
|
||||
return l.C.Close()
|
||||
return l.c.Close()
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -32,13 +33,30 @@ func TestClient_MaxResponseBody(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncates response exceeding limit", func(t *testing.T) {
|
||||
t.Run("returns ErrResponseTooLarge when exceeding limit", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", 1000)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, largeBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if _, err := resp.Bytes(); !errors.Is(err, httpx.ErrResponseTooLarge) {
|
||||
t.Fatalf("err = %v, want ErrResponseTooLarge", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows body exactly at limit", func(t *testing.T) {
|
||||
exact := strings.Repeat("x", 100)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, exact)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
@@ -46,7 +64,7 @@ func TestClient_MaxResponseBody(t *testing.T) {
|
||||
}
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
t.Fatalf("reading body at exact limit: %v", err)
|
||||
}
|
||||
if len(b) != 100 {
|
||||
t.Fatalf("body length = %d, want %d", len(b), 100)
|
||||
|
||||
@@ -44,8 +44,11 @@ func (b *exponentialBackoff) Delay(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
if b.withJitter {
|
||||
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
|
||||
delay += jitter
|
||||
// Guard against rand.Int64N panicking on a non-positive argument when
|
||||
// delay is small enough that delay/2 rounds to zero.
|
||||
if half := int64(delay / 2); half > 0 {
|
||||
delay += time.Duration(rand.Int64N(half))
|
||||
}
|
||||
}
|
||||
|
||||
if delay > b.max {
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package retry
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
type options struct {
|
||||
maxAttempts int // default 3
|
||||
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||
retryAfter bool // default true, respect Retry-After header
|
||||
maxAttempts int // default 3
|
||||
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||
retryAfter bool // default true, respect Retry-After header
|
||||
clk clock.Clock // time source for backoff delays (real by default)
|
||||
}
|
||||
|
||||
// Option configures the retry transport.
|
||||
@@ -18,6 +23,17 @@ func defaults() options {
|
||||
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
||||
policy: defaultPolicy{},
|
||||
retryAfter: true,
|
||||
clk: clock.System(),
|
||||
}
|
||||
}
|
||||
|
||||
// withClock sets the clock used for inter-attempt delays. Unexported; for
|
||||
// deterministic tests.
|
||||
func withClock(c clock.Clock) Option {
|
||||
return func(o *options) {
|
||||
if c != nil {
|
||||
o.clk = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,18 +37,16 @@ func Transport(opts ...Option) middleware.Middleware {
|
||||
var exhausted bool
|
||||
|
||||
for attempt := range cfg.maxAttempts {
|
||||
// For retries (attempt > 0), restore the request body.
|
||||
if attempt > 0 {
|
||||
if req.GetBody != nil {
|
||||
body, bodyErr := req.GetBody()
|
||||
if bodyErr != nil {
|
||||
return resp, bodyErr
|
||||
}
|
||||
req.Body = body
|
||||
} else if req.Body != nil {
|
||||
// Body was consumed and cannot be re-created.
|
||||
return resp, err
|
||||
// For retries (attempt > 0) the body was consumed by the
|
||||
// previous attempt; restore it via GetBody. The rewindability
|
||||
// check below guarantees GetBody is set whenever we loop with a
|
||||
// non-nil body, so this branch is always safe.
|
||||
if attempt > 0 && req.GetBody != nil {
|
||||
body, bodyErr := req.GetBody()
|
||||
if bodyErr != nil {
|
||||
return nil, bodyErr
|
||||
}
|
||||
req.Body = body
|
||||
}
|
||||
|
||||
resp, err = next.RoundTrip(req)
|
||||
@@ -64,6 +62,13 @@ func Transport(opts ...Option) middleware.Middleware {
|
||||
break
|
||||
}
|
||||
|
||||
// If the body cannot be rewound, a retry would replay with an
|
||||
// empty body. Return the current result as-is instead of
|
||||
// draining it and looping with a corrupted request.
|
||||
if req.Body != nil && req.GetBody == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Compute delay: use backoff or policy delay, whichever is larger.
|
||||
delay := cfg.backoff.Delay(attempt)
|
||||
if policyDelay > delay {
|
||||
@@ -84,12 +89,12 @@ func Transport(opts ...Option) middleware.Middleware {
|
||||
}
|
||||
|
||||
// Wait for the delay or context cancellation.
|
||||
timer := time.NewTimer(delay)
|
||||
timer := cfg.clk.NewTimer(delay)
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
timer.Stop()
|
||||
return nil, req.Context().Err()
|
||||
case <-timer.C:
|
||||
case <-timer.C():
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +123,9 @@ func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response,
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusTooManyRequests, // 429
|
||||
http.StatusBadGateway, // 502
|
||||
http.StatusBadGateway, // 502
|
||||
http.StatusServiceUnavailable, // 503
|
||||
http.StatusGatewayTimeout: // 504
|
||||
http.StatusGatewayTimeout: // 504
|
||||
return true, 0
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
@@ -229,6 +230,83 @@ func TestTransport(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestTransport_BodyNotRewindable verifies that an idempotent request whose
|
||||
// body cannot be replayed (no GetBody) is returned as-is rather than retried
|
||||
// with an empty body or a stale, already-drained response.
|
||||
func TestTransport_BodyNotRewindable(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
io.Copy(io.Discard, req.Body) // a real transport consumes the body
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}))
|
||||
|
||||
// PUT is idempotent (the policy would retry a 503), but with GetBody unset
|
||||
// the body cannot be rewound.
|
||||
req, _ := http.NewRequest(http.MethodPut, "http://example.com", strings.NewReader("data"))
|
||||
req.GetBody = nil
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected the original 503 response, got %v", resp)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("expected exactly 1 call (no rewind retry), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransport_InjectedClock verifies that backoff delays are driven by the
|
||||
// configured clock, so retries are deterministic without real sleeps.
|
||||
func TestTransport_InjectedClock(t *testing.T) {
|
||||
clk := clock.Mock(time.Now())
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(2),
|
||||
WithBackoff(ConstantBackoff(time.Hour)), // would block forever on a real clock
|
||||
withClock(clk),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
done := make(chan struct{})
|
||||
var resp *http.Response
|
||||
var err error
|
||||
go func() {
|
||||
resp, err = rt.RoundTrip(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Drive the backoff via the mock clock. Advancing repeatedly is robust
|
||||
// against the timer being created slightly after the first attempt.
|
||||
for {
|
||||
clk.Advance(time.Hour)
|
||||
select {
|
||||
case <-done:
|
||||
goto finished
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
finished:
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// policyFunc adapts a function into a Policy.
|
||||
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware {
|
||||
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
written bool
|
||||
status int
|
||||
written bool
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
|
||||
@@ -4,17 +4,25 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
|
||||
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
|
||||
const defaultMaxKeys = 1 << 16
|
||||
|
||||
type rateLimitOptions struct {
|
||||
rate float64
|
||||
burst int
|
||||
keyFunc func(r *http.Request) string
|
||||
clock clock.Clock
|
||||
rate float64
|
||||
burst int
|
||||
keyFunc func(r *http.Request) string
|
||||
clock clock.Clock
|
||||
trustedProxies []*net.IPNet
|
||||
maxKeys int
|
||||
}
|
||||
|
||||
// RateLimitOption configures the RateLimit middleware.
|
||||
@@ -31,11 +39,55 @@ func WithBurst(n int) RateLimitOption {
|
||||
}
|
||||
|
||||
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||
// request. By default, the client IP address is used.
|
||||
// request. By default, the client IP from RemoteAddr is used (see
|
||||
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
|
||||
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||
}
|
||||
|
||||
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
|
||||
// whose immediate peer (RemoteAddr) falls within one of the given trusted
|
||||
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
|
||||
// a /32 or /128. When the peer is trusted, the client key is taken from the
|
||||
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
|
||||
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
|
||||
// untrusted), so a typo can never silently widen trust.
|
||||
//
|
||||
// Without this option the middleware never trusts client-supplied forwarding
|
||||
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
|
||||
// spoofed headers.
|
||||
func WithTrustedProxies(cidrs ...string) RateLimitOption {
|
||||
return func(o *rateLimitOptions) {
|
||||
for _, c := range cidrs {
|
||||
if _, ipnet, err := net.ParseCIDR(c); err == nil {
|
||||
o.trustedProxies = append(o.trustedProxies, ipnet)
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(c); ip != nil {
|
||||
bits := 32
|
||||
if ip.To4() == nil {
|
||||
bits = 128
|
||||
}
|
||||
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(bits, bits),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
|
||||
// retained in memory. When exceeded, idle (fully-refilled) buckets are
|
||||
// evicted; active buckets are never dropped. Default is 65536.
|
||||
func WithMaxKeys(n int) RateLimitOption {
|
||||
return func(o *rateLimitOptions) {
|
||||
if n > 0 {
|
||||
o.maxKeys = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// withRateLimitClock sets the clock for testing. Not exported.
|
||||
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.clock = c }
|
||||
@@ -44,86 +96,153 @@ func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||
// RateLimit returns a middleware that limits requests using a per-key token
|
||||
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||
// Requests with a Retry-After header.
|
||||
//
|
||||
// By default the key is the client IP taken from RemoteAddr. Forwarding
|
||||
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
|
||||
// so the limiter cannot be bypassed by spoofing headers.
|
||||
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||
o := &rateLimitOptions{
|
||||
rate: 10,
|
||||
burst: 20,
|
||||
clock: clock.System(),
|
||||
rate: 10,
|
||||
burst: 20,
|
||||
clock: clock.System(),
|
||||
maxKeys: defaultMaxKeys,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
if o.keyFunc == nil {
|
||||
o.keyFunc = clientIP
|
||||
o.keyFunc = o.clientKey
|
||||
}
|
||||
|
||||
var buckets sync.Map
|
||||
lim := &limiter{opts: o}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := o.keyFunc(r)
|
||||
val, _ := buckets.LoadOrStore(key, &bucket{
|
||||
tokens: float64(o.burst),
|
||||
lastTime: o.clock.Now(),
|
||||
})
|
||||
b := val.(*bucket)
|
||||
|
||||
b.mu.Lock()
|
||||
now := o.clock.Now()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
b.tokens += elapsed * o.rate
|
||||
if b.tokens > float64(o.burst) {
|
||||
b.tokens = float64(o.burst)
|
||||
}
|
||||
b.lastTime = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
retryAfter := (1 - b.tokens) / o.rate
|
||||
b.mu.Unlock()
|
||||
if allowed, retryAfter := lim.allow(key); !allowed {
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
b.tokens--
|
||||
b.mu.Unlock()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// limiter holds the per-key token buckets for one RateLimit middleware.
|
||||
type limiter struct {
|
||||
opts *rateLimitOptions
|
||||
buckets sync.Map // key -> *bucket
|
||||
count atomic.Int64
|
||||
sweeping atomic.Bool
|
||||
}
|
||||
|
||||
// allow reports whether a request for key may proceed. When denied it also
|
||||
// returns the suggested Retry-After delay in seconds.
|
||||
func (l *limiter) allow(key string) (bool, float64) {
|
||||
o := l.opts
|
||||
val, loaded := l.buckets.LoadOrStore(key, &bucket{
|
||||
tokens: float64(o.burst),
|
||||
lastTime: o.clock.Now(),
|
||||
})
|
||||
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
|
||||
l.sweep()
|
||||
}
|
||||
b := val.(*bucket)
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
now := o.clock.Now()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
b.tokens += elapsed * o.rate
|
||||
if b.tokens > float64(o.burst) {
|
||||
b.tokens = float64(o.burst)
|
||||
}
|
||||
b.lastTime = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false, (1 - b.tokens) / o.rate
|
||||
}
|
||||
b.tokens--
|
||||
return true, 0
|
||||
}
|
||||
|
||||
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
|
||||
// runs at a time; buckets that still hold a partial limit are preserved so
|
||||
// that eviction can never reset an active client's allowance.
|
||||
func (l *limiter) sweep() {
|
||||
if !l.sweeping.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer l.sweeping.Store(false)
|
||||
|
||||
o := l.opts
|
||||
now := o.clock.Now()
|
||||
l.buckets.Range(func(k, v any) bool {
|
||||
b := v.(*bucket)
|
||||
b.mu.Lock()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
full := b.tokens+elapsed*o.rate >= float64(o.burst)
|
||||
b.mu.Unlock()
|
||||
if full {
|
||||
l.buckets.Delete(k)
|
||||
l.count.Add(-1)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP from the request. It checks
|
||||
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First IP in the comma-separated list is the original client.
|
||||
if i := indexOf(xff, ','); i > 0 {
|
||||
return xff[:i]
|
||||
// clientKey derives the rate-limit key from a request. It uses RemoteAddr by
|
||||
// default and only consults X-Forwarded-For when the peer is a configured
|
||||
// trusted proxy (see WithTrustedProxies).
|
||||
func (o *rateLimitOptions) clientKey(r *http.Request) string {
|
||||
remote := remoteIP(r)
|
||||
if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
|
||||
return remote
|
||||
}
|
||||
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
|
||||
// address that is not itself a trusted proxy — that is the real client.
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
return remote
|
||||
}
|
||||
parts := strings.Split(xff, ",")
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(parts[i])
|
||||
if ip == "" || o.isTrusted(ip) {
|
||||
continue
|
||||
}
|
||||
return xff
|
||||
return ip
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||
return xri
|
||||
return remote
|
||||
}
|
||||
|
||||
func (o *rateLimitOptions) isTrusted(ip string) bool {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
for _, n := range o.trustedProxies {
|
||||
if n.Contains(parsed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
|
||||
// has no port.
|
||||
func remoteIP(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func indexOf(s string, b byte) int {
|
||||
for i := range len(s) {
|
||||
if s[i] == b {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
62
server/middleware_ratelimit_internal_test.go
Normal file
62
server/middleware_ratelimit_internal_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
func newTestRequest(remoteAddr, xff string) *http.Request {
|
||||
r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}}
|
||||
if xff != "" {
|
||||
r.Header.Set("X-Forwarded-For", xff)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled
|
||||
// (idle) buckets while preserving buckets that still hold an active limit, so
|
||||
// memory is bounded without resetting live clients' allowances.
|
||||
func TestLimiterSweepEvictsIdleBuckets(t *testing.T) {
|
||||
clk := clock.Mock(time.Now())
|
||||
o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30}
|
||||
lim := &limiter{opts: o}
|
||||
|
||||
// "idle" makes a single request, then time passes so it refills to full.
|
||||
lim.allow("idle")
|
||||
clk.Advance(10 * time.Second)
|
||||
|
||||
// "active" drains its whole burst at the (advanced) current time.
|
||||
for i := 0; i < 6; i++ {
|
||||
lim.allow("active")
|
||||
}
|
||||
|
||||
lim.sweep()
|
||||
|
||||
if _, ok := lim.buckets.Load("idle"); ok {
|
||||
t.Error("fully-refilled idle bucket was not evicted")
|
||||
}
|
||||
if _, ok := lim.buckets.Load("active"); !ok {
|
||||
t.Error("active bucket with a partial limit was wrongly evicted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a
|
||||
// trusted proxy, independent of the HTTP layer.
|
||||
func TestClientKeyTrustedProxy(t *testing.T) {
|
||||
o := &rateLimitOptions{}
|
||||
WithTrustedProxies("192.168.0.0/16")(o)
|
||||
|
||||
r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10")
|
||||
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||
t.Fatalf("clientKey = %q, want real client 203.0.113.7", got)
|
||||
}
|
||||
|
||||
// Untrusted peer: X-Forwarded-For must be ignored entirely.
|
||||
r = newTestRequest("203.0.113.7:443", "10.0.0.1")
|
||||
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||
t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got)
|
||||
}
|
||||
}
|
||||
@@ -83,28 +83,59 @@ func TestRateLimit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
||||
t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
|
||||
// By default the limiter keys on RemoteAddr only. A spoofed,
|
||||
// per-request X-Forwarded-For must not let a single peer bypass the
|
||||
// limit by minting a fresh bucket each time.
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust limit for 10.0.0.1.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
send := func(xff string) int {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", xff)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
// Same forwarded IP should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
if code := send("10.0.0.1"); code != http.StatusOK {
|
||||
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||
}
|
||||
// Different spoofed XFF, same peer — must still be limited.
|
||||
if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
|
||||
t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
|
||||
}
|
||||
})
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
server.WithTrustedProxies("192.168.0.0/16"),
|
||||
)(okHandler)
|
||||
|
||||
send := func(xff string) int {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", xff)
|
||||
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
|
||||
mw.ServeHTTP(w, req)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
|
||||
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
|
||||
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||
}
|
||||
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
|
||||
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
|
||||
}
|
||||
// A different real client through the same proxy is independent.
|
||||
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
|
||||
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -9,9 +9,14 @@ import (
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// maxRequestIDLen bounds the length of a client-supplied request ID that we
|
||||
// are willing to propagate.
|
||||
const maxRequestIDLen = 128
|
||||
|
||||
// RequestID returns a middleware that assigns a unique request ID to each
|
||||
// request. If the incoming request already has an X-Request-Id header, that
|
||||
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
||||
// request. If the incoming request carries a valid X-Request-Id header, that
|
||||
// value is reused; otherwise (or if the supplied value is empty, too long, or
|
||||
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
|
||||
//
|
||||
// The request ID is stored in the request context (retrieve with
|
||||
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||
@@ -19,7 +24,7 @@ func RequestID() Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.Header.Get("X-Request-Id")
|
||||
if id == "" {
|
||||
if !validRequestID(id) {
|
||||
id = newUUID()
|
||||
}
|
||||
|
||||
@@ -30,6 +35,27 @@ func RequestID() Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// validRequestID reports whether a client-supplied request ID is safe to
|
||||
// propagate: non-empty, within a sane length, and restricted to characters
|
||||
// that cannot forge log lines or split response headers.
|
||||
func validRequestID(id string) bool {
|
||||
if id == "" || len(id) > maxRequestIDLen {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(id); i++ {
|
||||
c := id[i]
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z',
|
||||
c >= 'A' && c <= 'Z',
|
||||
c >= '0' && c <= '9',
|
||||
c == '-', c == '_', c == '.':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||
// string if none is set.
|
||||
func RequestIDFromContext(ctx context.Context) string {
|
||||
|
||||
@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
|
||||
t.Fatalf("expected empty, got %q", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"header injection": "abc\r\nX-Injected: 1",
|
||||
"contains space": "has space",
|
||||
"too long": strings.Repeat("a", 200),
|
||||
}
|
||||
for name, badID := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Request-Id", badID)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if gotID == badID {
|
||||
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
|
||||
}
|
||||
if len(gotID) != 36 {
|
||||
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user