Compare commits
24 Commits
4d47918a66
...
v0.2.0
| Author | SHA1 | Date | |
|---|---|---|---|
| f609b12c2f | |||
| b5259af73e | |||
| 01478be0dc | |||
| b07d487e63 | |||
| 43d3ecfba1 | |||
| e8c4577c6f | |||
| 2d4a06e715 | |||
| b6350185d9 | |||
| 85cdc5e2c9 | |||
| 25beb2f5c2 | |||
| 16ff427c93 | |||
| 138d4b6c6d | |||
| 3aa7536328 | |||
| 89cfc38f0e | |||
| 8a63f142a7 | |||
| 21274c178a | |||
| 49be6f8a7e | |||
| 3395f70abd | |||
| 7a2cef00c3 | |||
| de5bf9a6d9 | |||
| 7f12b0c87a | |||
| 1b322c8c81 | |||
| b40a373675 | |||
| f6384ecbea |
48
.cursorrules
Normal file
48
.cursorrules
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
You are working on `git.codelab.vc/pkg/httpx`, a Go 1.24 HTTP client/server library with zero external dependencies.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- Client middleware: `func(http.RoundTripper) http.RoundTripper` — compose with `middleware.Chain`
|
||||||
|
- Server middleware: `func(http.Handler) http.Handler` — compose with `server.Chain`
|
||||||
|
- All configuration uses functional options pattern (`WithXxx` functions)
|
||||||
|
- Chain order for client: Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
|
||||||
|
|
||||||
|
## Package structure
|
||||||
|
|
||||||
|
- `httpx` (root) — Client, request builders (NewJSONRequest, NewFormRequest), error types
|
||||||
|
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
|
||||||
|
- `retry/` — retry middleware with exponential backoff and Retry-After support
|
||||||
|
- `circuitbreaker/` — per-host circuit breaker (sync.Map of host → Breaker)
|
||||||
|
- `balancer/` — load balancing with health checking (RoundRobin, Weighted, Failover)
|
||||||
|
- `server/` — Server, Router, server middleware (RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout), response helpers (WriteJSON, WriteError)
|
||||||
|
- `internal/requestid/` — shared context key (avoids circular import between server and middleware)
|
||||||
|
- `internal/clock/` — deterministic time for tests
|
||||||
|
|
||||||
|
## Code conventions
|
||||||
|
|
||||||
|
- Zero external dependencies — stdlib only, do not add imports outside the module
|
||||||
|
- Functional options: `type Option func(*options)` with `With<Name>` constructors
|
||||||
|
- Test with stdlib only: `testing`, `httptest`, `net/http`. No testify/gomock
|
||||||
|
- Client test helper: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc`
|
||||||
|
- Server test helper: `httptest.NewRecorder`, `httptest.NewRequest`, `waitForAddr(t, srv)`
|
||||||
|
- Thread safety with `sync.Mutex`, `sync.Map`, or `atomic`
|
||||||
|
- Use `internal/clock` for time-dependent tests, not `time.Now()` directly
|
||||||
|
- Sentinel errors in sub-packages, re-exported as aliases in root package
|
||||||
|
|
||||||
|
## When writing new code
|
||||||
|
|
||||||
|
- Client middleware → file in `middleware/`, return `middleware.Middleware`
|
||||||
|
- Server middleware → file in `server/middleware_<name>.go`, return `server.Middleware`
|
||||||
|
- New option → add field to options struct, create `With<Name>` func, apply in constructor
|
||||||
|
- Do NOT import `server` from `middleware` or vice versa (use `internal/requestid` for shared context)
|
||||||
|
- Client.Close() must be called when using WithEndpoints() (stops health checker goroutine)
|
||||||
|
- Request bodies must have GetBody set for retry — use NewJSONRequest/NewFormRequest
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go build ./... # compile
|
||||||
|
go test ./... # test
|
||||||
|
go test -race ./... # test with race detector
|
||||||
|
go vet ./... # static analysis
|
||||||
|
```
|
||||||
39
.gitea/workflows/publish.yml
Normal file
39
.gitea/workflows/publish.yml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
name: Publish
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags: ["v*"]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.24"
|
||||||
|
|
||||||
|
- name: Vet
|
||||||
|
run: go vet ./...
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: go test -race -count=1 ./...
|
||||||
|
|
||||||
|
- name: Publish to Gitea Package Registry
|
||||||
|
run: |
|
||||||
|
VERSION=${GITHUB_REF#refs/tags/}
|
||||||
|
MODULE=$(go list -m)
|
||||||
|
|
||||||
|
# Create module zip with required prefix: module@version/
|
||||||
|
git archive --format=zip --prefix="${MODULE}@${VERSION}/" HEAD -o module.zip
|
||||||
|
|
||||||
|
# Gitea Go Package Registry API
|
||||||
|
curl -s -f \
|
||||||
|
-X PUT \
|
||||||
|
-H "Authorization: token ${{ secrets.PUBLISH_TOKEN }}" \
|
||||||
|
-H "Content-Type: application/zip" \
|
||||||
|
--data-binary @module.zip \
|
||||||
|
"${{ github.server_url }}/api/packages/pkg/go/upload?module=${MODULE}&version=${VERSION}"
|
||||||
|
env:
|
||||||
|
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}
|
||||||
50
.github/copilot-instructions.md
vendored
Normal file
50
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Copilot instructions — httpx
|
||||||
|
|
||||||
|
## Project
|
||||||
|
|
||||||
|
`git.codelab.vc/pkg/httpx` is a Go 1.24 HTTP client and server library with zero external dependencies.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Client (`httpx` root package)
|
||||||
|
- Middleware type: `func(http.RoundTripper) http.RoundTripper`
|
||||||
|
- Chain assembly (outermost → innermost): Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
|
||||||
|
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
|
||||||
|
- Circuit breaker is per-host (sync.Map of host → Breaker)
|
||||||
|
- Client.Close() required when using WithEndpoints() — stops health checker goroutine
|
||||||
|
|
||||||
|
### Server (`server/` package)
|
||||||
|
- Middleware type: `func(http.Handler) http.Handler`
|
||||||
|
- Router wraps http.ServeMux with groups, prefix routing, Mount for sub-handlers
|
||||||
|
- Defaults() preset: RequestID → Recovery → Logging + production timeouts
|
||||||
|
- Available middleware: RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout
|
||||||
|
- WriteJSON/WriteError for JSON responses
|
||||||
|
|
||||||
|
### Sub-packages
|
||||||
|
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
|
||||||
|
- `retry/` — retry with exponential backoff and Retry-After
|
||||||
|
- `circuitbreaker/` — per-host circuit breaker
|
||||||
|
- `balancer/` — load balancing with health checking
|
||||||
|
- `internal/requestid/` — shared context key between server and middleware
|
||||||
|
- `internal/clock/` — deterministic time for tests
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- All configuration uses functional options (`WithXxx` functions)
|
||||||
|
- Zero external dependencies — do not add requires to go.mod
|
||||||
|
- Tests use stdlib only (testing, httptest) — no testify or gomock
|
||||||
|
- Thread safety with sync.Mutex, sync.Map, or atomic
|
||||||
|
- Client test mock: `mockTransport(fn)` using `middleware.RoundTripperFunc`
|
||||||
|
- Server test helpers: `httptest.NewRecorder`, `httptest.NewRequest`
|
||||||
|
- Do NOT import server from middleware or vice versa — use internal/requestid for shared context
|
||||||
|
- Sentinel errors in sub-packages, re-exported in root package
|
||||||
|
- Use internal/clock for time-dependent tests
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go build ./... # compile
|
||||||
|
go test ./... # test
|
||||||
|
go test -race ./... # test with race detector
|
||||||
|
go vet ./... # static analysis
|
||||||
|
```
|
||||||
139
AGENTS.md
Normal file
139
AGENTS.md
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# AGENTS.md — httpx
|
||||||
|
|
||||||
|
Universal guide for AI coding agents working with this codebase.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
`git.codelab.vc/pkg/httpx` is a Go HTTP toolkit with **zero external dependencies** (Go 1.24, stdlib only). It provides:
|
||||||
|
- A composable HTTP **client** with retry, circuit breaking, load balancing
|
||||||
|
- A production-ready HTTP **server** with routing, middleware, graceful shutdown
|
||||||
|
|
||||||
|
## Package map
|
||||||
|
|
||||||
|
```
|
||||||
|
httpx/ Root — Client, request builders, error types
|
||||||
|
├── middleware/ Client-side middleware (RoundTripper wrappers)
|
||||||
|
├── retry/ Retry middleware with backoff
|
||||||
|
├── circuitbreaker/ Per-host circuit breaker
|
||||||
|
├── balancer/ Client-side load balancing + health checking
|
||||||
|
├── server/ Server, Router, server-side middleware, response helpers
|
||||||
|
└── internal/
|
||||||
|
├── requestid/ Shared context key (avoids circular imports)
|
||||||
|
└── clock/ Deterministic time for testing
|
||||||
|
```
|
||||||
|
|
||||||
|
## Middleware chain architecture
|
||||||
|
|
||||||
|
### Client middleware: `func(http.RoundTripper) http.RoundTripper`
|
||||||
|
|
||||||
|
```
|
||||||
|
Request flow (outermost → innermost):
|
||||||
|
|
||||||
|
Logging
|
||||||
|
└→ User Middlewares
|
||||||
|
└→ Retry
|
||||||
|
└→ Circuit Breaker
|
||||||
|
└→ Balancer
|
||||||
|
└→ Base Transport (http.DefaultTransport)
|
||||||
|
```
|
||||||
|
|
||||||
|
Retry wraps CB+Balancer so each attempt can hit a different endpoint.
|
||||||
|
|
||||||
|
### Server middleware: `func(http.Handler) http.Handler`
|
||||||
|
|
||||||
|
```
|
||||||
|
Chain(A, B, C)(handler) == A(B(C(handler)))
|
||||||
|
A is outermost (sees request first, response last)
|
||||||
|
```
|
||||||
|
|
||||||
|
Defaults() preset: `RequestID → Recovery → Logging`
|
||||||
|
|
||||||
|
## Common tasks
|
||||||
|
|
||||||
|
### Add a client middleware
|
||||||
|
|
||||||
|
1. Create file in `middleware/` (or inline)
|
||||||
|
2. Return `middleware.Middleware` (`func(http.RoundTripper) http.RoundTripper`)
|
||||||
|
3. Use `middleware.RoundTripperFunc` for the inner adapter
|
||||||
|
4. Test with `middleware.RoundTripperFunc` as mock transport
|
||||||
|
|
||||||
|
```go
|
||||||
|
func MyMiddleware() middleware.Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
// before
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
// after
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add a server middleware
|
||||||
|
|
||||||
|
1. Create file in `server/` named `middleware_<name>.go`
|
||||||
|
2. Return `server.Middleware` (`func(http.Handler) http.Handler`)
|
||||||
|
3. Use `server.statusWriter` if you need to capture the response status
|
||||||
|
4. Test with `httptest.NewRecorder` + `httptest.NewRequest`
|
||||||
|
|
||||||
|
```go
|
||||||
|
func MyMiddleware() Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// before
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
// after
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add a route
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.HandleFunc("GET /users/{id}", getUser)
|
||||||
|
r.HandleFunc("POST /users", createUser)
|
||||||
|
|
||||||
|
// Group with prefix + middleware
|
||||||
|
api := r.Group("/api/v1", authMiddleware)
|
||||||
|
api.HandleFunc("GET /items", listItems)
|
||||||
|
|
||||||
|
// Mount sub-handler
|
||||||
|
r.Mount("/health", server.HealthHandler())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add a functional option
|
||||||
|
|
||||||
|
1. Add field to the options struct (`clientOptions` or `serverOptions`)
|
||||||
|
2. Create `With<Name>` function returning `Option`
|
||||||
|
3. Apply the field in the constructor (`New`)
|
||||||
|
|
||||||
|
## Gotchas
|
||||||
|
|
||||||
|
- **Middleware order matters**: Retry wraps CB+Balancer intentionally — each retry attempt can hit a different endpoint and a different circuit breaker
|
||||||
|
- **Circular imports via `internal/`**: Both `server` and `middleware` packages need request ID context. The shared key lives in `internal/requestid` — do NOT import `server` from `middleware` or vice versa
|
||||||
|
- **Client.Close() is required** when using `WithEndpoints()` — the balancer starts a background health checker goroutine that must be stopped
|
||||||
|
- **GetBody for retries**: Request bodies must be replayable. Use `NewJSONRequest`/`NewFormRequest` (they set `GetBody`) or set it manually
|
||||||
|
- **statusWriter.Unwrap()**: Server middleware must not type-assert `http.ResponseWriter` directly — use `http.ResponseController` which calls `Unwrap()` to find `http.Flusher`, `http.Hijacker`, etc.
|
||||||
|
- **No external deps**: This is a zero-dependency library. Do not add any `require` to `go.mod`
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go build ./... # compile
|
||||||
|
go test ./... # all tests
|
||||||
|
go test -race ./... # tests with race detector
|
||||||
|
go test -v -run TestName ./package/ # single test
|
||||||
|
go vet ./... # static analysis
|
||||||
|
```
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- **Functional options** for all configuration (client and server)
|
||||||
|
- **stdlib only** testing — no testify, no gomock
|
||||||
|
- **Thread safety** — use `sync.Mutex`, `sync.Map`, or `atomic` where needed
|
||||||
|
- **`internal/clock`** — use for deterministic time in tests (never `time.Now()` directly in testable code)
|
||||||
|
- **Test helpers**: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server), `waitForAddr(t, srv)` for server integration tests
|
||||||
|
- **Sentinel errors** live in sub-packages, root package re-exports as aliases
|
||||||
16
CLAUDE.md
16
CLAUDE.md
@@ -22,15 +22,25 @@ go vet ./... # static analysis
|
|||||||
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
|
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
|
||||||
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
|
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
|
||||||
- **Client.Close()** stops the health checker goroutine
|
- **Client.Close()** stops the health checker goroutine
|
||||||
|
- **Client.Patch()** — PATCH method, same pattern as Put/Post
|
||||||
|
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
|
||||||
|
- **WithMaxResponseBody** — caps `resp.Body` reads; returns `ErrResponseTooLarge` (not silent truncation) when exceeded
|
||||||
|
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
|
||||||
|
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
|
||||||
|
|
||||||
### Server (`server/`)
|
### Server (`server/`)
|
||||||
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
|
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
|
||||||
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
|
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
|
||||||
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers
|
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers, `WithNotFoundHandler` for custom 404
|
||||||
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
|
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
|
||||||
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
|
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
|
||||||
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
|
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
|
||||||
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
|
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
|
||||||
|
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
|
||||||
|
- **RateLimit** middleware — per-key token bucket (`sync.Map`), keys on `RemoteAddr` by default; `X-Forwarded-For` is honored only via `WithTrustedProxies`; `WithRate`/`WithBurst`/`WithKeyFunc`/`WithMaxKeys`, uses `internal/clock`, idle buckets evicted to bound memory
|
||||||
|
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
|
||||||
|
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
|
||||||
|
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
|
|
||||||
@@ -40,3 +50,7 @@ go vet ./... # static analysis
|
|||||||
- No external test frameworks — stdlib only
|
- No external test frameworks — stdlib only
|
||||||
- Thread safety required (`sync.Mutex`/`atomic`)
|
- Thread safety required (`sync.Mutex`/`atomic`)
|
||||||
- `internal/clock` for deterministic time testing
|
- `internal/clock` for deterministic time testing
|
||||||
|
|
||||||
|
## See also
|
||||||
|
|
||||||
|
- `AGENTS.md` — universal AI agent guide with common tasks, gotchas, and ASCII diagrams
|
||||||
|
|||||||
94
README.md
94
README.md
@@ -1,6 +1,6 @@
|
|||||||
# httpx
|
# httpx
|
||||||
|
|
||||||
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging), health checks, graceful shutdown. stdlib only, zero external deps.
|
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking, request ID propagation, response size limits — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging, CORS, rate limiting, body limits, timeouts), health checks, JSON helpers, graceful shutdown. stdlib only, zero external deps.
|
||||||
|
|
||||||
```
|
```
|
||||||
go get git.codelab.vc/pkg/httpx
|
go get git.codelab.vc/pkg/httpx
|
||||||
@@ -29,6 +29,16 @@ if err != nil {
|
|||||||
|
|
||||||
var user User
|
var user User
|
||||||
resp.JSON(&user)
|
resp.JSON(&user)
|
||||||
|
|
||||||
|
// PATCH request
|
||||||
|
resp, err = client.Patch(ctx, "/users/123", strings.NewReader(`{"name":"updated"}`))
|
||||||
|
|
||||||
|
// Form-encoded request (OAuth, webhooks, etc.)
|
||||||
|
req, _ := httpx.NewFormRequest(ctx, http.MethodPost, "/oauth/token", url.Values{
|
||||||
|
"grant_type": {"client_credentials"},
|
||||||
|
"scope": {"read write"},
|
||||||
|
})
|
||||||
|
resp, err = client.Do(ctx, req)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Packages
|
## Packages
|
||||||
@@ -42,7 +52,7 @@ Client middleware is `func(http.RoundTripper) http.RoundTripper`. Use them with
|
|||||||
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
|
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
|
||||||
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
|
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
|
||||||
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
|
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
|
||||||
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery. |
|
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery, request ID propagation. |
|
||||||
|
|
||||||
### Server
|
### Server
|
||||||
|
|
||||||
@@ -56,6 +66,12 @@ Server middleware is `func(http.Handler) http.Handler`. The `server` package pro
|
|||||||
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
|
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
|
||||||
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
|
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
|
||||||
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
|
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
|
||||||
|
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
|
||||||
|
| `server.RateLimit` | Per-key token bucket rate limiting (keys on `RemoteAddr`; `X-Forwarded-For` via `WithTrustedProxies`) with `Retry-After`. |
|
||||||
|
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
|
||||||
|
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
|
||||||
|
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
|
||||||
|
| `server.WriteError` | JSON error response (`{"error": "..."}`) helper. |
|
||||||
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
|
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
|
||||||
|
|
||||||
The client assembles them in this order:
|
The client assembles them in this order:
|
||||||
@@ -111,9 +127,15 @@ httpClient := &http.Client{
|
|||||||
```go
|
```go
|
||||||
logger := slog.Default()
|
logger := slog.Default()
|
||||||
|
|
||||||
r := server.NewRouter()
|
r := server.NewRouter(
|
||||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
// Custom JSON 404 instead of plain text
|
||||||
w.Write([]byte("world"))
|
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
server.WriteError(w, 404, "not found")
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
|
r.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
server.WriteJSON(w, 200, map[string]string{"message": "world"})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Groups with middleware
|
// Groups with middleware
|
||||||
@@ -125,10 +147,70 @@ r.Mount("/", server.HealthHandler(
|
|||||||
func() error { return db.Ping() },
|
func() error { return db.Ping() },
|
||||||
))
|
))
|
||||||
|
|
||||||
srv := server.New(r, server.Defaults(logger)...)
|
srv := server.New(r,
|
||||||
|
append(server.Defaults(logger),
|
||||||
|
// Protection middleware
|
||||||
|
server.WithMiddleware(
|
||||||
|
server.CORS(
|
||||||
|
server.AllowOrigins("https://app.example.com"),
|
||||||
|
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
|
||||||
|
server.AllowHeaders("Authorization", "Content-Type"),
|
||||||
|
server.MaxAge(3600),
|
||||||
|
),
|
||||||
|
server.RateLimit(
|
||||||
|
server.WithRate(100),
|
||||||
|
server.WithBurst(200),
|
||||||
|
),
|
||||||
|
server.MaxBodySize(1<<20), // 1 MB
|
||||||
|
server.Timeout(30*time.Second),
|
||||||
|
),
|
||||||
|
)...,
|
||||||
|
)
|
||||||
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
|
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Client request ID propagation
|
||||||
|
|
||||||
|
In microservices, forward the incoming request ID to downstream calls:
|
||||||
|
|
||||||
|
```go
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithMiddleware(middleware.RequestID()),
|
||||||
|
)
|
||||||
|
|
||||||
|
// In a server handler — the context already has the request ID from server.RequestID():
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// ID is automatically forwarded as X-Request-Id
|
||||||
|
resp, err := client.Get(r.Context(), "https://downstream/api")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response body limit
|
||||||
|
|
||||||
|
Protect against OOM from unexpectedly large upstream responses:
|
||||||
|
|
||||||
|
```go
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithMaxResponseBody(10 << 20), // 10 MB max
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Reading a body that exceeds the limit returns `httpx.ErrResponseTooLarge`
|
||||||
|
(checkable with `errors.Is`) rather than silently truncating.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
See the [`examples/`](examples/) directory for runnable programs:
|
||||||
|
|
||||||
|
| Example | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| [`basic-client`](examples/basic-client/) | HTTP client with retry, timeout, logging, and response size limit |
|
||||||
|
| [`form-request`](examples/form-request/) | Form-encoded POST requests (OAuth, webhooks) |
|
||||||
|
| [`load-balancing`](examples/load-balancing/) | Multi-endpoint client with weighted balancing, circuit breaker, and health checks |
|
||||||
|
| [`server-basic`](examples/server-basic/) | Server with routing, groups, JSON helpers, health checks, and custom 404 |
|
||||||
|
| [`server-protected`](examples/server-protected/) | Production server with CORS, rate limiting, body limits, and timeouts |
|
||||||
|
| [`request-id-propagation`](examples/request-id-propagation/) | Request ID forwarding between server and client for distributed tracing |
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
Go 1.24+, stdlib only.
|
Go 1.24+, stdlib only.
|
||||||
|
|||||||
@@ -55,12 +55,19 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
|
|||||||
opt(o)
|
opt(o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-parse endpoint URLs once at construction time.
|
// Pre-parse endpoint URLs once at construction time. A malformed URL is a
|
||||||
|
// configuration error: rather than panicking (which would crash the host
|
||||||
|
// application, often at startup from external config), we capture the
|
||||||
|
// error and surface it from the transport on first use.
|
||||||
parsed := make(map[string]*url.URL, len(endpoints))
|
parsed := make(map[string]*url.URL, len(endpoints))
|
||||||
|
var parseErr error
|
||||||
for _, ep := range endpoints {
|
for _, ep := range endpoints {
|
||||||
u, err := url.Parse(ep.URL)
|
u, err := url.Parse(ep.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
|
if parseErr == nil {
|
||||||
|
parseErr = fmt.Errorf("balancer: invalid endpoint URL %q: %w", ep.URL, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
parsed[ep.URL] = u
|
parsed[ep.URL] = u
|
||||||
}
|
}
|
||||||
@@ -73,6 +80,10 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl
|
|||||||
|
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
|
||||||
healthy := endpoints
|
healthy := endpoints
|
||||||
if o.healthChecker != nil {
|
if o.healthChecker != nil {
|
||||||
healthy = o.healthChecker.Healthy(endpoints)
|
healthy = o.healthChecker.Healthy(endpoints)
|
||||||
|
|||||||
@@ -61,6 +61,27 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransport_InvalidEndpointURLReturnsError(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
t.Fatal("base transport should not be reached for an invalid endpoint")
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// A malformed URL must not panic; the error surfaces on first use.
|
||||||
|
mw, closer := Transport([]Endpoint{{URL: "://missing-scheme"}})
|
||||||
|
defer closer.Close()
|
||||||
|
rt := mw(base)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := rt.RoundTrip(req); err == nil {
|
||||||
|
t.Fatal("expected an error for invalid endpoint URL, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
||||||
var endpoints []Endpoint
|
var endpoints []Endpoint
|
||||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
|||||||
34
balancer/doc.go
Normal file
34
balancer/doc.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// Package balancer provides client-side load balancing as HTTP middleware.
|
||||||
|
//
|
||||||
|
// It distributes requests across multiple backend endpoints using pluggable
|
||||||
|
// strategies (round-robin, weighted, failover) with optional health checking.
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// mw, closer := balancer.Transport(
|
||||||
|
// []balancer.Endpoint{
|
||||||
|
// {URL: "http://backend1:8080"},
|
||||||
|
// {URL: "http://backend2:8080"},
|
||||||
|
// },
|
||||||
|
// balancer.WithStrategy(balancer.RoundRobin()),
|
||||||
|
// balancer.WithHealthCheck(5 * time.Second),
|
||||||
|
// )
|
||||||
|
// defer closer.Close()
|
||||||
|
// transport := mw(http.DefaultTransport)
|
||||||
|
//
|
||||||
|
// # Strategies
|
||||||
|
//
|
||||||
|
// - RoundRobin — cycles through healthy endpoints
|
||||||
|
// - Weighted — distributes based on endpoint Weight field
|
||||||
|
// - Failover — prefers primary, falls back to secondaries
|
||||||
|
//
|
||||||
|
// # Health checking
|
||||||
|
//
|
||||||
|
// When enabled, a background goroutine periodically probes each endpoint.
|
||||||
|
// The returned Closer must be closed to stop the health checker goroutine.
|
||||||
|
// In httpx.Client, this is handled by Client.Close().
|
||||||
|
//
|
||||||
|
// # Sentinel errors
|
||||||
|
//
|
||||||
|
// ErrNoHealthy is returned when no healthy endpoints are available.
|
||||||
|
package balancer
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultHealthInterval = 10 * time.Second
|
defaultHealthInterval = 10 * time.Second
|
||||||
defaultHealthPath = "/health"
|
defaultHealthPath = "/healthz"
|
||||||
defaultHealthTimeout = 5 * time.Second
|
defaultHealthTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
99
balancer/health_test.go
Normal file
99
balancer/health_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthChecker_InitialProbeClassifiesEndpoints(t *testing.T) {
|
||||||
|
healthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer healthy.Close()
|
||||||
|
|
||||||
|
unhealthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer unhealthy.Close()
|
||||||
|
|
||||||
|
eps := []Endpoint{{URL: healthy.URL}, {URL: unhealthy.URL}}
|
||||||
|
|
||||||
|
hc := newHealthChecker()
|
||||||
|
hc.Start(eps) // runs an initial synchronous probe
|
||||||
|
defer hc.Stop()
|
||||||
|
|
||||||
|
if !hc.IsHealthy(eps[0]) {
|
||||||
|
t.Errorf("healthy endpoint reported unhealthy")
|
||||||
|
}
|
||||||
|
if hc.IsHealthy(eps[1]) {
|
||||||
|
t.Errorf("unhealthy endpoint reported healthy")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := hc.Healthy(eps)
|
||||||
|
if len(got) != 1 || got[0].URL != healthy.URL {
|
||||||
|
t.Errorf("Healthy() = %v, want only %s", got, healthy.URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthChecker_DetectsRecovery(t *testing.T) {
|
||||||
|
var up atomic.Bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
if up.Load() {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
eps := []Endpoint{{URL: srv.URL}}
|
||||||
|
|
||||||
|
hc := newHealthChecker()
|
||||||
|
hc.Start(eps)
|
||||||
|
defer hc.Stop()
|
||||||
|
|
||||||
|
if hc.IsHealthy(eps[0]) {
|
||||||
|
t.Fatalf("endpoint should start unhealthy")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recover the backend and force a deterministic re-probe.
|
||||||
|
up.Store(true)
|
||||||
|
hc.probe(context.Background(), eps)
|
||||||
|
|
||||||
|
if !hc.IsHealthy(eps[0]) {
|
||||||
|
t.Fatalf("endpoint should be healthy after recovery")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthChecker_StopTerminatesLoop(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
hc := newHealthChecker(WithHealthInterval(time.Millisecond))
|
||||||
|
hc.Start([]Endpoint{{URL: srv.URL}})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
hc.Stop()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Stop did not return within 2s — loop goroutine leaked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthChecker_UnknownEndpointIsUnhealthy(t *testing.T) {
|
||||||
|
hc := newHealthChecker()
|
||||||
|
if hc.IsHealthy(Endpoint{URL: "http://never-probed.example"}) {
|
||||||
|
t.Error("unknown endpoint should be reported unhealthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -72,7 +72,7 @@ func (b *Breaker) State() State {
|
|||||||
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
||||||
// open duration has elapsed. Caller must hold b.mu.
|
// open duration has elapsed. Caller must hold b.mu.
|
||||||
func (b *Breaker) stateLocked() State {
|
func (b *Breaker) stateLocked() State {
|
||||||
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
|
if b.state == StateOpen && b.opts.clk.Since(b.openedAt) >= b.opts.openDuration {
|
||||||
b.state = StateHalfOpen
|
b.state = StateHalfOpen
|
||||||
b.halfOpenCur = 0
|
b.halfOpenCur = 0
|
||||||
}
|
}
|
||||||
@@ -142,7 +142,7 @@ func (b *Breaker) record(success bool) {
|
|||||||
// tripLocked transitions to the Open state and records the timestamp.
|
// tripLocked transitions to the Open state and records the timestamp.
|
||||||
func (b *Breaker) tripLocked() {
|
func (b *Breaker) tripLocked() {
|
||||||
b.state = StateOpen
|
b.state = StateOpen
|
||||||
b.openedAt = time.Now()
|
b.openedAt = b.opts.clk.Now()
|
||||||
b.halfOpenCur = 0
|
b.halfOpenCur = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
"git.codelab.vc/pkg/httpx/middleware"
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,9 +81,11 @@ func TestBreaker_OpenRejectsRequests(t *testing.T) {
|
|||||||
|
|
||||||
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||||
const openDuration = 50 * time.Millisecond
|
const openDuration = 50 * time.Millisecond
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
b := NewBreaker(
|
b := NewBreaker(
|
||||||
WithFailureThreshold(1),
|
WithFailureThreshold(1),
|
||||||
WithOpenDuration(openDuration),
|
WithOpenDuration(openDuration),
|
||||||
|
withClock(clk),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Trip the breaker.
|
// Trip the breaker.
|
||||||
@@ -96,8 +99,8 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
|||||||
t.Fatalf("state = %v, want %v", s, StateOpen)
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for the open duration to elapse.
|
// Advance past the open duration.
|
||||||
time.Sleep(openDuration + 10*time.Millisecond)
|
clk.Advance(openDuration + time.Millisecond)
|
||||||
|
|
||||||
if s := b.State(); s != StateHalfOpen {
|
if s := b.State(); s != StateHalfOpen {
|
||||||
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
||||||
@@ -106,9 +109,11 @@ func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
|||||||
|
|
||||||
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||||
const openDuration = 50 * time.Millisecond
|
const openDuration = 50 * time.Millisecond
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
b := NewBreaker(
|
b := NewBreaker(
|
||||||
WithFailureThreshold(1),
|
WithFailureThreshold(1),
|
||||||
WithOpenDuration(openDuration),
|
WithOpenDuration(openDuration),
|
||||||
|
withClock(clk),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Trip the breaker.
|
// Trip the breaker.
|
||||||
@@ -118,8 +123,8 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
done(false)
|
done(false)
|
||||||
|
|
||||||
// Wait for half-open.
|
// Advance into half-open.
|
||||||
time.Sleep(openDuration + 10*time.Millisecond)
|
clk.Advance(openDuration + time.Millisecond)
|
||||||
|
|
||||||
// A successful request in half-open should close the breaker.
|
// A successful request in half-open should close the breaker.
|
||||||
done, err = b.Allow()
|
done, err = b.Allow()
|
||||||
@@ -135,9 +140,11 @@ func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
|||||||
|
|
||||||
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||||
const openDuration = 50 * time.Millisecond
|
const openDuration = 50 * time.Millisecond
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
b := NewBreaker(
|
b := NewBreaker(
|
||||||
WithFailureThreshold(1),
|
WithFailureThreshold(1),
|
||||||
WithOpenDuration(openDuration),
|
WithOpenDuration(openDuration),
|
||||||
|
withClock(clk),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Trip the breaker.
|
// Trip the breaker.
|
||||||
@@ -147,8 +154,8 @@ func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
done(false)
|
done(false)
|
||||||
|
|
||||||
// Wait for half-open.
|
// Advance into half-open.
|
||||||
time.Sleep(openDuration + 10*time.Millisecond)
|
clk.Advance(openDuration + time.Millisecond)
|
||||||
|
|
||||||
// A failed request in half-open should re-open the breaker.
|
// A failed request in half-open should re-open the breaker.
|
||||||
done, err = b.Allow()
|
done, err = b.Allow()
|
||||||
|
|||||||
27
circuitbreaker/doc.go
Normal file
27
circuitbreaker/doc.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Package circuitbreaker provides a per-host circuit breaker as HTTP middleware.
|
||||||
|
//
|
||||||
|
// The circuit breaker monitors request failures and temporarily blocks requests
|
||||||
|
// to unhealthy hosts, allowing them time to recover before retrying.
|
||||||
|
//
|
||||||
|
// # State machine
|
||||||
|
//
|
||||||
|
// - Closed — normal operation, requests pass through
|
||||||
|
// - Open — too many failures, requests are rejected with ErrCircuitOpen
|
||||||
|
// - HalfOpen — after a cooldown period, one probe request is allowed through
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// mw := circuitbreaker.Transport(
|
||||||
|
// circuitbreaker.WithThreshold(5),
|
||||||
|
// circuitbreaker.WithTimeout(30 * time.Second),
|
||||||
|
// )
|
||||||
|
// transport := mw(http.DefaultTransport)
|
||||||
|
//
|
||||||
|
// The circuit breaker is per-host: each unique request host gets its own
|
||||||
|
// independent breaker state machine stored in a sync.Map.
|
||||||
|
//
|
||||||
|
// # Sentinel errors
|
||||||
|
//
|
||||||
|
// ErrCircuitOpen is returned when a request is rejected because the circuit
|
||||||
|
// is in the Open state.
|
||||||
|
package circuitbreaker
|
||||||
@@ -1,11 +1,16 @@
|
|||||||
package circuitbreaker
|
package circuitbreaker
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
failureThreshold int // consecutive failures to trip
|
failureThreshold int // consecutive failures to trip
|
||||||
openDuration time.Duration // how long to stay open before half-open
|
openDuration time.Duration // how long to stay open before half-open
|
||||||
halfOpenMax int // max concurrent requests in half-open
|
halfOpenMax int // max concurrent requests in half-open
|
||||||
|
clk clock.Clock // time source (real by default)
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaults() options {
|
func defaults() options {
|
||||||
@@ -13,12 +18,23 @@ func defaults() options {
|
|||||||
failureThreshold: 5,
|
failureThreshold: 5,
|
||||||
openDuration: 30 * time.Second,
|
openDuration: 30 * time.Second,
|
||||||
halfOpenMax: 1,
|
halfOpenMax: 1,
|
||||||
|
clk: clock.System(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option configures a Breaker.
|
// Option configures a Breaker.
|
||||||
type Option func(*options)
|
type Option func(*options)
|
||||||
|
|
||||||
|
// withClock sets the clock used for state-transition timing. Unexported; for
|
||||||
|
// deterministic tests.
|
||||||
|
func withClock(c clock.Clock) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if c != nil {
|
||||||
|
o.clk = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithFailureThreshold sets the number of consecutive failures required to
|
// WithFailureThreshold sets the number of consecutive failures required to
|
||||||
// trip the breaker from Closed to Open. Default is 5.
|
// trip the breaker from Closed to Open. Default is 5.
|
||||||
func WithFailureThreshold(n int) Option {
|
func WithFailureThreshold(n int) Option {
|
||||||
|
|||||||
27
client.go
27
client.go
@@ -20,6 +20,7 @@ type Client struct {
|
|||||||
baseURL string
|
baseURL string
|
||||||
errorMapper ErrorMapper
|
errorMapper ErrorMapper
|
||||||
balancerCloser *balancer.Closer
|
balancerCloser *balancer.Closer
|
||||||
|
maxResponseBody int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Client with the given options.
|
// New creates a new Client with the given options.
|
||||||
@@ -77,9 +78,10 @@ func New(opts ...Option) *Client {
|
|||||||
Transport: rt,
|
Transport: rt,
|
||||||
Timeout: o.timeout,
|
Timeout: o.timeout,
|
||||||
},
|
},
|
||||||
baseURL: o.baseURL,
|
baseURL: o.baseURL,
|
||||||
errorMapper: o.errorMapper,
|
errorMapper: o.errorMapper,
|
||||||
balancerCloser: balancerCloser,
|
balancerCloser: balancerCloser,
|
||||||
|
maxResponseBody: o.maxResponseBody,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,6 +101,16 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.maxResponseBody > 0 {
|
||||||
|
// Read one byte past the limit so we can distinguish "exactly at the
|
||||||
|
// limit" (allowed) from "exceeds the limit" (ErrResponseTooLarge).
|
||||||
|
resp.Body = &limitedReadCloser{
|
||||||
|
r: io.LimitReader(resp.Body, c.maxResponseBody+1),
|
||||||
|
c: resp.Body,
|
||||||
|
limit: c.maxResponseBody,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
r := newResponse(resp)
|
r := newResponse(resp)
|
||||||
|
|
||||||
if c.errorMapper != nil {
|
if c.errorMapper != nil {
|
||||||
@@ -142,6 +154,15 @@ func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response
|
|||||||
return c.Do(ctx, req)
|
return c.Do(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Patch performs a PATCH request to the given URL with the given body.
|
||||||
|
func (c *Client) Patch(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.Do(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
// Delete performs a DELETE request to the given URL.
|
// Delete performs a DELETE request to the given URL.
|
||||||
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
||||||
|
|||||||
@@ -12,18 +12,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type clientOptions struct {
|
type clientOptions struct {
|
||||||
baseURL string
|
baseURL string
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
transport http.RoundTripper
|
transport http.RoundTripper
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
errorMapper ErrorMapper
|
errorMapper ErrorMapper
|
||||||
middlewares []middleware.Middleware
|
middlewares []middleware.Middleware
|
||||||
retryOpts []retry.Option
|
retryOpts []retry.Option
|
||||||
enableRetry bool
|
enableRetry bool
|
||||||
cbOpts []circuitbreaker.Option
|
cbOpts []circuitbreaker.Option
|
||||||
enableCB bool
|
enableCB bool
|
||||||
endpoints []balancer.Endpoint
|
endpoints []balancer.Endpoint
|
||||||
balancerOpts []balancer.Option
|
balancerOpts []balancer.Option
|
||||||
|
maxResponseBody int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option configures a Client.
|
// Option configures a Client.
|
||||||
@@ -85,3 +86,11 @@ func WithEndpoints(eps ...balancer.Endpoint) Option {
|
|||||||
func WithBalancer(opts ...balancer.Option) Option {
|
func WithBalancer(opts ...balancer.Option) Option {
|
||||||
return func(o *clientOptions) { o.balancerOpts = opts }
|
return func(o *clientOptions) { o.balancerOpts = opts }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMaxResponseBody limits the number of bytes read from response bodies
|
||||||
|
// by Response.Bytes (and by extension String, JSON, XML). If the response
|
||||||
|
// body exceeds n bytes, reading returns ErrResponseTooLarge instead of
|
||||||
|
// silently truncating. A value of 0 means no limit (the default).
|
||||||
|
func WithMaxResponseBody(n int64) Option {
|
||||||
|
return func(o *clientOptions) { o.maxResponseBody = n }
|
||||||
|
}
|
||||||
|
|||||||
45
client_patch_test.go
Normal file
45
client_patch_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_Patch(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPatch {
|
||||||
|
t.Errorf("expected PATCH, got %s", r.Method)
|
||||||
|
}
|
||||||
|
b, _ := io.ReadAll(r.Body)
|
||||||
|
if string(b) != `{"name":"updated"}` {
|
||||||
|
t.Errorf("expected body %q, got %q", `{"name":"updated"}`, string(b))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, "patched")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New()
|
||||||
|
resp, err := client.Patch(context.Background(), srv.URL+"/item/1", strings.NewReader(`{"name":"updated"}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if body != "patched" {
|
||||||
|
t.Errorf("expected body %q, got %q", "patched", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
39
doc.go
Normal file
39
doc.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
// Package httpx provides a high-level HTTP client with composable middleware
|
||||||
|
// for retry, circuit breaking, load balancing, structured logging, and more.
|
||||||
|
//
|
||||||
|
// The client is configured via functional options and assembled as a middleware
|
||||||
|
// chain around a standard http.RoundTripper:
|
||||||
|
//
|
||||||
|
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport
|
||||||
|
//
|
||||||
|
// # Quick start
|
||||||
|
//
|
||||||
|
// client := httpx.New(
|
||||||
|
// httpx.WithBaseURL("https://api.example.com"),
|
||||||
|
// httpx.WithTimeout(10 * time.Second),
|
||||||
|
// httpx.WithRetry(),
|
||||||
|
// httpx.WithCircuitBreaker(),
|
||||||
|
// )
|
||||||
|
// defer client.Close()
|
||||||
|
//
|
||||||
|
// resp, err := client.Get(ctx, "/users/1")
|
||||||
|
//
|
||||||
|
// # Request builders
|
||||||
|
//
|
||||||
|
// NewJSONRequest and NewFormRequest create requests with appropriate
|
||||||
|
// Content-Type headers and GetBody set for retry compatibility.
|
||||||
|
//
|
||||||
|
// # Error handling
|
||||||
|
//
|
||||||
|
// Failed requests return *httpx.Error with structured fields (Op, URL,
|
||||||
|
// StatusCode). Sentinel errors ErrRetryExhausted, ErrCircuitOpen, and
|
||||||
|
// ErrNoHealthy can be checked with errors.Is.
|
||||||
|
//
|
||||||
|
// # Sub-packages
|
||||||
|
//
|
||||||
|
// - middleware — client-side middleware (logging, auth, headers, recovery, request ID)
|
||||||
|
// - retry — configurable retry with backoff and Retry-After support
|
||||||
|
// - circuitbreaker — per-host circuit breaker (closed → open → half-open)
|
||||||
|
// - balancer — client-side load balancing with health checking
|
||||||
|
// - server — production HTTP server with router, middleware, and graceful shutdown
|
||||||
|
package httpx
|
||||||
6
error.go
6
error.go
@@ -1,6 +1,7 @@
|
|||||||
package httpx
|
package httpx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -18,6 +19,11 @@ var (
|
|||||||
ErrNoHealthy = balancer.ErrNoHealthy
|
ErrNoHealthy = balancer.ErrNoHealthy
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrResponseTooLarge is returned when reading a response body that exceeds
|
||||||
|
// the limit configured via WithMaxResponseBody. Any bytes read up to the
|
||||||
|
// limit are returned alongside the error.
|
||||||
|
var ErrResponseTooLarge = errors.New("httpx: response body exceeds configured limit")
|
||||||
|
|
||||||
// Error provides structured error information for failed HTTP operations.
|
// Error provides structured error information for failed HTTP operations.
|
||||||
type Error struct {
|
type Error struct {
|
||||||
// Op is the operation that failed (e.g. "Get", "Do").
|
// Op is the operation that failed (e.g. "Get", "Do").
|
||||||
|
|||||||
67
examples/basic-client/main.go
Normal file
67
examples/basic-client/main.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Basic HTTP client with retry, timeout, and structured logging.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithBaseURL("https://httpbin.org"),
|
||||||
|
httpx.WithTimeout(10*time.Second),
|
||||||
|
httpx.WithRetry(
|
||||||
|
retry.WithMaxAttempts(3),
|
||||||
|
retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 2*time.Second, true)),
|
||||||
|
),
|
||||||
|
httpx.WithMiddleware(
|
||||||
|
middleware.UserAgent("httpx-example/1.0"),
|
||||||
|
),
|
||||||
|
httpx.WithMaxResponseBody(1<<20), // 1 MB limit
|
||||||
|
httpx.WithLogger(slog.Default()),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// GET request.
|
||||||
|
resp, err := client.Get(ctx, "/get")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("GET /get → %d (%d bytes)\n", resp.StatusCode, len(body))
|
||||||
|
|
||||||
|
// POST with JSON body.
|
||||||
|
type payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := httpx.NewJSONRequest(ctx, "POST", "/post", payload{
|
||||||
|
Name: "Alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = client.Do(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Printf("POST /post → %d\n", resp.StatusCode)
|
||||||
|
}
|
||||||
41
examples/form-request/main.go
Normal file
41
examples/form-request/main.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
// Demonstrates form-encoded requests for OAuth token endpoints and similar APIs.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithBaseURL("https://httpbin.org"),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Form-encoded POST, common for OAuth token endpoints.
|
||||||
|
req, err := httpx.NewFormRequest(ctx, http.MethodPost, "/post", url.Values{
|
||||||
|
"grant_type": {"client_credentials"},
|
||||||
|
"client_id": {"my-app"},
|
||||||
|
"client_secret": {"secret"},
|
||||||
|
"scope": {"read write"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.String()
|
||||||
|
fmt.Printf("Status: %d\nBody: %s\n", resp.StatusCode, body)
|
||||||
|
}
|
||||||
54
examples/load-balancing/main.go
Normal file
54
examples/load-balancing/main.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Demonstrates load balancing across multiple backend endpoints with
|
||||||
|
// circuit breaking and health-checked failover.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithEndpoints(
|
||||||
|
balancer.Endpoint{URL: "http://localhost:8081", Weight: 3},
|
||||||
|
balancer.Endpoint{URL: "http://localhost:8082", Weight: 1},
|
||||||
|
),
|
||||||
|
httpx.WithBalancer(
|
||||||
|
balancer.WithStrategy(balancer.WeightedRandom()),
|
||||||
|
balancer.WithHealthCheck(
|
||||||
|
balancer.WithHealthInterval(5*time.Second),
|
||||||
|
balancer.WithHealthPath("/healthz"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
httpx.WithCircuitBreaker(
|
||||||
|
circuitbreaker.WithFailureThreshold(5),
|
||||||
|
circuitbreaker.WithOpenDuration(30*time.Second),
|
||||||
|
),
|
||||||
|
httpx.WithRetry(
|
||||||
|
retry.WithMaxAttempts(3),
|
||||||
|
),
|
||||||
|
httpx.WithLogger(slog.Default()),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := range 5 {
|
||||||
|
resp, err := client.Get(ctx, fmt.Sprintf("/api/item/%d", i))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("request %d failed: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := resp.String()
|
||||||
|
fmt.Printf("request %d → %d: %s\n", i, resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
54
examples/request-id-propagation/main.go
Normal file
54
examples/request-id-propagation/main.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Demonstrates request ID propagation between server and client.
|
||||||
|
// The server assigns a request ID to incoming requests, and the client
|
||||||
|
// middleware forwards it to downstream services via X-Request-Id header.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger := slog.Default()
|
||||||
|
|
||||||
|
// Client that propagates request IDs from context.
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithBaseURL("http://localhost:9090"),
|
||||||
|
httpx.WithMiddleware(
|
||||||
|
middleware.RequestID(), // Picks up ID from context, sets X-Request-Id.
|
||||||
|
),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
r := server.NewRouter()
|
||||||
|
|
||||||
|
r.HandleFunc("GET /proxy", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// The request ID is in r.Context() thanks to server.RequestID().
|
||||||
|
id := server.RequestIDFromContext(r.Context())
|
||||||
|
logger.Info("handling request", "request_id", id)
|
||||||
|
|
||||||
|
// Client automatically forwards the request ID to downstream.
|
||||||
|
resp, err := client.Get(r.Context(), "/downstream")
|
||||||
|
if err != nil {
|
||||||
|
server.WriteError(w, http.StatusBadGateway, fmt.Sprintf("downstream error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.String()
|
||||||
|
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"request_id": id,
|
||||||
|
"downstream_response": body,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Server with RequestID middleware that assigns IDs to incoming requests.
|
||||||
|
srv := server.New(r, server.Defaults(logger)...)
|
||||||
|
log.Fatal(srv.ListenAndServe())
|
||||||
|
}
|
||||||
54
examples/server-basic/main.go
Normal file
54
examples/server-basic/main.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Basic HTTP server with routing, middleware, health checks, and graceful shutdown.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger := slog.Default()
|
||||||
|
|
||||||
|
r := server.NewRouter(
|
||||||
|
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
server.WriteError(w, http.StatusNotFound, "resource not found")
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Public endpoints.
|
||||||
|
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"message": "Hello, World!",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// API group with shared prefix.
|
||||||
|
api := r.Group("/api/v1")
|
||||||
|
api.HandleFunc("GET /users/{id}", getUser)
|
||||||
|
api.HandleFunc("POST /users", createUser)
|
||||||
|
|
||||||
|
// Health checks.
|
||||||
|
r.Mount("/", server.HealthHandler())
|
||||||
|
|
||||||
|
// Server with production defaults (RequestID → Recovery → Logging).
|
||||||
|
srv := server.New(r, server.Defaults(logger)...)
|
||||||
|
log.Fatal(srv.ListenAndServe())
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUser(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := r.PathValue("id")
|
||||||
|
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"id": id,
|
||||||
|
"name": "Alice",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUser(w http.ResponseWriter, r *http.Request) {
|
||||||
|
server.WriteJSON(w, http.StatusCreated, map[string]any{
|
||||||
|
"id": 1,
|
||||||
|
"message": "user created",
|
||||||
|
})
|
||||||
|
}
|
||||||
61
examples/server-protected/main.go
Normal file
61
examples/server-protected/main.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// Production server with CORS, rate limiting, body size limits, and timeouts.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger := slog.Default()
|
||||||
|
|
||||||
|
r := server.NewRouter(
|
||||||
|
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
server.WriteError(w, http.StatusNotFound, "not found")
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
|
r.HandleFunc("GET /api/data", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
server.WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
r.HandleFunc("POST /api/upload", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Body is already limited by MaxBodySize middleware.
|
||||||
|
server.WriteJSON(w, http.StatusAccepted, map[string]string{"status": "received"})
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Mount("/", server.HealthHandler())
|
||||||
|
|
||||||
|
srv := server.New(r,
|
||||||
|
append(
|
||||||
|
server.Defaults(logger),
|
||||||
|
server.WithMiddleware(
|
||||||
|
// CORS for browser-facing APIs.
|
||||||
|
server.CORS(
|
||||||
|
server.AllowOrigins("https://app.example.com", "https://admin.example.com"),
|
||||||
|
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
|
||||||
|
server.AllowHeaders("Authorization", "Content-Type"),
|
||||||
|
server.ExposeHeaders("X-Request-Id"),
|
||||||
|
server.AllowCredentials(true),
|
||||||
|
server.MaxAge(3600),
|
||||||
|
),
|
||||||
|
// Rate limit: 100 req/s per IP, burst of 200.
|
||||||
|
server.RateLimit(
|
||||||
|
server.WithRate(100),
|
||||||
|
server.WithBurst(200),
|
||||||
|
),
|
||||||
|
// Limit request body to 1 MB.
|
||||||
|
server.MaxBodySize(1<<20),
|
||||||
|
// Per-request timeout of 30 seconds.
|
||||||
|
server.Timeout(30*time.Second),
|
||||||
|
),
|
||||||
|
server.WithAddr(":8080"),
|
||||||
|
)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.Fatal(srv.ListenAndServe())
|
||||||
|
}
|
||||||
19
internal/requestid/requestid.go
Normal file
19
internal/requestid/requestid.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// Package requestid provides a shared context key for request IDs,
|
||||||
|
// allowing both client and server packages to access request IDs
|
||||||
|
// without circular imports.
|
||||||
|
package requestid
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type key struct{}
|
||||||
|
|
||||||
|
// NewContext returns a context with the given request ID.
|
||||||
|
func NewContext(ctx context.Context, id string) context.Context {
|
||||||
|
return context.WithValue(ctx, key{}, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext returns the request ID from ctx, or empty string if not set.
|
||||||
|
func FromContext(ctx context.Context) string {
|
||||||
|
id, _ := ctx.Value(key{}).(string)
|
||||||
|
return id
|
||||||
|
}
|
||||||
@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// RoundTrippers must not mutate the caller's request; clone before
|
||||||
|
// setting headers (req.Clone is shallow + a header copy).
|
||||||
|
req = req.Clone(req.Context())
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
return next.RoundTrip(req)
|
return next.RoundTrip(req)
|
||||||
})
|
})
|
||||||
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
|||||||
func BasicAuth(username, password string) Middleware {
|
func BasicAuth(username, password string) Middleware {
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
req.SetBasicAuth(username, password)
|
req.SetBasicAuth(username, password)
|
||||||
return next.RoundTrip(req)
|
return next.RoundTrip(req)
|
||||||
})
|
})
|
||||||
|
|||||||
28
middleware/doc.go
Normal file
28
middleware/doc.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// Package middleware provides client-side HTTP middleware for use with
|
||||||
|
// httpx.Client or any http.RoundTripper-based transport chain.
|
||||||
|
//
|
||||||
|
// Each middleware is a function of type func(http.RoundTripper) http.RoundTripper.
|
||||||
|
// Compose them with Chain:
|
||||||
|
//
|
||||||
|
// chain := middleware.Chain(
|
||||||
|
// middleware.Logging(logger),
|
||||||
|
// middleware.Recovery(),
|
||||||
|
// middleware.UserAgent("my-service/1.0"),
|
||||||
|
// )
|
||||||
|
// transport := chain(http.DefaultTransport)
|
||||||
|
//
|
||||||
|
// # Available middleware
|
||||||
|
//
|
||||||
|
// - Logging — structured request/response logging via slog
|
||||||
|
// - Recovery — panic recovery, converts panics to errors
|
||||||
|
// - DefaultHeaders — adds default headers to outgoing requests
|
||||||
|
// - UserAgent — sets User-Agent header
|
||||||
|
// - BearerAuth — dynamic Bearer token authentication
|
||||||
|
// - BasicAuth — HTTP Basic authentication
|
||||||
|
// - RequestID — propagates request ID from context to X-Request-Id header
|
||||||
|
//
|
||||||
|
// # RoundTripperFunc
|
||||||
|
//
|
||||||
|
// RoundTripperFunc adapts plain functions to http.RoundTripper, similar to
|
||||||
|
// http.HandlerFunc. Useful for testing and inline middleware.
|
||||||
|
package middleware
|
||||||
@@ -7,10 +7,17 @@ import "net/http"
|
|||||||
func DefaultHeaders(headers http.Header) Middleware {
|
func DefaultHeaders(headers http.Header) Middleware {
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
// Clone lazily on the first header we actually add, so that
|
||||||
|
// RoundTrippers never mutate the caller's request.
|
||||||
|
cloned := false
|
||||||
for key, values := range headers {
|
for key, values := range headers {
|
||||||
if req.Header.Get(key) != "" {
|
if req.Header.Get(key) != "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !cloned {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
cloned = true
|
||||||
|
}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
req.Header.Add(key, v)
|
req.Header.Add(key, v)
|
||||||
}
|
}
|
||||||
|
|||||||
23
middleware/requestid.go
Normal file
23
middleware/requestid.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestID returns a middleware that propagates the request ID from the
|
||||||
|
// request context to the outgoing X-Request-Id header. This pairs with
|
||||||
|
// the server.RequestID middleware: the server stores the ID in the context,
|
||||||
|
// and the client middleware forwards it to downstream services.
|
||||||
|
func RequestID() Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if id := requestid.FromContext(req.Context()); id != "" {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
req.Header.Set("X-Request-Id", id)
|
||||||
|
}
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
69
middleware/requestid_test.go
Normal file
69
middleware/requestid_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestID(t *testing.T) {
|
||||||
|
t.Run("propagates ID from context", func(t *testing.T) {
|
||||||
|
var gotHeader string
|
||||||
|
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
gotHeader = req.Header.Get("X-Request-Id")
|
||||||
|
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := middleware.RequestID()(base)
|
||||||
|
|
||||||
|
ctx := requestid.NewContext(context.Background(), "test-id-123")
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := mw.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotHeader != "test-id-123" {
|
||||||
|
t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no ID in context skips header", func(t *testing.T) {
|
||||||
|
var gotHeader string
|
||||||
|
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
gotHeader = req.Header.Get("X-Request-Id")
|
||||||
|
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := middleware.RequestID()(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := mw.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotHeader != "" {
|
||||||
|
t.Fatalf("expected no X-Request-Id header, got %q", gotHeader)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not mutate original request", func(t *testing.T) {
|
||||||
|
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := middleware.RequestID()(base)
|
||||||
|
|
||||||
|
ctx := requestid.NewContext(context.Background(), "test-id")
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||||
|
_, _ = mw.RoundTrip(req)
|
||||||
|
|
||||||
|
if req.Header.Get("X-Request-Id") != "" {
|
||||||
|
t.Fatal("original request was mutated")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
18
request.go
18
request.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRequest creates an http.Request with context. It is a convenience
|
// NewRequest creates an http.Request with context. It is a convenience
|
||||||
@@ -32,3 +33,20 @@ func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Re
|
|||||||
}
|
}
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewFormRequest creates an http.Request with a form-encoded body and
|
||||||
|
// sets Content-Type to application/x-www-form-urlencoded.
|
||||||
|
// The GetBody function is set so that the request can be retried.
|
||||||
|
func NewFormRequest(ctx context.Context, method, rawURL string, values url.Values) (*http.Request, error) {
|
||||||
|
encoded := values.Encode()
|
||||||
|
b := []byte(encoded)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader(b)), nil
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|||||||
80
request_form_test.go
Normal file
80
request_form_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewFormRequest(t *testing.T) {
|
||||||
|
t.Run("body is form-encoded", func(t *testing.T) {
|
||||||
|
values := url.Values{"username": {"alice"}, "scope": {"read"}}
|
||||||
|
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com/token", values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.ParseQuery(string(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parsing form: %v", err)
|
||||||
|
}
|
||||||
|
if parsed.Get("username") != "alice" {
|
||||||
|
t.Errorf("username = %q, want %q", parsed.Get("username"), "alice")
|
||||||
|
}
|
||||||
|
if parsed.Get("scope") != "read" {
|
||||||
|
t.Errorf("scope = %q, want %q", parsed.Get("scope"), "read")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("content type is set", func(t *testing.T) {
|
||||||
|
values := url.Values{"key": {"value"}}
|
||||||
|
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := req.Header.Get("Content-Type")
|
||||||
|
if ct != "application/x-www-form-urlencoded" {
|
||||||
|
t.Errorf("Content-Type = %q, want %q", ct, "application/x-www-form-urlencoded")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetBody works for retry", func(t *testing.T) {
|
||||||
|
values := url.Values{"key": {"value"}}
|
||||||
|
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.GetBody == nil {
|
||||||
|
t.Fatal("GetBody is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
b1, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body2, err := req.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody(): %v", err)
|
||||||
|
}
|
||||||
|
b2, err := io.ReadAll(body2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(b1) != string(b2) {
|
||||||
|
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
24
response.go
24
response.go
@@ -97,3 +97,27 @@ func (r *Response) BodyReader() io.Reader {
|
|||||||
}
|
}
|
||||||
return r.Body
|
return r.Body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// limitedReadCloser enforces a maximum number of bytes that may be read from
|
||||||
|
// a response body. Reading more than limit bytes returns ErrResponseTooLarge
|
||||||
|
// rather than silently truncating the body. The original body is closed via
|
||||||
|
// the separate Closer.
|
||||||
|
type limitedReadCloser struct {
|
||||||
|
r io.Reader // an io.LimitReader over the original body (limit+1 bytes)
|
||||||
|
c io.Closer // the original body, for Close
|
||||||
|
limit int64
|
||||||
|
read int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *limitedReadCloser) Read(p []byte) (int, error) {
|
||||||
|
n, err := l.r.Read(p)
|
||||||
|
l.read += int64(n)
|
||||||
|
if l.read > l.limit {
|
||||||
|
return n, ErrResponseTooLarge
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *limitedReadCloser) Close() error {
|
||||||
|
return l.c.Close()
|
||||||
|
}
|
||||||
|
|||||||
94
response_limit_test.go
Normal file
94
response_limit_test.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_MaxResponseBody(t *testing.T) {
|
||||||
|
t.Run("allows response within limit", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprint(w, "hello")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithMaxResponseBody(1024))
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if body != "hello" {
|
||||||
|
t.Fatalf("body = %q, want %q", body, "hello")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns ErrResponseTooLarge when exceeding limit", func(t *testing.T) {
|
||||||
|
largeBody := strings.Repeat("x", 1000)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprint(w, largeBody)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := resp.Bytes(); !errors.Is(err, httpx.ErrResponseTooLarge) {
|
||||||
|
t.Fatalf("err = %v, want ErrResponseTooLarge", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allows body exactly at limit", func(t *testing.T) {
|
||||||
|
exact := strings.Repeat("x", 100)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprint(w, exact)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
b, err := resp.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body at exact limit: %v", err)
|
||||||
|
}
|
||||||
|
if len(b) != 100 {
|
||||||
|
t.Fatalf("body length = %d, want %d", len(b), 100)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no limit when zero", func(t *testing.T) {
|
||||||
|
largeBody := strings.Repeat("x", 10000)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprint(w, largeBody)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New()
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
b, err := resp.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if len(b) != 10000 {
|
||||||
|
t.Fatalf("body length = %d, want %d", len(b), 10000)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -44,8 +44,11 @@ func (b *exponentialBackoff) Delay(attempt int) time.Duration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.withJitter {
|
if b.withJitter {
|
||||||
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
|
// Guard against rand.Int64N panicking on a non-positive argument when
|
||||||
delay += jitter
|
// delay is small enough that delay/2 rounds to zero.
|
||||||
|
if half := int64(delay / 2); half > 0 {
|
||||||
|
delay += time.Duration(rand.Int64N(half))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if delay > b.max {
|
if delay > b.max {
|
||||||
|
|||||||
31
retry/doc.go
Normal file
31
retry/doc.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Package retry provides configurable HTTP request retry as client middleware.
|
||||||
|
//
|
||||||
|
// The retry middleware wraps an http.RoundTripper and automatically retries
|
||||||
|
// failed requests based on a configurable policy, with exponential backoff
|
||||||
|
// and optional jitter.
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// mw := retry.Transport(
|
||||||
|
// retry.WithMaxAttempts(3),
|
||||||
|
// retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 5*time.Second)),
|
||||||
|
// )
|
||||||
|
// transport := mw(http.DefaultTransport)
|
||||||
|
//
|
||||||
|
// # Retry-After
|
||||||
|
//
|
||||||
|
// The retry middleware respects the Retry-After response header. If a server
|
||||||
|
// returns 429 or 503 with Retry-After, the delay from the header overrides
|
||||||
|
// the backoff strategy.
|
||||||
|
//
|
||||||
|
// # Request bodies
|
||||||
|
//
|
||||||
|
// For requests with bodies to be retried, the request must have GetBody set.
|
||||||
|
// Use httpx.NewJSONRequest or httpx.NewFormRequest which set GetBody
|
||||||
|
// automatically.
|
||||||
|
//
|
||||||
|
// # Sentinel errors
|
||||||
|
//
|
||||||
|
// ErrRetryExhausted is returned when all attempts fail. The original error
|
||||||
|
// is wrapped and accessible via errors.Unwrap.
|
||||||
|
package retry
|
||||||
@@ -1,12 +1,17 @@
|
|||||||
package retry
|
package retry
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
maxAttempts int // default 3
|
maxAttempts int // default 3
|
||||||
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||||
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||||
retryAfter bool // default true, respect Retry-After header
|
retryAfter bool // default true, respect Retry-After header
|
||||||
|
clk clock.Clock // time source for backoff delays (real by default)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option configures the retry transport.
|
// Option configures the retry transport.
|
||||||
@@ -18,6 +23,17 @@ func defaults() options {
|
|||||||
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
||||||
policy: defaultPolicy{},
|
policy: defaultPolicy{},
|
||||||
retryAfter: true,
|
retryAfter: true,
|
||||||
|
clk: clock.System(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// withClock sets the clock used for inter-attempt delays. Unexported; for
|
||||||
|
// deterministic tests.
|
||||||
|
func withClock(c clock.Clock) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if c != nil {
|
||||||
|
o.clk = c
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,18 +37,16 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
var exhausted bool
|
var exhausted bool
|
||||||
|
|
||||||
for attempt := range cfg.maxAttempts {
|
for attempt := range cfg.maxAttempts {
|
||||||
// For retries (attempt > 0), restore the request body.
|
// For retries (attempt > 0) the body was consumed by the
|
||||||
if attempt > 0 {
|
// previous attempt; restore it via GetBody. The rewindability
|
||||||
if req.GetBody != nil {
|
// check below guarantees GetBody is set whenever we loop with a
|
||||||
body, bodyErr := req.GetBody()
|
// non-nil body, so this branch is always safe.
|
||||||
if bodyErr != nil {
|
if attempt > 0 && req.GetBody != nil {
|
||||||
return resp, bodyErr
|
body, bodyErr := req.GetBody()
|
||||||
}
|
if bodyErr != nil {
|
||||||
req.Body = body
|
return nil, bodyErr
|
||||||
} else if req.Body != nil {
|
|
||||||
// Body was consumed and cannot be re-created.
|
|
||||||
return resp, err
|
|
||||||
}
|
}
|
||||||
|
req.Body = body
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err = next.RoundTrip(req)
|
resp, err = next.RoundTrip(req)
|
||||||
@@ -64,6 +62,13 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the body cannot be rewound, a retry would replay with an
|
||||||
|
// empty body. Return the current result as-is instead of
|
||||||
|
// draining it and looping with a corrupted request.
|
||||||
|
if req.Body != nil && req.GetBody == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// Compute delay: use backoff or policy delay, whichever is larger.
|
// Compute delay: use backoff or policy delay, whichever is larger.
|
||||||
delay := cfg.backoff.Delay(attempt)
|
delay := cfg.backoff.Delay(attempt)
|
||||||
if policyDelay > delay {
|
if policyDelay > delay {
|
||||||
@@ -84,12 +89,12 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait for the delay or context cancellation.
|
// Wait for the delay or context cancellation.
|
||||||
timer := time.NewTimer(delay)
|
timer := cfg.clk.NewTimer(delay)
|
||||||
select {
|
select {
|
||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
return nil, req.Context().Err()
|
return nil, req.Context().Err()
|
||||||
case <-timer.C:
|
case <-timer.C():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,9 +123,9 @@ func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response,
|
|||||||
|
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case http.StatusTooManyRequests, // 429
|
case http.StatusTooManyRequests, // 429
|
||||||
http.StatusBadGateway, // 502
|
http.StatusBadGateway, // 502
|
||||||
http.StatusServiceUnavailable, // 503
|
http.StatusServiceUnavailable, // 503
|
||||||
http.StatusGatewayTimeout: // 504
|
http.StatusGatewayTimeout: // 504
|
||||||
return true, 0
|
return true, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
"git.codelab.vc/pkg/httpx/middleware"
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,6 +230,83 @@ func TestTransport(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTransport_BodyNotRewindable verifies that an idempotent request whose
|
||||||
|
// body cannot be replayed (no GetBody) is returned as-is rather than retried
|
||||||
|
// with an empty body or a stale, already-drained response.
|
||||||
|
func TestTransport_BodyNotRewindable(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
io.Copy(io.Discard, req.Body) // a real transport consumes the body
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// PUT is idempotent (the policy would retry a 503), but with GetBody unset
|
||||||
|
// the body cannot be rewound.
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", strings.NewReader("data"))
|
||||||
|
req.GetBody = nil
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil || resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("expected the original 503 response, got %v", resp)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected exactly 1 call (no rewind retry), got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTransport_InjectedClock verifies that backoff delays are driven by the
|
||||||
|
// configured clock, so retries are deterministic without real sleeps.
|
||||||
|
func TestTransport_InjectedClock(t *testing.T) {
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(2),
|
||||||
|
WithBackoff(ConstantBackoff(time.Hour)), // would block forever on a real clock
|
||||||
|
withClock(clk),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
var resp *http.Response
|
||||||
|
var err error
|
||||||
|
go func() {
|
||||||
|
resp, err = rt.RoundTrip(req)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Drive the backoff via the mock clock. Advancing repeatedly is robust
|
||||||
|
// against the timer being created slightly after the first attempt.
|
||||||
|
for {
|
||||||
|
clk.Advance(time.Hour)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
goto finished
|
||||||
|
case <-time.After(time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finished:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// policyFunc adapts a function into a Policy.
|
// policyFunc adapts a function into a Policy.
|
||||||
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
||||||
|
|
||||||
|
|||||||
38
server/doc.go
Normal file
38
server/doc.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// Package server provides a production-ready HTTP server with graceful
|
||||||
|
// shutdown, middleware composition, routing, and JSON response helpers.
|
||||||
|
//
|
||||||
|
// # Server
|
||||||
|
//
|
||||||
|
// Server wraps http.Server with net.Listener, signal-based graceful shutdown
|
||||||
|
// (SIGINT/SIGTERM), and lifecycle hooks. It is configured via functional options:
|
||||||
|
//
|
||||||
|
// srv := server.New(handler,
|
||||||
|
// server.WithAddr(":8080"),
|
||||||
|
// server.Defaults(logger),
|
||||||
|
// )
|
||||||
|
// srv.ListenAndServe()
|
||||||
|
//
|
||||||
|
// # Router
|
||||||
|
//
|
||||||
|
// Router wraps http.ServeMux with middleware groups, prefix-based route groups,
|
||||||
|
// and sub-handler mounting. It supports Go 1.22+ method-based patterns:
|
||||||
|
//
|
||||||
|
// r := server.NewRouter()
|
||||||
|
// r.HandleFunc("GET /users/{id}", getUser)
|
||||||
|
//
|
||||||
|
// api := r.Group("/api/v1", authMiddleware)
|
||||||
|
// api.HandleFunc("GET /items", listItems)
|
||||||
|
//
|
||||||
|
// # Middleware
|
||||||
|
//
|
||||||
|
// Server middleware follows the func(http.Handler) http.Handler pattern.
|
||||||
|
// Available middleware: RequestID, Recovery, Logging, CORS, RateLimit,
|
||||||
|
// MaxBodySize, Timeout. Use Chain to compose them:
|
||||||
|
//
|
||||||
|
// chain := server.Chain(server.RequestID(), server.Recovery(logger), server.Logging(logger))
|
||||||
|
//
|
||||||
|
// # Response helpers
|
||||||
|
//
|
||||||
|
// WriteJSON and WriteError provide JSON response writing with proper
|
||||||
|
// Content-Type headers.
|
||||||
|
package server
|
||||||
@@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware {
|
|||||||
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
||||||
type statusWriter struct {
|
type statusWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
status int
|
status int
|
||||||
written bool
|
written bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||||
|
|||||||
15
server/middleware_bodylimit.go
Normal file
15
server/middleware_bodylimit.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// MaxBodySize returns a middleware that limits the size of incoming request
|
||||||
|
// bodies. If the body exceeds n bytes, the server returns 413 Request Entity
|
||||||
|
// Too Large. It wraps the body with http.MaxBytesReader.
|
||||||
|
func MaxBodySize(n int64) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, n)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
61
server/middleware_bodylimit_test.go
Normal file
61
server/middleware_bodylimit_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaxBodySize(t *testing.T) {
|
||||||
|
t.Run("allows body within limit", func(t *testing.T) {
|
||||||
|
handler := server.MaxBodySize(1024)(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(b)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
body := strings.NewReader("hello")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if w.Body.String() != "hello" {
|
||||||
|
t.Fatalf("got body %q, want %q", w.Body.String(), "hello")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects body exceeding limit", func(t *testing.T) {
|
||||||
|
handler := server.MaxBodySize(5)(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
body := strings.NewReader("this is longer than 5 bytes")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
128
server/middleware_cors.go
Normal file
128
server/middleware_cors.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type corsOptions struct {
|
||||||
|
allowOrigins []string
|
||||||
|
allowMethods []string
|
||||||
|
allowHeaders []string
|
||||||
|
exposeHeaders []string
|
||||||
|
allowCredentials bool
|
||||||
|
maxAge int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSOption configures the CORS middleware.
|
||||||
|
type CORSOption func(*corsOptions)
|
||||||
|
|
||||||
|
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
|
||||||
|
// Default is no origins (CORS disabled).
|
||||||
|
func AllowOrigins(origins ...string) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.allowOrigins = origins }
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowMethods sets the allowed HTTP methods for preflight requests.
|
||||||
|
// Default is GET, POST, HEAD.
|
||||||
|
func AllowMethods(methods ...string) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.allowMethods = methods }
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowHeaders sets the allowed request headers for preflight requests.
|
||||||
|
func AllowHeaders(headers ...string) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.allowHeaders = headers }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExposeHeaders sets headers that browsers are allowed to access.
|
||||||
|
func ExposeHeaders(headers ...string) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.exposeHeaders = headers }
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowCredentials indicates whether the response to the request can be
|
||||||
|
// exposed when the credentials flag is true.
|
||||||
|
func AllowCredentials(allow bool) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.allowCredentials = allow }
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
|
||||||
|
func MaxAge(seconds int) CORSOption {
|
||||||
|
return func(o *corsOptions) { o.maxAge = seconds }
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
|
||||||
|
// It processes preflight OPTIONS requests and sets the appropriate
|
||||||
|
// Access-Control-* response headers.
|
||||||
|
func CORS(opts ...CORSOption) Middleware {
|
||||||
|
o := &corsOptions{
|
||||||
|
allowMethods: []string{"GET", "POST", "HEAD"},
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
|
||||||
|
allowAll := false
|
||||||
|
for _, origin := range o.allowOrigins {
|
||||||
|
if origin == "*" {
|
||||||
|
allowAll = true
|
||||||
|
}
|
||||||
|
allowedOrigins[origin] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
methods := strings.Join(o.allowMethods, ", ")
|
||||||
|
headers := strings.Join(o.allowHeaders, ", ")
|
||||||
|
expose := strings.Join(o.exposeHeaders, ", ")
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := allowAll
|
||||||
|
if !allowed {
|
||||||
|
_, allowed = allowedOrigins[origin]
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the allowed origin. When credentials are enabled,
|
||||||
|
// we must echo the specific origin, not "*".
|
||||||
|
if allowAll && !o.allowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.allowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
if expose != "" {
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", expose)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Add("Vary", "Origin")
|
||||||
|
|
||||||
|
// Handle preflight.
|
||||||
|
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", methods)
|
||||||
|
if headers != "" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", headers)
|
||||||
|
}
|
||||||
|
if o.maxAge > 0 {
|
||||||
|
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
143
server/middleware_cors_test.go
Normal file
143
server/middleware_cors_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCORS(t *testing.T) {
|
||||||
|
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no Origin header passes through", func(t *testing.T) {
|
||||||
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||||
|
t.Fatal("expected no CORS headers without Origin")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wildcard origin", func(t *testing.T) {
|
||||||
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("specific origin allowed", func(t *testing.T) {
|
||||||
|
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
|
||||||
|
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://evil.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||||
|
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preflight OPTIONS", func(t *testing.T) {
|
||||||
|
mw := server.CORS(
|
||||||
|
server.AllowOrigins("http://example.com"),
|
||||||
|
server.AllowMethods("GET", "POST", "PUT"),
|
||||||
|
server.AllowHeaders("Authorization", "Content-Type"),
|
||||||
|
server.MaxAge(3600),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
|
||||||
|
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("credentials with specific origin", func(t *testing.T) {
|
||||||
|
mw := server.CORS(
|
||||||
|
server.AllowOrigins("*"),
|
||||||
|
server.AllowCredentials(true),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// With credentials, must echo specific origin even with wildcard config.
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||||
|
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expose headers", func(t *testing.T) {
|
||||||
|
mw := server.CORS(
|
||||||
|
server.AllowOrigins("*"),
|
||||||
|
server.ExposeHeaders("X-Custom", "X-Request-Id"),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
|
||||||
|
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Vary header is set", func(t *testing.T) {
|
||||||
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if got := w.Header().Get("Vary"); got != "Origin" {
|
||||||
|
t.Fatalf("Vary = %q, want %q", got, "Origin")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
248
server/middleware_ratelimit.go
Normal file
248
server/middleware_ratelimit.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultMaxKeys bounds the number of distinct rate-limit buckets retained in
|
||||||
|
// memory. When exceeded, fully-refilled (idle) buckets are evicted.
|
||||||
|
const defaultMaxKeys = 1 << 16
|
||||||
|
|
||||||
|
type rateLimitOptions struct {
|
||||||
|
rate float64
|
||||||
|
burst int
|
||||||
|
keyFunc func(r *http.Request) string
|
||||||
|
clock clock.Clock
|
||||||
|
trustedProxies []*net.IPNet
|
||||||
|
maxKeys int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitOption configures the RateLimit middleware.
|
||||||
|
type RateLimitOption func(*rateLimitOptions)
|
||||||
|
|
||||||
|
// WithRate sets the token refill rate (tokens per second).
|
||||||
|
func WithRate(tokensPerSecond float64) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBurst sets the maximum burst size (bucket capacity).
|
||||||
|
func WithBurst(n int) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.burst = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||||
|
// request. By default, the client IP from RemoteAddr is used (see
|
||||||
|
// WithTrustedProxies to honor X-Forwarded-For behind a trusted proxy).
|
||||||
|
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTrustedProxies enables X-Forwarded-For parsing, but only for requests
|
||||||
|
// whose immediate peer (RemoteAddr) falls within one of the given trusted
|
||||||
|
// CIDR ranges (e.g. "10.0.0.0/8", "192.168.0.0/16"). A bare IP is accepted as
|
||||||
|
// a /32 or /128. When the peer is trusted, the client key is taken from the
|
||||||
|
// right-most X-Forwarded-For entry that is not itself a trusted proxy;
|
||||||
|
// otherwise RemoteAddr is used. Invalid entries are ignored (treated as
|
||||||
|
// untrusted), so a typo can never silently widen trust.
|
||||||
|
//
|
||||||
|
// Without this option the middleware never trusts client-supplied forwarding
|
||||||
|
// headers, which prevents trivial rate-limit bypass and bucket exhaustion via
|
||||||
|
// spoofed headers.
|
||||||
|
func WithTrustedProxies(cidrs ...string) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) {
|
||||||
|
for _, c := range cidrs {
|
||||||
|
if _, ipnet, err := net.ParseCIDR(c); err == nil {
|
||||||
|
o.trustedProxies = append(o.trustedProxies, ipnet)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(c); ip != nil {
|
||||||
|
bits := 32
|
||||||
|
if ip.To4() == nil {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
o.trustedProxies = append(o.trustedProxies, &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: net.CIDRMask(bits, bits),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxKeys sets the soft upper bound on the number of distinct buckets
|
||||||
|
// retained in memory. When exceeded, idle (fully-refilled) buckets are
|
||||||
|
// evicted; active buckets are never dropped. Default is 65536.
|
||||||
|
func WithMaxKeys(n int) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) {
|
||||||
|
if n > 0 {
|
||||||
|
o.maxKeys = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRateLimitClock sets the clock for testing. Not exported.
|
||||||
|
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||||
|
return func(o *rateLimitOptions) { o.clock = c }
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimit returns a middleware that limits requests using a per-key token
|
||||||
|
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||||
|
// Requests with a Retry-After header.
|
||||||
|
//
|
||||||
|
// By default the key is the client IP taken from RemoteAddr. Forwarding
|
||||||
|
// headers (X-Forwarded-For) are honored only when WithTrustedProxies is set,
|
||||||
|
// so the limiter cannot be bypassed by spoofing headers.
|
||||||
|
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||||
|
o := &rateLimitOptions{
|
||||||
|
rate: 10,
|
||||||
|
burst: 20,
|
||||||
|
clock: clock.System(),
|
||||||
|
maxKeys: defaultMaxKeys,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
if o.keyFunc == nil {
|
||||||
|
o.keyFunc = o.clientKey
|
||||||
|
}
|
||||||
|
|
||||||
|
lim := &limiter{opts: o}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
key := o.keyFunc(r)
|
||||||
|
if allowed, retryAfter := lim.allow(key); !allowed {
|
||||||
|
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||||
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// limiter holds the per-key token buckets for one RateLimit middleware.
|
||||||
|
type limiter struct {
|
||||||
|
opts *rateLimitOptions
|
||||||
|
buckets sync.Map // key -> *bucket
|
||||||
|
count atomic.Int64
|
||||||
|
sweeping atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow reports whether a request for key may proceed. When denied it also
|
||||||
|
// returns the suggested Retry-After delay in seconds.
|
||||||
|
func (l *limiter) allow(key string) (bool, float64) {
|
||||||
|
o := l.opts
|
||||||
|
val, loaded := l.buckets.LoadOrStore(key, &bucket{
|
||||||
|
tokens: float64(o.burst),
|
||||||
|
lastTime: o.clock.Now(),
|
||||||
|
})
|
||||||
|
if !loaded && l.count.Add(1) > int64(o.maxKeys) {
|
||||||
|
l.sweep()
|
||||||
|
}
|
||||||
|
b := val.(*bucket)
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
now := o.clock.Now()
|
||||||
|
elapsed := now.Sub(b.lastTime).Seconds()
|
||||||
|
b.tokens += elapsed * o.rate
|
||||||
|
if b.tokens > float64(o.burst) {
|
||||||
|
b.tokens = float64(o.burst)
|
||||||
|
}
|
||||||
|
b.lastTime = now
|
||||||
|
|
||||||
|
if b.tokens < 1 {
|
||||||
|
return false, (1 - b.tokens) / o.rate
|
||||||
|
}
|
||||||
|
b.tokens--
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// sweep removes fully-refilled (idle) buckets to bound memory. Only one sweep
|
||||||
|
// runs at a time; buckets that still hold a partial limit are preserved so
|
||||||
|
// that eviction can never reset an active client's allowance.
|
||||||
|
func (l *limiter) sweep() {
|
||||||
|
if !l.sweeping.CompareAndSwap(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer l.sweeping.Store(false)
|
||||||
|
|
||||||
|
o := l.opts
|
||||||
|
now := o.clock.Now()
|
||||||
|
l.buckets.Range(func(k, v any) bool {
|
||||||
|
b := v.(*bucket)
|
||||||
|
b.mu.Lock()
|
||||||
|
elapsed := now.Sub(b.lastTime).Seconds()
|
||||||
|
full := b.tokens+elapsed*o.rate >= float64(o.burst)
|
||||||
|
b.mu.Unlock()
|
||||||
|
if full {
|
||||||
|
l.buckets.Delete(k)
|
||||||
|
l.count.Add(-1)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type bucket struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens float64
|
||||||
|
lastTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientKey derives the rate-limit key from a request. It uses RemoteAddr by
|
||||||
|
// default and only consults X-Forwarded-For when the peer is a configured
|
||||||
|
// trusted proxy (see WithTrustedProxies).
|
||||||
|
func (o *rateLimitOptions) clientKey(r *http.Request) string {
|
||||||
|
remote := remoteIP(r)
|
||||||
|
if len(o.trustedProxies) == 0 || !o.isTrusted(remote) {
|
||||||
|
return remote
|
||||||
|
}
|
||||||
|
// Peer is trusted: walk X-Forwarded-For right-to-left and return the first
|
||||||
|
// address that is not itself a trusted proxy — that is the real client.
|
||||||
|
xff := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xff == "" {
|
||||||
|
return remote
|
||||||
|
}
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
ip := strings.TrimSpace(parts[i])
|
||||||
|
if ip == "" || o.isTrusted(ip) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
return remote
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *rateLimitOptions) isTrusted(ip string) bool {
|
||||||
|
parsed := net.ParseIP(ip)
|
||||||
|
if parsed == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, n := range o.trustedProxies {
|
||||||
|
if n.Contains(parsed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// remoteIP returns the host portion of r.RemoteAddr, or the raw value if it
|
||||||
|
// has no port.
|
||||||
|
func remoteIP(r *http.Request) string {
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
62
server/middleware_ratelimit_internal_test.go
Normal file
62
server/middleware_ratelimit_internal_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestRequest(remoteAddr, xff string) *http.Request {
|
||||||
|
r := &http.Request{RemoteAddr: remoteAddr, Header: http.Header{}}
|
||||||
|
if xff != "" {
|
||||||
|
r.Header.Set("X-Forwarded-For", xff)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLimiterSweepEvictsIdleBuckets verifies that sweep removes fully-refilled
|
||||||
|
// (idle) buckets while preserving buckets that still hold an active limit, so
|
||||||
|
// memory is bounded without resetting live clients' allowances.
|
||||||
|
func TestLimiterSweepEvictsIdleBuckets(t *testing.T) {
|
||||||
|
clk := clock.Mock(time.Now())
|
||||||
|
o := &rateLimitOptions{rate: 1, burst: 5, clock: clk, maxKeys: 1 << 30}
|
||||||
|
lim := &limiter{opts: o}
|
||||||
|
|
||||||
|
// "idle" makes a single request, then time passes so it refills to full.
|
||||||
|
lim.allow("idle")
|
||||||
|
clk.Advance(10 * time.Second)
|
||||||
|
|
||||||
|
// "active" drains its whole burst at the (advanced) current time.
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
lim.allow("active")
|
||||||
|
}
|
||||||
|
|
||||||
|
lim.sweep()
|
||||||
|
|
||||||
|
if _, ok := lim.buckets.Load("idle"); ok {
|
||||||
|
t.Error("fully-refilled idle bucket was not evicted")
|
||||||
|
}
|
||||||
|
if _, ok := lim.buckets.Load("active"); !ok {
|
||||||
|
t.Error("active bucket with a partial limit was wrongly evicted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClientKeyTrustedProxy exercises the X-Forwarded-For walk used behind a
|
||||||
|
// trusted proxy, independent of the HTTP layer.
|
||||||
|
func TestClientKeyTrustedProxy(t *testing.T) {
|
||||||
|
o := &rateLimitOptions{}
|
||||||
|
WithTrustedProxies("192.168.0.0/16")(o)
|
||||||
|
|
||||||
|
r := newTestRequest("192.168.1.10:443", "203.0.113.7, 192.168.1.10")
|
||||||
|
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||||
|
t.Fatalf("clientKey = %q, want real client 203.0.113.7", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Untrusted peer: X-Forwarded-For must be ignored entirely.
|
||||||
|
r = newTestRequest("203.0.113.7:443", "10.0.0.1")
|
||||||
|
if got := o.clientKey(r); got != "203.0.113.7" {
|
||||||
|
t.Fatalf("clientKey = %q, want peer 203.0.113.7 (XFF ignored)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
202
server/middleware_ratelimit_test.go
Normal file
202
server/middleware_ratelimit_test.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimit(t *testing.T) {
|
||||||
|
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allows requests within limit", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(100),
|
||||||
|
server.WithBurst(10),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
for i := range 10 {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(2),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust burst.
|
||||||
|
for range 2 {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request should be rejected.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if w.Header().Get("Retry-After") == "" {
|
||||||
|
t.Fatal("expected Retry-After header")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different IPs have independent limits", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// First IP exhausts its limit.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Second IP should still be allowed.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "5.6.7.8:5678"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ignores X-Forwarded-For without trusted proxies", func(t *testing.T) {
|
||||||
|
// By default the limiter keys on RemoteAddr only. A spoofed,
|
||||||
|
// per-request X-Forwarded-For must not let a single peer bypass the
|
||||||
|
// limit by minting a fresh bucket each time.
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
send := func(xff string) int {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", xff)
|
||||||
|
req.RemoteAddr = "192.168.1.1:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
if code := send("10.0.0.1"); code != http.StatusOK {
|
||||||
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||||
|
}
|
||||||
|
// Different spoofed XFF, same peer — must still be limited.
|
||||||
|
if code := send("10.0.0.2"); code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("spoofed XFF bypassed limit: got %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("honors X-Forwarded-For behind trusted proxy", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
server.WithTrustedProxies("192.168.0.0/16"),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
send := func(xff string) int {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", xff)
|
||||||
|
req.RemoteAddr = "192.168.1.1:1234" // trusted proxy
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real client 10.0.0.1 (left-most), proxy hop 192.168.1.1 (right-most).
|
||||||
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusOK {
|
||||||
|
t.Fatalf("first request: got %d, want %d", code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if code := send("10.0.0.1, 192.168.1.1"); code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("same client not limited: got %d, want %d", code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
// A different real client through the same proxy is independent.
|
||||||
|
if code := send("10.0.0.2, 192.168.1.1"); code != http.StatusOK {
|
||||||
|
t.Fatalf("different client should be allowed: got %d, want %d", code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom key function", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1),
|
||||||
|
server.WithBurst(1),
|
||||||
|
server.WithKeyFunc(func(r *http.Request) string {
|
||||||
|
return r.Header.Get("X-API-Key")
|
||||||
|
}),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust key "abc".
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "abc")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Same key should be rate limited.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "abc")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different key should be allowed.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "xyz")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tokens refill over time", func(t *testing.T) {
|
||||||
|
mw := server.RateLimit(
|
||||||
|
server.WithRate(1000), // Very fast refill for test
|
||||||
|
server.WithBurst(1),
|
||||||
|
)(okHandler)
|
||||||
|
|
||||||
|
// Exhaust burst.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Wait a bit for refill.
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -5,13 +5,18 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestIDKey struct{}
|
// maxRequestIDLen bounds the length of a client-supplied request ID that we
|
||||||
|
// are willing to propagate.
|
||||||
|
const maxRequestIDLen = 128
|
||||||
|
|
||||||
// RequestID returns a middleware that assigns a unique request ID to each
|
// RequestID returns a middleware that assigns a unique request ID to each
|
||||||
// request. If the incoming request already has an X-Request-Id header, that
|
// request. If the incoming request carries a valid X-Request-Id header, that
|
||||||
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
// value is reused; otherwise (or if the supplied value is empty, too long, or
|
||||||
|
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
|
||||||
//
|
//
|
||||||
// The request ID is stored in the request context (retrieve with
|
// The request ID is stored in the request context (retrieve with
|
||||||
// RequestIDFromContext) and set on the response X-Request-Id header.
|
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||||
@@ -19,22 +24,42 @@ func RequestID() Middleware {
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
id := r.Header.Get("X-Request-Id")
|
id := r.Header.Get("X-Request-Id")
|
||||||
if id == "" {
|
if !validRequestID(id) {
|
||||||
id = newUUID()
|
id = newUUID()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), requestIDKey{}, id)
|
ctx := requestid.NewContext(r.Context(), id)
|
||||||
w.Header().Set("X-Request-Id", id)
|
w.Header().Set("X-Request-Id", id)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validRequestID reports whether a client-supplied request ID is safe to
|
||||||
|
// propagate: non-empty, within a sane length, and restricted to characters
|
||||||
|
// that cannot forge log lines or split response headers.
|
||||||
|
func validRequestID(id string) bool {
|
||||||
|
if id == "" || len(id) > maxRequestIDLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 0; i < len(id); i++ {
|
||||||
|
c := id[i]
|
||||||
|
switch {
|
||||||
|
case c >= 'a' && c <= 'z',
|
||||||
|
c >= 'A' && c <= 'Z',
|
||||||
|
c >= '0' && c <= '9',
|
||||||
|
c == '-', c == '_', c == '.':
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||||
// string if none is set.
|
// string if none is set.
|
||||||
func RequestIDFromContext(ctx context.Context) string {
|
func RequestIDFromContext(ctx context.Context) string {
|
||||||
id, _ := ctx.Value(requestIDKey{}).(string)
|
return requestid.FromContext(ctx)
|
||||||
return id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newUUID generates a UUID v4 string using crypto/rand.
|
// newUUID generates a UUID v4 string using crypto/rand.
|
||||||
|
|||||||
@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
|
|||||||
t.Fatalf("expected empty, got %q", id)
|
t.Fatalf("expected empty, got %q", id)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"header injection": "abc\r\nX-Injected: 1",
|
||||||
|
"contains space": "has space",
|
||||||
|
"too long": strings.Repeat("a", 200),
|
||||||
|
}
|
||||||
|
for name, badID := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
var gotID string
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotID = server.RequestIDFromContext(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
mw := server.RequestID()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Request-Id", badID)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if gotID == badID {
|
||||||
|
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
|
||||||
|
}
|
||||||
|
if len(gotID) != 36 {
|
||||||
|
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||||
|
|||||||
15
server/middleware_timeout.go
Normal file
15
server/middleware_timeout.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Timeout returns a middleware that limits request processing time.
|
||||||
|
// If the handler does not complete within d, the client receives a
|
||||||
|
// 503 Service Unavailable response. It wraps http.TimeoutHandler.
|
||||||
|
func Timeout(d time.Duration) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.TimeoutHandler(next, d, "Service Unavailable\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
49
server/middleware_timeout_test.go
Normal file
49
server/middleware_timeout_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTimeout(t *testing.T) {
|
||||||
|
t.Run("handler completes within timeout", func(t *testing.T) {
|
||||||
|
handler := server.Timeout(1 * time.Second)(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handler exceeds timeout returns 503", func(t *testing.T) {
|
||||||
|
handler := server.Timeout(10 * time.Millisecond)(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
select {
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
case <-r.Context().Done():
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
29
server/respond.go
Normal file
29
server/respond.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteJSON encodes v as JSON and writes it to w with the given status code.
|
||||||
|
// It sets Content-Type to application/json.
|
||||||
|
func WriteJSON(w http.ResponseWriter, status int, v any) error {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_, err = w.Write(b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteError writes a JSON error response with the given status code and
|
||||||
|
// message. The response body is {"error": "<message>"}.
|
||||||
|
func WriteError(w http.ResponseWriter, status int, msg string) error {
|
||||||
|
return WriteJSON(w, status, errorBody{Error: msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorBody struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
72
server/respond_test.go
Normal file
72
server/respond_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteJSON(t *testing.T) {
|
||||||
|
t.Run("writes JSON with status and content type", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
type resp struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.WriteJSON(w, 201, resp{ID: 1, Name: "Alice"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != 201 {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, 201)
|
||||||
|
}
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||||
|
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded resp
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &decoded); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if decoded.ID != 1 || decoded.Name != "Alice" {
|
||||||
|
t.Fatalf("got %+v, want {ID:1 Name:Alice}", decoded)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error for unmarshalable input", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err := server.WriteJSON(w, 200, make(chan int))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for channel type")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteError(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
err := server.WriteError(w, 404, "not found")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != 404 {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, 404)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if body.Error != "not found" {
|
||||||
|
t.Fatalf("error = %q, want %q", body.Error, "not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,16 +9,31 @@ import (
|
|||||||
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
|
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
|
||||||
// like "GET /users/{id}".
|
// like "GET /users/{id}".
|
||||||
type Router struct {
|
type Router struct {
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
prefix string
|
prefix string
|
||||||
middlewares []Middleware
|
middlewares []Middleware
|
||||||
|
notFoundHandler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOption configures a Router.
|
||||||
|
type RouterOption func(*Router)
|
||||||
|
|
||||||
|
// WithNotFoundHandler sets a custom handler for requests that don't match
|
||||||
|
// any registered pattern. This is useful for returning JSON 404/405 responses
|
||||||
|
// instead of the default plain text.
|
||||||
|
func WithNotFoundHandler(h http.Handler) RouterOption {
|
||||||
|
return func(r *Router) { r.notFoundHandler = h }
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter creates a new Router backed by a fresh http.ServeMux.
|
// NewRouter creates a new Router backed by a fresh http.ServeMux.
|
||||||
func NewRouter() *Router {
|
func NewRouter(opts ...RouterOption) *Router {
|
||||||
return &Router{
|
r := &Router{
|
||||||
mux: http.NewServeMux(),
|
mux: http.NewServeMux(),
|
||||||
}
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(r)
|
||||||
|
}
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle registers a handler for the given pattern. The pattern follows
|
// Handle registers a handler for the given pattern. The pattern follows
|
||||||
@@ -63,6 +78,14 @@ func (r *Router) Mount(prefix string, handler http.Handler) {
|
|||||||
|
|
||||||
// ServeHTTP implements http.Handler, making Router usable as a handler.
|
// ServeHTTP implements http.Handler, making Router usable as a handler.
|
||||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if r.notFoundHandler != nil {
|
||||||
|
// Use the mux to check for a match. If none, use the custom handler.
|
||||||
|
_, pattern := r.mux.Handler(req)
|
||||||
|
if pattern == "" {
|
||||||
|
r.notFoundHandler.ServeHTTP(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
r.mux.ServeHTTP(w, req)
|
r.mux.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
70
server/route_notfound_test.go
Normal file
70
server/route_notfound_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRouter_NotFoundHandler(t *testing.T) {
|
||||||
|
t.Run("custom 404 handler", func(t *testing.T) {
|
||||||
|
notFound := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||||
|
})
|
||||||
|
|
||||||
|
r := server.NewRouter(server.WithNotFoundHandler(notFound))
|
||||||
|
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Matched route works normally.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/exists", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("matched route: got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmatched route uses custom handler.
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("not found: got status %d, want %d", w.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||||
|
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if body["error"] != "not found" {
|
||||||
|
t.Fatalf("error = %q, want %q", body["error"], "not found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default behavior without custom handler", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Default ServeMux returns 404.
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user