Compare commits

...

7 Commits

Author SHA1 Message Date
f609b12c2f Update CLAUDE.md and README for revised behavior
All checks were successful
CI / test (push) Successful in 38s
Publish / publish (push) Successful in 35s
Document RateLimit's RemoteAddr-by-default keying and WithTrustedProxies, and
that WithMaxResponseBody returns ErrResponseTooLarge rather than truncating.
2026-05-23 13:47:43 +03:00
b5259af73e Honor RoundTripper contract in middleware; validate incoming X-Request-Id
BearerAuth, BasicAuth and DefaultHeaders mutated the caller's request, which
violates the RoundTripper contract and risks races on shared/retried requests;
clone before writing headers (matching RequestID). Validate the incoming
X-Request-Id (length and character set) before propagating it to logs and the
response header, preventing log forging and header splitting from a
client-controlled value.
2026-05-23 13:47:38 +03:00
01478be0dc Replace balancer panic with deferred error; test HealthChecker
A malformed endpoint URL panicked inside Transport, crashing the host app
(often at startup from external config). Capture the parse error and surface
it from the transport on first use instead. Add the previously untested
HealthChecker coverage (initial probe, recovery, Stop termination, unknown
endpoint), raising balancer coverage from ~41% to ~87%. Default the health
probe path to /healthz to match this library's own server.
2026-05-23 13:47:33 +03:00
b07d487e63 Drive circuit breaker state transitions via internal/clock
The Open->HalfOpen promotion used time.Now/time.Since directly, forcing tests
to use real time.Sleep and diverging from the project's clock convention. Add
an unexported withClock option (default clock.System) and replace the real
sleeps in tests with mock-clock Advance, making the transitions deterministic
and the package faster.
2026-05-23 13:47:26 +03:00
43d3ecfba1 Fix retry body replay and jitter panic; drive delays via internal/clock
Gate the retry decision on body rewindability: an idempotent request whose
body cannot be replayed (no GetBody) is now returned as-is instead of looping
with an empty body or surfacing a stale, already-drained response. Guard
ExponentialBackoff against rand.Int64N panicking when delay/2 rounds to zero.
Use internal/clock for inter-attempt delays so retry timing is consistent with
the rest of the codebase and testable without real sleeps.
2026-05-23 13:47:18 +03:00
e8c4577c6f Return ErrResponseTooLarge instead of truncating response body
WithMaxResponseBody wrapped the body in io.LimitedReader, which returns EOF
at the cap, so Bytes/JSON/XML silently returned a truncated body with a nil
error despite the documented contract. Read one byte past the limit and
return the new ErrResponseTooLarge sentinel when exceeded; bodies exactly at
the limit still succeed.
2026-05-23 13:47:13 +03:00
2d4a06e715 Harden RateLimit against X-Forwarded-For spoofing
Key on RemoteAddr by default; honor X-Forwarded-For only when the peer is
a configured trusted proxy (WithTrustedProxies), walking right-to-left to
the first untrusted hop. This closes a trivial rate-limit bypass and the
matching unbounded-bucket DoS via spoofed headers. Add WithMaxKeys with
opportunistic eviction of idle (fully-refilled) buckets to bound memory.
Drop the hand-rolled indexOf in favor of stdlib.
2026-05-23 13:47:08 +03:00
26 changed files with 694 additions and 120 deletions

View File

@@ -24,7 +24,7 @@ go vet ./... # static analysis
- **Client.Close()** stops the health checker goroutine - **Client.Close()** stops the health checker goroutine
- **Client.Patch()** — PATCH method, same pattern as Put/Post - **Client.Patch()** — PATCH method, same pattern as Put/Post
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry - **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM - **WithMaxResponseBody** — caps `resp.Body` reads; returns `ErrResponseTooLarge` (not silent truncation) when exceeded
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header - **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports - **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
@@ -37,7 +37,7 @@ go vet ./... # static analysis
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts - **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers) - **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge` - **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock` - **RateLimit** middleware — per-key token bucket (`sync.Map`), keys on `RemoteAddr` by default; `X-Forwarded-For` is honored only via `WithTrustedProxies`; `WithRate`/`WithBurst`/`WithKeyFunc`/`WithMaxKeys`, uses `internal/clock`, idle buckets evicted to bound memory
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader` - **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503 - **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go` - **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`

View File

@@ -67,7 +67,7 @@ Server middleware is `func(http.Handler) http.Handler`. The `server` package pro
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). | | `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. | | `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. | | `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. | | `server.RateLimit` | Per-key token bucket rate limiting (keys on `RemoteAddr`; `X-Forwarded-For` via `WithTrustedProxies`) with `Retry-After`. |
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. | | `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. | | `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. | | `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
@@ -195,6 +195,9 @@ client := httpx.New(
) )
``` ```
Reading a body that exceeds the limit returns `httpx.ErrResponseTooLarge`
(checkable with `errors.Is`) rather than silently truncating.
## Examples ## Examples
See the [`examples/`](examples/) directory for runnable programs: See the [`examples/`](examples/) directory for runnable programs:

View File

@@ -55,12 +55,19 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
opt(o) opt(o)
} }
// Pre-parse endpoint URLs once at construction time. // Pre-parse endpoint URLs once at construction time. A malformed URL is a
// configuration error: rather than panicking (which would crash the host
// application, often at startup from external config), we capture the
// error and surface it from the transport on first use.
parsed := make(map[string]*url.URL, len(endpoints)) parsed := make(map[string]*url.URL, len(endpoints))
var parseErr error
for _, ep := range endpoints { for _, ep := range endpoints {
u, err := url.Parse(ep.URL) u, err := url.Parse(ep.URL)
if err != nil { if err != nil {
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err)) if parseErr == nil {
parseErr = fmt.Errorf("balancer: invalid endpoint URL %q: %w", ep.URL, err)
}
continue
} }
parsed[ep.URL] = u parsed[ep.URL] = u
} }
@@ -73,6 +80,10 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
return func(next http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
if parseErr != nil {
return nil, parseErr
}
healthy := endpoints healthy := endpoints
if o.healthChecker != nil { if o.healthChecker != nil {
healthy = o.healthChecker.Healthy(endpoints) healthy = o.healthChecker.Healthy(endpoints)

View File

@@ -61,6 +61,27 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
} }
} }
func TestTransport_InvalidEndpointURLReturnsError(t *testing.T) {
base := mockTransport(func(req *http.Request) (*http.Response, error) {
t.Fatal("base transport should not be reached for an invalid endpoint")
return nil, nil
})
// A malformed URL must not panic; the error surfaces on first use.
mw, closer := Transport([]Endpoint{{URL: "://missing-scheme"}})
defer closer.Close()
rt := mw(base)
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/", nil)
if err != nil {
t.Fatal(err)
}
if _, err := rt.RoundTrip(req); err == nil {
t.Fatal("expected an error for invalid endpoint URL, got nil")
}
}
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) { func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
var endpoints []Endpoint var endpoints []Endpoint
base := mockTransport(func(req *http.Request) (*http.Response, error) { base := mockTransport(func(req *http.Request) (*http.Response, error) {

View File

@@ -10,7 +10,7 @@ import (
const ( const (
defaultHealthInterval = 10 * time.Second defaultHealthInterval = 10 * time.Second
defaultHealthPath = "/health" defaultHealthPath = "/healthz"
defaultHealthTimeout = 5 * time.Second defaultHealthTimeout = 5 * time.Second
) )

99
balancer/health_test.go Normal file
View File

@@ -0,0 +1,99 @@
package balancer
import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
func TestHealthChecker_InitialProbeClassifiesEndpoints(t *testing.T) {
healthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer healthy.Close()
unhealthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer unhealthy.Close()
eps := []Endpoint{{URL: healthy.URL}, {URL: unhealthy.URL}}
hc := newHealthChecker()
hc.Start(eps) // runs an initial synchronous probe
defer hc.Stop()
if !hc.IsHealthy(eps[0]) {
t.Errorf("healthy endpoint reported unhealthy")
}
if hc.IsHealthy(eps[1]) {
t.Errorf("unhealthy endpoint reported healthy")
}
got := hc.Healthy(eps)
if len(got) != 1 || got[0].URL != healthy.URL {
t.Errorf("Healthy() = %v, want only %s", got, healthy.URL)
}
}
func TestHealthChecker_DetectsRecovery(t *testing.T) {
var up atomic.Bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if up.Load() {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer srv.Close()
eps := []Endpoint{{URL: srv.URL}}
hc := newHealthChecker()
hc.Start(eps)
defer hc.Stop()
if hc.IsHealthy(eps[0]) {
t.Fatalf("endpoint should start unhealthy")
}
// Recover the backend and force a deterministic re-probe.
up.Store(true)
hc.probe(context.Background(), eps)
if !hc.IsHealthy(eps[0]) {
t.Fatalf("endpoint should be healthy after recovery")
}
}
func TestHealthChecker_StopTerminatesLoop(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
hc := newHealthChecker(WithHealthInterval(time.Millisecond))
hc.Start([]Endpoint{{URL: srv.URL}})
done := make(chan struct{})
go func() {
hc.Stop()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Stop did not return within 2s — loop goroutine leaked")
}
}
func TestHealthChecker_UnknownEndpointIsUnhealthy(t *testing.T) {
hc := newHealthChecker()
if hc.IsHealthy(Endpoint{URL: "http://never-probed.example"}) {
t.Error("unknown endpoint should be reported unhealthy")
}
}

View File

@@ -72,7 +72,7 @@ func (b *Breaker) State() State {
// stateLocked returns the effective state, promoting Open → HalfOpen when the // stateLocked returns the effective state, promoting Open → HalfOpen when the
// open duration has elapsed. Caller must hold b.mu. // open duration has elapsed. Caller must hold b.mu.
func (b *Breaker) stateLocked() State { func (b *Breaker) stateLocked() State {
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration { if b.state == StateOpen && b.opts.clk.Since(b.openedAt) >= b.opts.openDuration {
b.state = StateHalfOpen b.state = StateHalfOpen
b.halfOpenCur = 0 b.halfOpenCur = 0
} }
@@ -142,7 +142,7 @@ func (b *Breaker) record(success bool) {
// tripLocked transitions to the Open state and records the timestamp. // tripLocked transitions to the Open state and records the timestamp.
func (b *Breaker) tripLocked() { func (b *Breaker) tripLocked() {
b.state = StateOpen b.state = StateOpen
b.openedAt = time.Now() b.openedAt = b.opts.clk.Now()
b.halfOpenCur = 0 b.halfOpenCur = 0
} }

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"git.codelab.vc/pkg/httpx/internal/clock"
"git.codelab.vc/pkg/httpx/middleware" "git.codelab.vc/pkg/httpx/middleware"
) )
@@ -80,9 +81,11 @@ func TestBreaker_OpenRejectsRequests(t *testing.T) {
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) { func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
const openDuration = 50 * time.Millisecond const openDuration = 50 * time.Millisecond
clk := clock.Mock(time.Now())
b := NewBreaker( b := NewBreaker(
WithFailureThreshold(1), WithFailureThreshold(1),
WithOpenDuration(openDuration), WithOpenDuration(openDuration),
withClock(clk),
) )
// Trip the breaker. // Trip the breaker.
@@ -96,8 +99,8 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
t.Fatalf("state = %v, want %v", s, StateOpen) t.Fatalf("state = %v, want %v", s, StateOpen)
} }
// Wait for the open duration to elapse. // Advance past the open duration.
time.Sleep(openDuration + 10*time.Millisecond) clk.Advance(openDuration + time.Millisecond)
if s := b.State(); s != StateHalfOpen { if s := b.State(); s != StateHalfOpen {
t.Fatalf("state = %v, want %v", s, StateHalfOpen) t.Fatalf("state = %v, want %v", s, StateHalfOpen)
@@ -106,9 +109,11 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) { func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
const openDuration = 50 * time.Millisecond const openDuration = 50 * time.Millisecond
clk := clock.Mock(time.Now())
b := NewBreaker( b := NewBreaker(
WithFailureThreshold(1), WithFailureThreshold(1),
WithOpenDuration(openDuration), WithOpenDuration(openDuration),
withClock(clk),
) )
// Trip the breaker. // Trip the breaker.
@@ -118,8 +123,8 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
} }
done(false) done(false)
// Wait for half-open. // Advance into half-open.
time.Sleep(openDuration + 10*time.Millisecond) clk.Advance(openDuration + time.Millisecond)
// A successful request in half-open should close the breaker. // A successful request in half-open should close the breaker.
done, err = b.Allow() done, err = b.Allow()
@@ -135,9 +140,11 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) { func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
const openDuration = 50 * time.Millisecond const openDuration = 50 * time.Millisecond
clk := clock.Mock(time.Now())
b := NewBreaker( b := NewBreaker(
WithFailureThreshold(1), WithFailureThreshold(1),
WithOpenDuration(openDuration), WithOpenDuration(openDuration),
withClock(clk),
) )
// Trip the breaker. // Trip the breaker.
@@ -147,8 +154,8 @@ func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
} }
done(false) done(false)
// Wait for half-open. // Advance into half-open.
time.Sleep(openDuration + 10*time.Millisecond) clk.Advance(openDuration + time.Millisecond)
// A failed request in half-open should re-open the breaker. // A failed request in half-open should re-open the breaker.
done, err = b.Allow() done, err = b.Allow()

View File

@@ -1,11 +1,16 @@
package circuitbreaker package circuitbreaker
import "time" import (
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
type options struct { type options struct {
failureThreshold int // consecutive failures to trip failureThreshold int // consecutive failures to trip
openDuration time.Duration // how long to stay open before half-open openDuration time.Duration // how long to stay open before half-open
halfOpenMax int // max concurrent requests in half-open halfOpenMax int // max concurrent requests in half-open
clk clock.Clock // time source (real by default)
} }
func defaults() options { func defaults() options {
@@ -13,12 +18,23 @@ func defaults() options {
failureThreshold: 5, failureThreshold: 5,
openDuration: 30 * time.Second, openDuration: 30 * time.Second,
halfOpenMax: 1, halfOpenMax: 1,
clk: clock.System(),
} }
} }
// Option configures a Breaker. // Option configures a Breaker.
type Option func(*options) type Option func(*options)
// withClock sets the clock used for state-transition timing. Unexported; for
// deterministic tests.
func withClock(c clock.Clock) Option {
return func(o *options) {
if c != nil {
o.clk = c
}
}
}
// WithFailureThreshold sets the number of consecutive failures required to // WithFailureThreshold sets the number of consecutive failures required to
// trip the breaker from Closed to Open. Default is 5. // trip the breaker from Closed to Open. Default is 5.
func WithFailureThreshold(n int) Option { func WithFailureThreshold(n int) Option {

View File

@@ -102,9 +102,12 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
} }
if c.maxResponseBody > 0 { if c.maxResponseBody > 0 {
// Read one byte past the limit so we can distinguish "exactly at the
// limit" (allowed) from "exceeds the limit" (ErrResponseTooLarge).
resp.Body = &limitedReadCloser{ resp.Body = &limitedReadCloser{
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody}, r: io.LimitReader(resp.Body, c.maxResponseBody+1),
C: resp.Body, c: resp.Body,
limit: c.maxResponseBody,
} }
} }

View File

@@ -89,8 +89,8 @@ func WithBalancer(opts ...balancer.Option) Option {
// WithMaxResponseBody limits the number of bytes read from response bodies // WithMaxResponseBody limits the number of bytes read from response bodies
// by Response.Bytes (and by extension String, JSON, XML). If the response // by Response.Bytes (and by extension String, JSON, XML). If the response
// body exceeds n bytes, reading stops and returns an error. // body exceeds n bytes, reading returns ErrResponseTooLarge instead of
// A value of 0 means no limit (the default). // silently truncating. A value of 0 means no limit (the default).
func WithMaxResponseBody(n int64) Option { func WithMaxResponseBody(n int64) Option {
return func(o *clientOptions) { o.maxResponseBody = n } return func(o *clientOptions) { o.maxResponseBody = n }
} }

View File

@@ -1,6 +1,7 @@
package httpx package httpx
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -18,6 +19,11 @@ var (
ErrNoHealthy = balancer.ErrNoHealthy ErrNoHealthy = balancer.ErrNoHealthy
) )
// ErrResponseTooLarge is returned when reading a response body that exceeds
// the limit configured via WithMaxResponseBody. Any bytes read up to the
// limit are returned alongside the error.
var ErrResponseTooLarge = errors.New("httpx: response body exceeds configured limit")
// Error provides structured error information for failed HTTP operations. // Error provides structured error information for failed HTTP operations.
type Error struct { type Error struct {
// Op is the operation that failed (e.g. "Get", "Do"). // Op is the operation that failed (e.g. "Get", "Do").

View File

@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
if err != nil { if err != nil {
return nil, err return nil, err
} }
// RoundTrippers must not mutate the caller's request; clone before
// setting headers (req.Clone is shallow + a header copy).
req = req.Clone(req.Context())
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
return next.RoundTrip(req) return next.RoundTrip(req)
}) })
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
func BasicAuth(username, password string) Middleware { func BasicAuth(username, password string) Middleware {
return func(next http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.SetBasicAuth(username, password) req.SetBasicAuth(username, password)
return next.RoundTrip(req) return next.RoundTrip(req)
}) })

View File

@@ -7,10 +7,17 @@ import "net/http"
func DefaultHeaders(headers http.Header) Middleware { func DefaultHeaders(headers http.Header) Middleware {
return func(next http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
// Clone lazily on the first header we actually add, so that
// RoundTrippers never mutate the caller's request.
cloned := false
for key, values := range headers { for key, values := range headers {
if req.Header.Get(key) != "" { if req.Header.Get(key) != "" {
continue continue
} }
if !cloned {
req = req.Clone(req.Context())
cloned = true
}
for _, v := range values { for _, v := range values {
req.Header.Add(key, v) req.Header.Add(key, v)
} }

View File

@@ -98,17 +98,26 @@ func (r *Response) BodyReader() io.Reader {
return r.Body return r.Body
} }
// limitedReadCloser wraps an io.LimitedReader with a separate Closer // limitedReadCloser enforces a maximum number of bytes that may be read from
// so the original body can be closed. // a response body. Reading more than limit bytes returns ErrResponseTooLarge
// rather than silently truncating the body. The original body is closed via
// the separate Closer.
type limitedReadCloser struct { type limitedReadCloser struct {
R io.LimitedReader r io.Reader // an io.LimitReader over the original body (limit+1 bytes)
C io.Closer c io.Closer // the original body, for Close
limit int64
read int64
} }
func (l *limitedReadCloser) Read(p []byte) (int, error) { func (l *limitedReadCloser) Read(p []byte) (int, error) {
return l.R.Read(p) n, err := l.r.Read(p)
l.read += int64(n)
if l.read > l.limit {
return n, ErrResponseTooLarge
}
return n, err
} }
func (l *limitedReadCloser) Close() error { func (l *limitedReadCloser) Close() error {
return l.C.Close() return l.c.Close()
} }

View File

@@ -2,6 +2,7 @@ package httpx_test
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -32,13 +33,30 @@ func TestClient_MaxResponseBody(t *testing.T) {
} }
}) })
t.Run("truncates response exceeding limit", func(t *testing.T) { t.Run("returns ErrResponseTooLarge when exceeding limit", func(t *testing.T) {
largeBody := strings.Repeat("x", 1000) largeBody := strings.Repeat("x", 1000)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, largeBody) fmt.Fprint(w, largeBody)
})) }))
defer srv.Close() defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(100))
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err := resp.Bytes(); !errors.Is(err, httpx.ErrResponseTooLarge) {
t.Fatalf("err = %v, want ErrResponseTooLarge", err)
}
})
t.Run("allows body exactly at limit", func(t *testing.T) {
exact := strings.Repeat("x", 100)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, exact)
}))
defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(100)) client := httpx.New(httpx.WithMaxResponseBody(100))
resp, err := client.Get(context.Background(), srv.URL+"/") resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil { if err != nil {
@@ -46,7 +64,7 @@ func TestClient_MaxResponseBody(t *testing.T) {
} }
b, err := resp.Bytes() b, err := resp.Bytes()
if err != nil { if err != nil {
t.Fatalf("reading body: %v", err) t.Fatalf("reading body at exact limit: %v", err)
} }
if len(b) != 100 { if len(b) != 100 {
t.Fatalf("body length = %d, want %d", len(b), 100) t.Fatalf("body length = %d, want %d", len(b), 100)

View File

@@ -44,8 +44,11 @@ func (b *exponentialBackoff) Delay(attempt int) time.Duration {
} }
if b.withJitter { if b.withJitter {
jitter := time.Duration(rand.Int64N(int64(delay / 2))) // Guard against rand.Int64N panicking on a non-positive argument when
delay += jitter // delay is small enough that delay/2 rounds to zero.
if half := int64(delay / 2); half > 0 {
delay += time.Duration(rand.Int64N(half))
}
} }
if delay > b.max { if delay > b.max {

View File

@@ -1,12 +1,17 @@
package retry package retry
import "time" import (
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
type options struct { type options struct {
maxAttempts int // default 3 maxAttempts int // default 3
backoff Backoff // default ExponentialBackoff(100ms, 5s, true) backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
policy Policy // default: defaultPolicy (retry on 5xx and network errors) policy Policy // default: defaultPolicy (retry on 5xx and network errors)
retryAfter bool // default true, respect Retry-After header retryAfter bool // default true, respect Retry-After header
clk clock.Clock // time source for backoff delays (real by default)
} }
// Option configures the retry transport. // Option configures the retry transport.
@@ -18,6 +23,17 @@ func defaults() options {
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true), backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
policy: defaultPolicy{}, policy: defaultPolicy{},
retryAfter: true, retryAfter: true,
clk: clock.System(),
}
}
// withClock sets the clock used for inter-attempt delays. Unexported; for
// deterministic tests.
func withClock(c clock.Clock) Option {
return func(o *options) {
if c != nil {
o.clk = c
}
} }
} }

View File

@@ -37,18 +37,16 @@ func Transport(opts ...Option) middleware.Middleware {
var exhausted bool var exhausted bool
for attempt := range cfg.maxAttempts { for attempt := range cfg.maxAttempts {
// For retries (attempt > 0), restore the request body. // For retries (attempt > 0) the body was consumed by the
if attempt > 0 { // previous attempt; restore it via GetBody. The rewindability
if req.GetBody != nil { // check below guarantees GetBody is set whenever we loop with a
body, bodyErr := req.GetBody() // non-nil body, so this branch is always safe.
if bodyErr != nil { if attempt > 0 && req.GetBody != nil {
return resp, bodyErr body, bodyErr := req.GetBody()
} if bodyErr != nil {
req.Body = body return nil, bodyErr
} else if req.Body != nil {
// Body was consumed and cannot be re-created.
return resp, err
} }
req.Body = body
} }
resp, err = next.RoundTrip(req) resp, err = next.RoundTrip(req)
@@ -64,6 +62,13 @@ func Transport(opts ...Option) middleware.Middleware {
break break
} }
// If the body cannot be rewound, a retry would replay with an
// empty body. Return the current result as-is instead of
// draining it and looping with a corrupted request.
if req.Body != nil && req.GetBody == nil {
break
}
// Compute delay: use backoff or policy delay, whichever is larger. // Compute delay: use backoff or policy delay, whichever is larger.
delay := cfg.backoff.Delay(attempt) delay := cfg.backoff.Delay(attempt)
if policyDelay > delay { if policyDelay > delay {
@@ -84,12 +89,12 @@ func Transport(opts ...Option) middleware.Middleware {
} }
// Wait for the delay or context cancellation. // Wait for the delay or context cancellation.
timer := time.NewTimer(delay) timer := cfg.clk.NewTimer(delay)
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop() timer.Stop()
return nil, req.Context().Err() return nil, req.Context().Err()
case <-timer.C: case <-timer.C():
} }
} }
@@ -118,9 +123,9 @@ func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response,
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusTooManyRequests, // 429 case http.StatusTooManyRequests, // 429
http.StatusBadGateway, // 502 http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503 http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504 http.StatusGatewayTimeout: // 504
return true, 0 return true, 0
} }

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"git.codelab.vc/pkg/httpx/internal/clock"
"git.codelab.vc/pkg/httpx/middleware" "git.codelab.vc/pkg/httpx/middleware"
) )
@@ -229,6 +230,83 @@ func TestTransport(t *testing.T) {
}) })
} }
// TestTransport_BodyNotRewindable verifies that an idempotent request whose
// body cannot be replayed (no GetBody) is returned as-is rather than retried
// with an empty body or a stale, already-drained response.
func TestTransport_BodyNotRewindable(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
io.Copy(io.Discard, req.Body) // a real transport consumes the body
return statusResponse(http.StatusServiceUnavailable), nil
}))
// PUT is idempotent (the policy would retry a 503), but with GetBody unset
// the body cannot be rewound.
req, _ := http.NewRequest(http.MethodPut, "http://example.com", strings.NewReader("data"))
req.GetBody = nil
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected the original 503 response, got %v", resp)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected exactly 1 call (no rewind retry), got %d", got)
}
}
// TestTransport_InjectedClock verifies that backoff delays are driven by the
// configured clock, so retries are deterministic without real sleeps.
func TestTransport_InjectedClock(t *testing.T) {
clk := clock.Mock(time.Now())
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(2),
WithBackoff(ConstantBackoff(time.Hour)), // would block forever on a real clock
withClock(clk),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
done := make(chan struct{})
var resp *http.Response
var err error
go func() {
resp, err = rt.RoundTrip(req)
close(done)
}()
// Drive the backoff via the mock clock. Advancing repeatedly is robust
// against the timer being created slightly after the first attempt.
for {
clk.Advance(time.Hour)
select {
case <-done:
goto finished
case <-time.After(time.Millisecond):
}
}
finished:
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
}
// policyFunc adapts a function into a Policy. // policyFunc adapts a function into a Policy.
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration) type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)

View File

@@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware {
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.). // underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
type statusWriter struct { type statusWriter struct {
http.ResponseWriter http.ResponseWriter
status int status int
written bool written bool
} }
// WriteHeader captures the status code and delegates to the underlying writer. // WriteHeader captures the status code and delegates to the underlying writer.

View File

@@ -4,17 +4,25 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"git.codelab.vc/pkg/httpx/internal/clock" "git.codelab.vc/pkg/httpx/internal/clock"
) )
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
const defaultMaxKeys = 1 << 16
type rateLimitOptions struct { type rateLimitOptions struct {
rate float64 rate float64
burst int burst int
keyFunc func(r *http.Request) string keyFunc func(r *http.Request) string
clock clock.Clock clock clock.Clock
trustedProxies []*net.IPNet
maxKeys int
} }
// RateLimitOption configures the RateLimit middleware. // RateLimitOption configures the RateLimit middleware.
@@ -31,11 +39,55 @@ func WithBurst(n int) RateLimitOption {
} }
// WithKeyFunc sets a custom function to extract the rate-limit key from a // WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP address is used. // request. By default, the client IP from RemoteAddr is used (see
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption { func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn } return func(o *rateLimitOptions) { o.keyFunc = fn }
} }
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
// whose immediate peer (RemoteAddr) falls within one of the given trusted
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
// a /32 or /128. When the peer is trusted, the client key is taken from the
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
// untrusted), so a typo can never silently widen trust.
//
// Without this option the middleware never trusts client-supplied forwarding
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
// spoofed headers.
func WithTrustedProxies(cidrs ...string) RateLimitOption {
return func(o *rateLimitOptions) {
for _, c := range cidrs {
if _, ipnet, err := net.ParseCIDR(c); err == nil {
o.trustedProxies = append(o.trustedProxies, ipnet)
continue
}
if ip := net.ParseIP(c); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
IP: ip,
Mask: net.CIDRMask(bits, bits),
})
}
}
}
}
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
// retained in memory. When exceeded, idle (fully-refilled) buckets are
// evicted; active buckets are never dropped. Default is 65536.
func WithMaxKeys(n int) RateLimitOption {
return func(o *rateLimitOptions) {
if n > 0 {
o.maxKeys = n
}
}
}
// withRateLimitClock sets the clock for testing. Not exported. // withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption { func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c } return func(o *rateLimitOptions) { o.clock = c }
@@ -44,86 +96,153 @@ func withRateLimitClock(c clock.Clock) RateLimitOption {
// RateLimit returns a middleware that limits requests using a per-key token // RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many // bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header. // Requests with a Retry-After header.
//
// By default the key is the client IP taken from RemoteAddr. Forwarding
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
// so the limiter cannot be bypassed by spoofing headers.
func RateLimit(opts ...RateLimitOption) Middleware { func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{ o := &rateLimitOptions{
rate: 10, rate: 10,
burst: 20, burst: 20,
clock: clock.System(), clock: clock.System(),
maxKeys: defaultMaxKeys,
} }
for _, opt := range opts { for _, opt := range opts {
opt(o) opt(o)
} }
if o.keyFunc == nil { if o.keyFunc == nil {
o.keyFunc = clientIP o.keyFunc = o.clientKey
} }
var buckets sync.Map lim := &limiter{opts: o}
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r) key := o.keyFunc(r)
val, _ := buckets.LoadOrStore(key, &bucket{ if allowed, retryAfter := lim.allow(key); !allowed {
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
b := val.(*bucket)
b.mu.Lock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
retryAfter := (1 - b.tokens) / o.rate
b.mu.Unlock()
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1)) w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests) http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return return
} }
b.tokens--
b.mu.Unlock()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
// limiter holds the per-key token buckets for one RateLimit middleware.
type limiter struct {
opts *rateLimitOptions
buckets sync.Map // key -> *bucket
count atomic.Int64
sweeping atomic.Bool
}
// allow reports whether a request for key may proceed. When denied it also
// returns the suggested Retry-After delay in seconds.
func (l *limiter) allow(key string) (bool, float64) {
o := l.opts
val, loaded := l.buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
l.sweep()
}
b := val.(*bucket)
b.mu.Lock()
defer b.mu.Unlock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
return false, (1 - b.tokens) / o.rate
}
b.tokens--
return true, 0
}
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
// runs at a time; buckets that still hold a partial limit are preserved so
// that eviction can never reset an active client's allowance.
func (l *limiter) sweep() {
if !l.sweeping.CompareAndSwap(false, true) {
return
}
defer l.sweeping.Store(false)
o := l.opts
now := o.clock.Now()
l.buckets.Range(func(k, v any) bool {
b := v.(*bucket)
b.mu.Lock()
elapsed := now.Sub(b.lastTime).Seconds()
full := b.tokens+elapsed*o.rate >= float64(o.burst)
b.mu.Unlock()
if full {
l.buckets.Delete(k)
l.count.Add(-1)
}
return true
})
}
type bucket struct { type bucket struct {
mu sync.Mutex mu sync.Mutex
tokens float64 tokens float64
lastTime time.Time lastTime time.Time
} }
// clientIP extracts the client IP from the request. It checks // clientKey derives the rate-limit key from a request. It uses RemoteAddr by
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr. // default and only consults X-Forwarded-For when the peer is a configured
func clientIP(r *http.Request) string { // trusted proxy (see WithTrustedProxies).
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { func (o *rateLimitOptions) clientKey(r *http.Request) string {
// First IP in the comma-separated list is the original client. remote := remoteIP(r)
if i := indexOf(xff, ','); i > 0 { if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
return xff[:i] return remote
}
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
// address that is not itself a trusted proxy — that is the real client.
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remote
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" || o.isTrusted(ip) {
continue
} }
return xff return ip
} }
if xri := r.Header.Get("X-Real-Ip"); xri != "" { return remote
return xri }
func (o *rateLimitOptions) isTrusted(ip string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
} }
for _, n := range o.trustedProxies {
if n.Contains(parsed) {
return true
}
}
return false
}
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
// has no port.
func remoteIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr) host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
return r.RemoteAddr return r.RemoteAddr
} }
return host return host
} }
func indexOf(s string, b byte) int {
for i := range len(s) {
if s[i] == b {
return i
}
}
return -1
}

View File

@@ -0,0 +1,62 @@
package server
import (
"net/http"
"testing"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
func newTestRequest(remoteAddr, xff string) *http.Request {
r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}}
if xff != "" {
r.Header.Set("X-Forwarded-For", xff)
}
return r
}
// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled
// (idle) buckets while preserving buckets that still hold an active limit, so
// memory is bounded without resetting live clients' allowances.
func TestLimiterSweepEvictsIdleBuckets(t *testing.T) {
clk := clock.Mock(time.Now())
o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30}
lim := &limiter{opts: o}
// "idle" makes a single request, then time passes so it refills to full.
lim.allow("idle")
clk.Advance(10 * time.Second)
// "active" drains its whole burst at the (advanced) current time.
for i := 0; i < 6; i++ {
lim.allow("active")
}
lim.sweep()
if _, ok := lim.buckets.Load("idle"); ok {
t.Error("fully-refilled idle bucket was not evicted")
}
if _, ok := lim.buckets.Load("active"); !ok {
t.Error("active bucket with a partial limit was wrongly evicted")
}
}
// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a
// trusted proxy, independent of the HTTP layer.
func TestClientKeyTrustedProxy(t *testing.T) {
o := &rateLimitOptions{}
WithTrustedProxies("192.168.0.0/16")(o)
r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10")
if got := o.clientKey(r); got != "203.0.113.7" {
t.Fatalf("clientKey = %q, want real client 203.0.113.7", got)
}
// Untrusted peer: X-Forwarded-For must be ignored entirely.
r = newTestRequest("203.0.113.7:443", "10.0.0.1")
if got := o.clientKey(r); got != "203.0.113.7" {
t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got)
}
}

View File

@@ -83,28 +83,59 @@ func TestRateLimit(t *testing.T) {
} }
}) })
t.Run("uses X-Forwarded-For", func(t *testing.T) { t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
// By default the limiter keys on RemoteAddr only. A spoofed,
// per-request X-Forwarded-For must not let a single peer bypass the
// limit by minting a fresh bucket each time.
mw := server.RateLimit( mw := server.RateLimit(
server.WithRate(1), server.WithRate(1),
server.WithBurst(1), server.WithBurst(1),
)(okHandler) )(okHandler)
// Exhaust limit for 10.0.0.1. send := func(xff string) int {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") req.Header.Set("X-Forwarded-For", xff)
req.RemoteAddr = "192.168.1.1:1234" req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req) mw.ServeHTTP(w, req)
return w.Code
}
// Same forwarded IP should be rate limited. if code := send("10.0.0.1"); code != http.StatusOK {
w = httptest.NewRecorder() t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
req = httptest.NewRequest(http.MethodGet, "/", nil) }
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1") // Different spoofed XFF, same peer — must still be limited.
req.RemoteAddr = "192.168.1.1:1234" if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
mw.ServeHTTP(w, req) t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
}
})
if w.Code != http.StatusTooManyRequests { t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests) mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
server.WithTrustedProxies("192.168.0.0/16"),
)(okHandler)
send := func(xff string) int {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", xff)
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
mw.ServeHTTP(w, req)
return w.Code
}
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
}
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
}
// A different real client through the same proxy is independent.
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
} }
}) })

View File

@@ -9,9 +9,14 @@ import (
"git.codelab.vc/pkg/httpx/internal/requestid" "git.codelab.vc/pkg/httpx/internal/requestid"
) )
// maxRequestIDLen bounds the length of a client-supplied request ID that we
// are willing to propagate.
const maxRequestIDLen = 128
// RequestID returns a middleware that assigns a unique request ID to each // RequestID returns a middleware that assigns a unique request ID to each
// request. If the incoming request already has an X-Request-Id header, that // request. If the incoming request carries a valid X-Request-Id header, that
// value is used. Otherwise a new UUID v4 is generated via crypto/rand. // value is reused; otherwise (or if the supplied value is empty, too long, or
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
// //
// The request ID is stored in the request context (retrieve with // The request ID is stored in the request context (retrieve with
// RequestIDFromContext) and set on the response X-Request-Id header. // RequestIDFromContext) and set on the response X-Request-Id header.
@@ -19,7 +24,7 @@ func RequestID() Middleware {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-Id") id := r.Header.Get("X-Request-Id")
if id == "" { if !validRequestID(id) {
id = newUUID() id = newUUID()
} }
@@ -30,6 +35,27 @@ func RequestID() Middleware {
} }
} }
// validRequestID reports whether a client-supplied request ID is safe to
// propagate: non-empty, within a sane length, and restricted to characters
// that cannot forge log lines or split response headers.
func validRequestID(id string) bool {
if id == "" || len(id) > maxRequestIDLen {
return false
}
for i := 0; i < len(id); i++ {
c := id[i]
switch {
case c >= 'a' && c <= 'z',
c >= 'A' && c <= 'Z',
c >= '0' && c <= '9',
c == '-', c == '_', c == '.':
default:
return false
}
}
return true
}
// RequestIDFromContext returns the request ID from the context, or an empty // RequestIDFromContext returns the request ID from the context, or an empty
// string if none is set. // string if none is set.
func RequestIDFromContext(ctx context.Context) string { func RequestIDFromContext(ctx context.Context) string {

View File

@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
t.Fatalf("expected empty, got %q", id) t.Fatalf("expected empty, got %q", id)
} }
}) })
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
cases := map[string]string{
"header injection": "abc\r\nX-Injected: 1",
"contains space": "has space",
"too long": strings.Repeat("a", 200),
}
for name, badID := range cases {
t.Run(name, func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", badID)
mw.ServeHTTP(w, req)
if gotID == badID {
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
}
if len(gotID) != 36 {
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
}
})
}
})
} }
func TestRequestID_UUIDFormat(t *testing.T) { func TestRequestID_UUIDFormat(t *testing.T) {