Compare commits
29 Commits
f2a4a4fccc
...
v0.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
| b6350185d9 | |||
| 85cdc5e2c9 | |||
| 25beb2f5c2 | |||
| 16ff427c93 | |||
| 138d4b6c6d | |||
| 3aa7536328 | |||
| 89cfc38f0e | |||
| 8a63f142a7 | |||
| 21274c178a | |||
| 49be6f8a7e | |||
| 3395f70abd | |||
| 7a2cef00c3 | |||
| de5bf9a6d9 | |||
| 7f12b0c87a | |||
| 1b322c8c81 | |||
| b40a373675 | |||
| f6384ecbea | |||
| 4d47918a66 | |||
| 7fae6247d5 | |||
| cea75d198b | |||
| 6b901c931e | |||
| d260abc393 | |||
| 5cfd1a7400 | |||
| f9a05f5c57 | |||
| a90c4cd7fa | |||
| 8d322123a4 | |||
| 2ca930236d | |||
| 505c7b8c4f | |||
| 6b1941fce7 |
48
.cursorrules
Normal file
48
.cursorrules
Normal file
@@ -0,0 +1,48 @@
|
||||
You are working on `git.codelab.vc/pkg/httpx`, a Go 1.24 HTTP client/server library with zero external dependencies.
|
||||
|
||||
## Architecture
|
||||
|
||||
- Client middleware: `func(http.RoundTripper) http.RoundTripper` — compose with `middleware.Chain`
|
||||
- Server middleware: `func(http.Handler) http.Handler` — compose with `server.Chain`
|
||||
- All configuration uses functional options pattern (`WithXxx` functions)
|
||||
- Chain order for client: Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
|
||||
|
||||
## Package structure
|
||||
|
||||
- `httpx` (root) — Client, request builders (NewJSONRequest, NewFormRequest), error types
|
||||
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
|
||||
- `retry/` — retry middleware with exponential backoff and Retry-After support
|
||||
- `circuitbreaker/` — per-host circuit breaker (sync.Map of host → Breaker)
|
||||
- `balancer/` — load balancing with health checking (RoundRobin, Weighted, Failover)
|
||||
- `server/` — Server, Router, server middleware (RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout), response helpers (WriteJSON, WriteError)
|
||||
- `internal/requestid/` — shared context key (avoids circular import between server and middleware)
|
||||
- `internal/clock/` — deterministic time for tests
|
||||
|
||||
## Code conventions
|
||||
|
||||
- Zero external dependencies — stdlib only, do not add imports outside the module
|
||||
- Functional options: `type Option func(*options)` with `With<Name>` constructors
|
||||
- Test with stdlib only: `testing`, `httptest`, `net/http`. No testify/gomock
|
||||
- Client test helper: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc`
|
||||
- Server test helper: `httptest.NewRecorder`, `httptest.NewRequest`, `waitForAddr(t, srv)`
|
||||
- Thread safety with `sync.Mutex`, `sync.Map`, or `atomic`
|
||||
- Use `internal/clock` for time-dependent tests, not `time.Now()` directly
|
||||
- Sentinel errors in sub-packages, re-exported as aliases in root package
|
||||
|
||||
## When writing new code
|
||||
|
||||
- Client middleware → file in `middleware/`, return `middleware.Middleware`
|
||||
- Server middleware → file in `server/middleware_<name>.go`, return `server.Middleware`
|
||||
- New option → add field to options struct, create `With<Name>` func, apply in constructor
|
||||
- Do NOT import `server` from `middleware` or vice versa (use `internal/requestid` for shared context)
|
||||
- Client.Close() must be called when using WithEndpoints() (stops health checker goroutine)
|
||||
- Request bodies must have GetBody set for retry — use NewJSONRequest/NewFormRequest
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
go build ./... # compile
|
||||
go test ./... # test
|
||||
go test -race ./... # test with race detector
|
||||
go vet ./... # static analysis
|
||||
```
|
||||
23
.gitea/workflows/ci.yml
Normal file
23
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,23 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24"
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race -count=1 ./...
|
||||
39
.gitea/workflows/publish.yml
Normal file
39
.gitea/workflows/publish.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Publish
|
||||
|
||||
on:
|
||||
push:
|
||||
tags: ["v*"]
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24"
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race -count=1 ./...
|
||||
|
||||
- name: Publish to Gitea Package Registry
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/}
|
||||
MODULE=$(go list -m)
|
||||
|
||||
# Create module zip with required prefix: module@version/
|
||||
git archive --format=zip --prefix="${MODULE}@${VERSION}/" HEAD -o module.zip
|
||||
|
||||
# Gitea Go Package Registry API
|
||||
curl -s -f \
|
||||
-X PUT \
|
||||
-H "Authorization: token ${{ secrets.PUBLISH_TOKEN }}" \
|
||||
-H "Content-Type: application/zip" \
|
||||
--data-binary @module.zip \
|
||||
"${{ github.server_url }}/api/packages/pkg/go/upload?module=${MODULE}&version=${VERSION}"
|
||||
env:
|
||||
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}
|
||||
50
.github/copilot-instructions.md
vendored
Normal file
50
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copilot instructions — httpx
|
||||
|
||||
## Project
|
||||
|
||||
`git.codelab.vc/pkg/httpx` is a Go 1.24 HTTP client and server library with zero external dependencies.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Client (`httpx` root package)
|
||||
- Middleware type: `func(http.RoundTripper) http.RoundTripper`
|
||||
- Chain assembly (outermost → innermost): Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
|
||||
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
|
||||
- Circuit breaker is per-host (sync.Map of host → Breaker)
|
||||
- Client.Close() required when using WithEndpoints() — stops health checker goroutine
|
||||
|
||||
### Server (`server/` package)
|
||||
- Middleware type: `func(http.Handler) http.Handler`
|
||||
- Router wraps http.ServeMux with groups, prefix routing, Mount for sub-handlers
|
||||
- Defaults() preset: RequestID → Recovery → Logging + production timeouts
|
||||
- Available middleware: RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout
|
||||
- WriteJSON/WriteError for JSON responses
|
||||
|
||||
### Sub-packages
|
||||
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
|
||||
- `retry/` — retry with exponential backoff and Retry-After
|
||||
- `circuitbreaker/` — per-host circuit breaker
|
||||
- `balancer/` — load balancing with health checking
|
||||
- `internal/requestid/` — shared context key between server and middleware
|
||||
- `internal/clock/` — deterministic time for tests
|
||||
|
||||
## Conventions
|
||||
|
||||
- All configuration uses functional options (`WithXxx` functions)
|
||||
- Zero external dependencies — do not add requires to go.mod
|
||||
- Tests use stdlib only (testing, httptest) — no testify or gomock
|
||||
- Thread safety with sync.Mutex, sync.Map, or atomic
|
||||
- Client test mock: `mockTransport(fn)` using `middleware.RoundTripperFunc`
|
||||
- Server test helpers: `httptest.NewRecorder`, `httptest.NewRequest`
|
||||
- Do NOT import server from middleware or vice versa — use internal/requestid for shared context
|
||||
- Sentinel errors in sub-packages, re-exported in root package
|
||||
- Use internal/clock for time-dependent tests
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
go build ./... # compile
|
||||
go test ./... # test
|
||||
go test -race ./... # test with race detector
|
||||
go vet ./... # static analysis
|
||||
```
|
||||
139
AGENTS.md
Normal file
139
AGENTS.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# AGENTS.md — httpx
|
||||
|
||||
Universal guide for AI coding agents working with this codebase.
|
||||
|
||||
## Overview
|
||||
|
||||
`git.codelab.vc/pkg/httpx` is a Go HTTP toolkit with **zero external dependencies** (Go 1.24, stdlib only). It provides:
|
||||
- A composable HTTP **client** with retry, circuit breaking, load balancing
|
||||
- A production-ready HTTP **server** with routing, middleware, graceful shutdown
|
||||
|
||||
## Package map
|
||||
|
||||
```
|
||||
httpx/ Root — Client, request builders, error types
|
||||
├── middleware/ Client-side middleware (RoundTripper wrappers)
|
||||
├── retry/ Retry middleware with backoff
|
||||
├── circuitbreaker/ Per-host circuit breaker
|
||||
├── balancer/ Client-side load balancing + health checking
|
||||
├── server/ Server, Router, server-side middleware, response helpers
|
||||
└── internal/
|
||||
├── requestid/ Shared context key (avoids circular imports)
|
||||
└── clock/ Deterministic time for testing
|
||||
```
|
||||
|
||||
## Middleware chain architecture
|
||||
|
||||
### Client middleware: `func(http.RoundTripper) http.RoundTripper`
|
||||
|
||||
```
|
||||
Request flow (outermost → innermost):
|
||||
|
||||
Logging
|
||||
└→ User Middlewares
|
||||
└→ Retry
|
||||
└→ Circuit Breaker
|
||||
└→ Balancer
|
||||
└→ Base Transport (http.DefaultTransport)
|
||||
```
|
||||
|
||||
Retry wraps CB+Balancer so each attempt can hit a different endpoint.
|
||||
|
||||
### Server middleware: `func(http.Handler) http.Handler`
|
||||
|
||||
```
|
||||
Chain(A, B, C)(handler) == A(B(C(handler)))
|
||||
A is outermost (sees request first, response last)
|
||||
```
|
||||
|
||||
Defaults() preset: `RequestID → Recovery → Logging`
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Add a client middleware
|
||||
|
||||
1. Create file in `middleware/` (or inline)
|
||||
2. Return `middleware.Middleware` (`func(http.RoundTripper) http.RoundTripper`)
|
||||
3. Use `middleware.RoundTripperFunc` for the inner adapter
|
||||
4. Test with `middleware.RoundTripperFunc` as mock transport
|
||||
|
||||
```go
|
||||
func MyMiddleware() middleware.Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
// before
|
||||
resp, err := next.RoundTrip(req)
|
||||
// after
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Add a server middleware
|
||||
|
||||
1. Create file in `server/` named `middleware_<name>.go`
|
||||
2. Return `server.Middleware` (`func(http.Handler) http.Handler`)
|
||||
3. Use `server.statusWriter` if you need to capture the response status
|
||||
4. Test with `httptest.NewRecorder` + `httptest.NewRequest`
|
||||
|
||||
```go
|
||||
func MyMiddleware() Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// before
|
||||
next.ServeHTTP(w, r)
|
||||
// after
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Add a route
|
||||
|
||||
```go
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /users/{id}", getUser)
|
||||
r.HandleFunc("POST /users", createUser)
|
||||
|
||||
// Group with prefix + middleware
|
||||
api := r.Group("/api/v1", authMiddleware)
|
||||
api.HandleFunc("GET /items", listItems)
|
||||
|
||||
// Mount sub-handler
|
||||
r.Mount("/health", server.HealthHandler())
|
||||
```
|
||||
|
||||
### Add a functional option
|
||||
|
||||
1. Add field to the options struct (`clientOptions` or `serverOptions`)
|
||||
2. Create `With<Name>` function returning `Option`
|
||||
3. Apply the field in the constructor (`New`)
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Middleware order matters**: Retry wraps CB+Balancer intentionally — each retry attempt can hit a different endpoint and a different circuit breaker
|
||||
- **Circular imports via `internal/`**: Both `server` and `middleware` packages need request ID context. The shared key lives in `internal/requestid` — do NOT import `server` from `middleware` or vice versa
|
||||
- **Client.Close() is required** when using `WithEndpoints()` — the balancer starts a background health checker goroutine that must be stopped
|
||||
- **GetBody for retries**: Request bodies must be replayable. Use `NewJSONRequest`/`NewFormRequest` (they set `GetBody`) or set it manually
|
||||
- **statusWriter.Unwrap()**: Server middleware must not type-assert `http.ResponseWriter` directly — use `http.ResponseController` which calls `Unwrap()` to find `http.Flusher`, `http.Hijacker`, etc.
|
||||
- **No external deps**: This is a zero-dependency library. Do not add any `require` to `go.mod`
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
go build ./... # compile
|
||||
go test ./... # all tests
|
||||
go test -race ./... # tests with race detector
|
||||
go test -v -run TestName ./package/ # single test
|
||||
go vet ./... # static analysis
|
||||
```
|
||||
|
||||
## Conventions
|
||||
|
||||
- **Functional options** for all configuration (client and server)
|
||||
- **stdlib only** testing — no testify, no gomock
|
||||
- **Thread safety** — use `sync.Mutex`, `sync.Map`, or `atomic` where needed
|
||||
- **`internal/clock`** — use for deterministic time in tests (never `time.Now()` directly in testable code)
|
||||
- **Test helpers**: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server), `waitForAddr(t, srv)` for server integration tests
|
||||
- **Sentinel errors** live in sub-packages, root package re-exports as aliases
|
||||
56
CLAUDE.md
Normal file
56
CLAUDE.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# CLAUDE.md — httpx
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
go build ./... # compile
|
||||
go test ./... # all tests
|
||||
go test -race ./... # tests with race detector
|
||||
go test -v -run TestName ./package/ # single test
|
||||
go vet ./... # static analysis
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Module**: `git.codelab.vc/pkg/httpx`, Go 1.24, zero external dependencies
|
||||
|
||||
### Client
|
||||
- **Core pattern**: middleware is `func(http.RoundTripper) http.RoundTripper`
|
||||
- **Chain assembly order** (client.go): Logging → User MW → Retry → CB → Balancer → Transport
|
||||
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
|
||||
- **Circuit breaker** is per-host (`sync.Map` of host → Breaker)
|
||||
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
|
||||
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
|
||||
- **Client.Close()** stops the health checker goroutine
|
||||
- **Client.Patch()** — PATCH method, same pattern as Put/Post
|
||||
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
|
||||
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM
|
||||
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
|
||||
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
|
||||
|
||||
### Server (`server/`)
|
||||
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
|
||||
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
|
||||
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers, `WithNotFoundHandler` for custom 404
|
||||
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
|
||||
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
|
||||
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
|
||||
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
|
||||
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
|
||||
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock`
|
||||
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
|
||||
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
|
||||
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
|
||||
|
||||
## Conventions
|
||||
|
||||
- Functional options for all configuration (client and server)
|
||||
- Test helpers: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server)
|
||||
- Server tests use `waitForAddr(t, srv)` helper to poll until server is ready
|
||||
- No external test frameworks — stdlib only
|
||||
- Thread safety required (`sync.Mutex`/`atomic`)
|
||||
- `internal/clock` for deterministic time testing
|
||||
|
||||
## See also
|
||||
|
||||
- `AGENTS.md` — universal AI agent guide with common tasks, gotchas, and ASCII diagrams
|
||||
211
README.md
211
README.md
@@ -1,2 +1,213 @@
|
||||
# httpx
|
||||
|
||||
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking, request ID propagation, response size limits — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging, CORS, rate limiting, body limits, timeouts), health checks, JSON helpers, graceful shutdown. stdlib only, zero external deps.
|
||||
|
||||
```
|
||||
go get git.codelab.vc/pkg/httpx
|
||||
```
|
||||
|
||||
## Quick start
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithBaseURL("https://api.example.com"),
|
||||
httpx.WithTimeout(10*time.Second),
|
||||
httpx.WithRetry(retry.WithMaxAttempts(3)),
|
||||
httpx.WithMiddleware(
|
||||
middleware.UserAgent("my-service/1.0"),
|
||||
middleware.BearerAuth(func(ctx context.Context) (string, error) {
|
||||
return os.Getenv("API_TOKEN"), nil
|
||||
}),
|
||||
),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
resp, err := client.Get(ctx, "/users/123")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var user User
|
||||
resp.JSON(&user)
|
||||
|
||||
// PATCH request
|
||||
resp, err = client.Patch(ctx, "/users/123", strings.NewReader(`{"name":"updated"}`))
|
||||
|
||||
// Form-encoded request (OAuth, webhooks, etc.)
|
||||
req, _ := httpx.NewFormRequest(ctx, http.MethodPost, "/oauth/token", url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"scope": {"read write"},
|
||||
})
|
||||
resp, err = client.Do(ctx, req)
|
||||
```
|
||||
|
||||
## Packages
|
||||
|
||||
### Client
|
||||
|
||||
Client middleware is `func(http.RoundTripper) http.RoundTripper`. Use them with `httpx.Client` or plug into a plain `http.Client`.
|
||||
|
||||
| Package | What it does |
|
||||
|---------|-------------|
|
||||
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
|
||||
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
|
||||
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
|
||||
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery, request ID propagation. |
|
||||
|
||||
### Server
|
||||
|
||||
Server middleware is `func(http.Handler) http.Handler`. The `server` package provides a production-ready HTTP server.
|
||||
|
||||
| Component | What it does |
|
||||
|-----------|-------------|
|
||||
| `server.Server` | Wraps `http.Server` with graceful shutdown, signal handling, lifecycle logging. |
|
||||
| `server.Router` | Lightweight wrapper around `http.ServeMux` with groups, prefix routing, sub-router mounting. |
|
||||
| `server.RequestID` | Assigns/propagates `X-Request-Id` (UUID v4 via `crypto/rand`). |
|
||||
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
|
||||
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
|
||||
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
|
||||
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
|
||||
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. |
|
||||
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
|
||||
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
|
||||
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
|
||||
| `server.WriteError` | JSON error response (`{"error": "..."}`) helper. |
|
||||
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
|
||||
|
||||
The client assembles them in this order:
|
||||
|
||||
```
|
||||
Request → Logging → Your Middleware → Retry → Circuit Breaker → Balancer → Transport
|
||||
```
|
||||
|
||||
Retry wraps the circuit breaker and balancer, so each attempt can pick a different endpoint.
|
||||
|
||||
## Multi-DC setup
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithEndpoints(
|
||||
balancer.Endpoint{URL: "https://dc1.api.internal", Weight: 3},
|
||||
balancer.Endpoint{URL: "https://dc2.api.internal", Weight: 1},
|
||||
),
|
||||
httpx.WithBalancer(balancer.WithStrategy(balancer.WeightedRandom())),
|
||||
httpx.WithRetry(retry.WithMaxAttempts(4)),
|
||||
httpx.WithCircuitBreaker(circuitbreaker.WithFailureThreshold(5)),
|
||||
httpx.WithLogger(slog.Default()),
|
||||
)
|
||||
defer client.Close()
|
||||
```
|
||||
|
||||
## Standalone usage
|
||||
|
||||
Each component works with any `http.Client`, no need for the full wrapper:
|
||||
|
||||
```go
|
||||
// Just retry, nothing else
|
||||
transport := retry.Transport(retry.WithMaxAttempts(3))
|
||||
httpClient := &http.Client{
|
||||
Transport: transport(http.DefaultTransport),
|
||||
}
|
||||
```
|
||||
|
||||
```go
|
||||
// Chain a few middlewares together
|
||||
chain := middleware.Chain(
|
||||
middleware.Logging(slog.Default()),
|
||||
middleware.UserAgent("my-service/1.0"),
|
||||
retry.Transport(retry.WithMaxAttempts(2)),
|
||||
)
|
||||
httpClient := &http.Client{
|
||||
Transport: chain(http.DefaultTransport),
|
||||
}
|
||||
```
|
||||
|
||||
## Server
|
||||
|
||||
```go
|
||||
logger := slog.Default()
|
||||
|
||||
r := server.NewRouter(
|
||||
// Custom JSON 404 instead of plain text
|
||||
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteError(w, 404, "not found")
|
||||
})),
|
||||
)
|
||||
|
||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
|
||||
server.WriteJSON(w, 200, map[string]string{"message": "world"})
|
||||
})
|
||||
|
||||
// Groups with middleware
|
||||
api := r.Group("/api/v1", authMiddleware)
|
||||
api.HandleFunc("GET /users/{id}", getUser)
|
||||
|
||||
// Health checks
|
||||
r.Mount("/", server.HealthHandler(
|
||||
func() error { return db.Ping() },
|
||||
))
|
||||
|
||||
srv := server.New(r,
|
||||
append(server.Defaults(logger),
|
||||
// Protection middleware
|
||||
server.WithMiddleware(
|
||||
server.CORS(
|
||||
server.AllowOrigins("https://app.example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.MaxAge(3600),
|
||||
),
|
||||
server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(200),
|
||||
),
|
||||
server.MaxBodySize(1<<20), // 1 MB
|
||||
server.Timeout(30*time.Second),
|
||||
),
|
||||
)...,
|
||||
)
|
||||
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
|
||||
```
|
||||
|
||||
## Client request ID propagation
|
||||
|
||||
In microservices, forward the incoming request ID to downstream calls:
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithMiddleware(middleware.RequestID()),
|
||||
)
|
||||
|
||||
// In a server handler — the context already has the request ID from server.RequestID():
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// ID is automatically forwarded as X-Request-Id
|
||||
resp, err := client.Get(r.Context(), "https://downstream/api")
|
||||
}
|
||||
```
|
||||
|
||||
## Response body limit
|
||||
|
||||
Protect against OOM from unexpectedly large upstream responses:
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithMaxResponseBody(10 << 20), // 10 MB max
|
||||
)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the [`examples/`](examples/) directory for runnable programs:
|
||||
|
||||
| Example | Description |
|
||||
|---------|-------------|
|
||||
| [`basic-client`](examples/basic-client/) | HTTP client with retry, timeout, logging, and response size limit |
|
||||
| [`form-request`](examples/form-request/) | Form-encoded POST requests (OAuth, webhooks) |
|
||||
| [`load-balancing`](examples/load-balancing/) | Multi-endpoint client with weighted balancing, circuit breaker, and health checks |
|
||||
| [`server-basic`](examples/server-basic/) | Server with routing, groups, JSON helpers, health checks, and custom 404 |
|
||||
| [`server-protected`](examples/server-protected/) | Production server with CORS, rate limiting, body limits, and timeouts |
|
||||
| [`request-id-propagation`](examples/request-id-propagation/) | Request ID forwarding between server and client for distributed tracing |
|
||||
|
||||
## Requirements
|
||||
|
||||
Go 1.24+, stdlib only.
|
||||
|
||||
104
balancer/balancer.go
Normal file
104
balancer/balancer.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
// ErrNoHealthy is returned when no healthy endpoints are available.
|
||||
var ErrNoHealthy = errors.New("httpx: no healthy endpoints available")
|
||||
|
||||
// Endpoint represents a backend server that can handle requests.
|
||||
type Endpoint struct {
|
||||
URL string
|
||||
Weight int
|
||||
Meta map[string]string
|
||||
}
|
||||
|
||||
// Strategy selects an endpoint from the list of healthy endpoints.
|
||||
type Strategy interface {
|
||||
Next(healthy []Endpoint) (Endpoint, error)
|
||||
}
|
||||
|
||||
// Closer can be used to shut down resources associated with a balancer
|
||||
// transport (e.g. background health checker goroutines).
|
||||
type Closer struct {
|
||||
healthChecker *HealthChecker
|
||||
}
|
||||
|
||||
// Close stops background goroutines. Safe to call multiple times.
|
||||
func (c *Closer) Close() {
|
||||
if c.healthChecker != nil {
|
||||
c.healthChecker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns a middleware that load-balances requests across the
|
||||
// provided endpoints using the configured strategy.
|
||||
//
|
||||
// For each request the middleware picks an endpoint via the strategy,
|
||||
// replaces the request URL scheme and host with the endpoint's URL,
|
||||
// and forwards the request to the underlying RoundTripper.
|
||||
//
|
||||
// If active health checking is enabled (WithHealthCheck), a background
|
||||
// goroutine periodically probes endpoints. Otherwise all endpoints are
|
||||
// assumed healthy.
|
||||
func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Closer) {
|
||||
o := &options{
|
||||
strategy: RoundRobin(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// Pre-parse endpoint URLs once at construction time.
|
||||
parsed := make(map[string]*url.URL, len(endpoints))
|
||||
for _, ep := range endpoints {
|
||||
u, err := url.Parse(ep.URL)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
|
||||
}
|
||||
parsed[ep.URL] = u
|
||||
}
|
||||
|
||||
if o.healthChecker != nil {
|
||||
o.healthChecker.Start(endpoints)
|
||||
}
|
||||
|
||||
closer := &Closer{healthChecker: o.healthChecker}
|
||||
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
healthy := endpoints
|
||||
if o.healthChecker != nil {
|
||||
healthy = o.healthChecker.Healthy(endpoints)
|
||||
}
|
||||
|
||||
if len(healthy) == 0 {
|
||||
return nil, ErrNoHealthy
|
||||
}
|
||||
|
||||
ep, err := o.strategy.Next(healthy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
epURL := parsed[ep.URL]
|
||||
|
||||
// Shallow-copy request and URL to avoid mutating the original,
|
||||
// without the expense of req.Clone's deep header copy.
|
||||
r := *req
|
||||
u := *req.URL
|
||||
r.URL = &u
|
||||
r.URL.Scheme = epURL.Scheme
|
||||
r.URL.Host = epURL.Host
|
||||
r.Host = epURL.Host
|
||||
|
||||
return next.RoundTrip(&r)
|
||||
})
|
||||
}, closer
|
||||
}
|
||||
214
balancer/balancer_test.go
Normal file
214
balancer/balancer_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(fn)
|
||||
}
|
||||
|
||||
func okResponse() *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
|
||||
endpoints := []Endpoint{
|
||||
{URL: "https://backend1.example.com"},
|
||||
}
|
||||
|
||||
var captured *http.Request
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
mw, _ := Transport(endpoints)
|
||||
rt := mw(base)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if captured == nil {
|
||||
t.Fatal("base transport was not called")
|
||||
}
|
||||
if captured.URL.Scheme != "https" {
|
||||
t.Errorf("scheme = %q, want %q", captured.URL.Scheme, "https")
|
||||
}
|
||||
if captured.URL.Host != "backend1.example.com" {
|
||||
t.Errorf("host = %q, want %q", captured.URL.Host, "backend1.example.com")
|
||||
}
|
||||
if captured.URL.Path != "/api/v1/users" {
|
||||
t.Errorf("path = %q, want %q", captured.URL.Path, "/api/v1/users")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
||||
var endpoints []Endpoint
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
t.Fatal("base transport should not be called")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
mw, _ := Transport(endpoints)
|
||||
rt := mw(base)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = rt.RoundTrip(req)
|
||||
if err != ErrNoHealthy {
|
||||
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin_DistributesEvenly(t *testing.T) {
|
||||
endpoints := []Endpoint{
|
||||
{URL: "https://a.example.com"},
|
||||
{URL: "https://b.example.com"},
|
||||
{URL: "https://c.example.com"},
|
||||
}
|
||||
|
||||
rr := RoundRobin()
|
||||
counts := make(map[string]int)
|
||||
|
||||
const iterations = 300
|
||||
for i := 0; i < iterations; i++ {
|
||||
ep, err := rr.Next(endpoints)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||
}
|
||||
counts[ep.URL]++
|
||||
}
|
||||
|
||||
expected := iterations / len(endpoints)
|
||||
for _, ep := range endpoints {
|
||||
got := counts[ep.URL]
|
||||
if got != expected {
|
||||
t.Errorf("endpoint %s: got %d calls, want %d", ep.URL, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin_ErrNoHealthy(t *testing.T) {
|
||||
rr := RoundRobin()
|
||||
_, err := rr.Next(nil)
|
||||
if err != ErrNoHealthy {
|
||||
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_AlwaysPicksFirst(t *testing.T) {
|
||||
endpoints := []Endpoint{
|
||||
{URL: "https://primary.example.com"},
|
||||
{URL: "https://secondary.example.com"},
|
||||
{URL: "https://tertiary.example.com"},
|
||||
}
|
||||
|
||||
fo := Failover()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
ep, err := fo.Next(endpoints)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if ep.URL != "https://primary.example.com" {
|
||||
t.Errorf("iteration %d: got %q, want %q", i, ep.URL, "https://primary.example.com")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_ErrNoHealthy(t *testing.T) {
|
||||
fo := Failover()
|
||||
_, err := fo.Next(nil)
|
||||
if err != ErrNoHealthy {
|
||||
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeightedRandom_RespectsWeights(t *testing.T) {
|
||||
endpoints := []Endpoint{
|
||||
{URL: "https://heavy.example.com", Weight: 80},
|
||||
{URL: "https://light.example.com", Weight: 20},
|
||||
}
|
||||
|
||||
wr := WeightedRandom()
|
||||
counts := make(map[string]int)
|
||||
|
||||
const iterations = 10000
|
||||
for i := 0; i < iterations; i++ {
|
||||
ep, err := wr.Next(endpoints)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||
}
|
||||
counts[ep.URL]++
|
||||
}
|
||||
|
||||
totalWeight := 0
|
||||
for _, ep := range endpoints {
|
||||
totalWeight += ep.Weight
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
got := float64(counts[ep.URL]) / float64(iterations)
|
||||
want := float64(ep.Weight) / float64(totalWeight)
|
||||
if math.Abs(got-want) > 0.05 {
|
||||
t.Errorf("endpoint %s: got ratio %.3f, want ~%.3f (tolerance 0.05)", ep.URL, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeightedRandom_DefaultWeightForZero(t *testing.T) {
|
||||
endpoints := []Endpoint{
|
||||
{URL: "https://a.example.com", Weight: 0},
|
||||
{URL: "https://b.example.com", Weight: 0},
|
||||
}
|
||||
|
||||
wr := WeightedRandom()
|
||||
counts := make(map[string]int)
|
||||
|
||||
const iterations = 1000
|
||||
for i := 0; i < iterations; i++ {
|
||||
ep, err := wr.Next(endpoints)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||
}
|
||||
counts[ep.URL]++
|
||||
}
|
||||
|
||||
// With equal default weights, distribution should be roughly even.
|
||||
for _, ep := range endpoints {
|
||||
got := float64(counts[ep.URL]) / float64(iterations)
|
||||
if math.Abs(got-0.5) > 0.1 {
|
||||
t.Errorf("endpoint %s: got ratio %.3f, want ~0.5 (tolerance 0.1)", ep.URL, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeightedRandom_ErrNoHealthy(t *testing.T) {
|
||||
wr := WeightedRandom()
|
||||
_, err := wr.Next(nil)
|
||||
if err != ErrNoHealthy {
|
||||
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||
}
|
||||
}
|
||||
34
balancer/doc.go
Normal file
34
balancer/doc.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Package balancer provides client-side load balancing as HTTP middleware.
|
||||
//
|
||||
// It distributes requests across multiple backend endpoints using pluggable
|
||||
// strategies (round-robin, weighted, failover) with optional health checking.
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// mw, closer := balancer.Transport(
|
||||
// []balancer.Endpoint{
|
||||
// {URL: "http://backend1:8080"},
|
||||
// {URL: "http://backend2:8080"},
|
||||
// },
|
||||
// balancer.WithStrategy(balancer.RoundRobin()),
|
||||
// balancer.WithHealthCheck(5 * time.Second),
|
||||
// )
|
||||
// defer closer.Close()
|
||||
// transport := mw(http.DefaultTransport)
|
||||
//
|
||||
// # Strategies
|
||||
//
|
||||
// - RoundRobin — cycles through healthy endpoints
|
||||
// - Weighted — distributes based on endpoint Weight field
|
||||
// - Failover — prefers primary, falls back to secondaries
|
||||
//
|
||||
// # Health checking
|
||||
//
|
||||
// When enabled, a background goroutine periodically probes each endpoint.
|
||||
// The returned Closer must be closed to stop the health checker goroutine.
|
||||
// In httpx.Client, this is handled by Client.Close().
|
||||
//
|
||||
// # Sentinel errors
|
||||
//
|
||||
// ErrNoHealthy is returned when no healthy endpoints are available.
|
||||
package balancer
|
||||
17
balancer/failover.go
Normal file
17
balancer/failover.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package balancer
|
||||
|
||||
type failover struct{}
|
||||
|
||||
// Failover returns a strategy that always picks the first healthy endpoint.
|
||||
// If the primary endpoint is unhealthy, it falls back to the next available
|
||||
// healthy endpoint in order.
|
||||
func Failover() Strategy {
|
||||
return &failover{}
|
||||
}
|
||||
|
||||
func (f *failover) Next(healthy []Endpoint) (Endpoint, error) {
|
||||
if len(healthy) == 0 {
|
||||
return Endpoint{}, ErrNoHealthy
|
||||
}
|
||||
return healthy[0], nil
|
||||
}
|
||||
174
balancer/health.go
Normal file
174
balancer/health.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHealthInterval = 10 * time.Second
|
||||
defaultHealthPath = "/health"
|
||||
defaultHealthTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// HealthOption configures the HealthChecker.
|
||||
type HealthOption func(*HealthChecker)
|
||||
|
||||
// WithHealthInterval sets the interval between health check probes.
|
||||
// Default is 10 seconds.
|
||||
func WithHealthInterval(d time.Duration) HealthOption {
|
||||
return func(h *HealthChecker) {
|
||||
h.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithHealthPath sets the HTTP path to probe for health checks.
|
||||
// Default is "/health".
|
||||
func WithHealthPath(path string) HealthOption {
|
||||
return func(h *HealthChecker) {
|
||||
h.path = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithHealthTimeout sets the timeout for each health check request.
|
||||
// Default is 5 seconds.
|
||||
func WithHealthTimeout(d time.Duration) HealthOption {
|
||||
return func(h *HealthChecker) {
|
||||
h.timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
// HealthChecker periodically probes endpoints to determine their health status.
|
||||
type HealthChecker struct {
|
||||
interval time.Duration
|
||||
path string
|
||||
timeout time.Duration
|
||||
client *http.Client
|
||||
|
||||
mu sync.RWMutex
|
||||
status map[string]bool
|
||||
cancel context.CancelFunc
|
||||
stopped chan struct{}
|
||||
}
|
||||
|
||||
func newHealthChecker(opts ...HealthOption) *HealthChecker {
|
||||
h := &HealthChecker{
|
||||
interval: defaultHealthInterval,
|
||||
path: defaultHealthPath,
|
||||
timeout: defaultHealthTimeout,
|
||||
status: make(map[string]bool),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(h)
|
||||
}
|
||||
h.client = &http.Client{
|
||||
Timeout: h.timeout,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Start begins the background health checking loop for the given endpoints.
|
||||
// An initial probe is run synchronously so that unhealthy endpoints are
|
||||
// detected before the first request.
|
||||
func (h *HealthChecker) Start(endpoints []Endpoint) {
|
||||
// Mark all healthy as a safe default, then immediately probe.
|
||||
h.mu.Lock()
|
||||
for _, ep := range endpoints {
|
||||
h.status[ep.URL] = true
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.cancel = cancel
|
||||
h.stopped = make(chan struct{})
|
||||
|
||||
// Run initial probe synchronously so callers don't hit stale state.
|
||||
h.probe(ctx, endpoints)
|
||||
|
||||
go h.loop(ctx, endpoints)
|
||||
}
|
||||
|
||||
// Stop terminates the background health checking goroutine and waits for
|
||||
// it to finish.
|
||||
func (h *HealthChecker) Stop() {
|
||||
if h.cancel != nil {
|
||||
h.cancel()
|
||||
<-h.stopped
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy reports whether the given endpoint is currently healthy.
|
||||
func (h *HealthChecker) IsHealthy(ep Endpoint) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
healthy, ok := h.status[ep.URL]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return healthy
|
||||
}
|
||||
|
||||
// Healthy returns the subset of endpoints that are currently healthy.
|
||||
func (h *HealthChecker) Healthy(endpoints []Endpoint) []Endpoint {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
result := make([]Endpoint, 0, len(endpoints))
|
||||
for _, ep := range endpoints {
|
||||
if h.status[ep.URL] {
|
||||
result = append(result, ep)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *HealthChecker) loop(ctx context.Context, endpoints []Endpoint) {
|
||||
defer close(h.stopped)
|
||||
|
||||
ticker := time.NewTicker(h.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.probe(ctx, endpoints)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(endpoints))
|
||||
for _, ep := range endpoints {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
healthy := h.check(ctx, ep)
|
||||
h.mu.Lock()
|
||||
h.status[ep.URL] = healthy
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ep.URL+h.path, nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
}
|
||||
25
balancer/options.go
Normal file
25
balancer/options.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package balancer
|
||||
|
||||
// options holds configuration for the load balancer transport.
|
||||
type options struct {
|
||||
strategy Strategy // default RoundRobin
|
||||
healthChecker *HealthChecker // optional
|
||||
}
|
||||
|
||||
// Option configures the load balancer transport.
|
||||
type Option func(*options)
|
||||
|
||||
// WithStrategy sets the endpoint selection strategy.
|
||||
// If not specified, RoundRobin is used.
|
||||
func WithStrategy(s Strategy) Option {
|
||||
return func(o *options) {
|
||||
o.strategy = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithHealthCheck enables active health checking of endpoints.
|
||||
func WithHealthCheck(opts ...HealthOption) Option {
|
||||
return func(o *options) {
|
||||
o.healthChecker = newHealthChecker(opts...)
|
||||
}
|
||||
}
|
||||
21
balancer/roundrobin.go
Normal file
21
balancer/roundrobin.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package balancer
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type roundRobin struct {
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
// RoundRobin returns a strategy that cycles through healthy endpoints
|
||||
// sequentially using an atomic counter.
|
||||
func RoundRobin() Strategy {
|
||||
return &roundRobin{}
|
||||
}
|
||||
|
||||
func (r *roundRobin) Next(healthy []Endpoint) (Endpoint, error) {
|
||||
if len(healthy) == 0 {
|
||||
return Endpoint{}, ErrNoHealthy
|
||||
}
|
||||
idx := r.counter.Add(1) - 1
|
||||
return healthy[idx%uint64(len(healthy))], nil
|
||||
}
|
||||
42
balancer/weighted.go
Normal file
42
balancer/weighted.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package balancer
|
||||
|
||||
import "math/rand/v2"
|
||||
|
||||
type weightedRandom struct{}
|
||||
|
||||
// WeightedRandom returns a strategy that selects endpoints randomly,
|
||||
// weighted by each endpoint's Weight field. Endpoints with Weight <= 0
|
||||
// are treated as having a weight of 1.
|
||||
func WeightedRandom() Strategy {
|
||||
return &weightedRandom{}
|
||||
}
|
||||
|
||||
func (w *weightedRandom) Next(healthy []Endpoint) (Endpoint, error) {
|
||||
if len(healthy) == 0 {
|
||||
return Endpoint{}, ErrNoHealthy
|
||||
}
|
||||
|
||||
totalWeight := 0
|
||||
for _, ep := range healthy {
|
||||
weight := ep.Weight
|
||||
if weight <= 0 {
|
||||
weight = 1
|
||||
}
|
||||
totalWeight += weight
|
||||
}
|
||||
|
||||
r := rand.IntN(totalWeight)
|
||||
for _, ep := range healthy {
|
||||
weight := ep.Weight
|
||||
if weight <= 0 {
|
||||
weight = 1
|
||||
}
|
||||
r -= weight
|
||||
if r < 0 {
|
||||
return ep, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Should never reach here, but return last endpoint as a safeguard.
|
||||
return healthy[len(healthy)-1], nil
|
||||
}
|
||||
176
circuitbreaker/breaker.go
Normal file
176
circuitbreaker/breaker.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
// ErrCircuitOpen is returned by Allow when the breaker is in the Open state.
|
||||
var ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
|
||||
|
||||
// State represents the current state of a circuit breaker.
|
||||
type State int
|
||||
|
||||
const (
|
||||
StateClosed State = iota // normal operation
|
||||
StateOpen // failing, reject requests
|
||||
StateHalfOpen // testing recovery
|
||||
)
|
||||
|
||||
// String returns a human-readable name for the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Breaker implements a per-endpoint circuit breaker state machine.
|
||||
//
|
||||
// State transitions:
|
||||
//
|
||||
// Closed → Open: after failureThreshold consecutive failures
|
||||
// Open → HalfOpen: after openDuration passes
|
||||
// HalfOpen → Closed: on success
|
||||
// HalfOpen → Open: on failure (timer resets)
|
||||
type Breaker struct {
|
||||
mu sync.Mutex
|
||||
opts options
|
||||
|
||||
state State
|
||||
failures int // consecutive failure count (Closed state)
|
||||
openedAt time.Time
|
||||
halfOpenCur int // current in-flight half-open requests
|
||||
}
|
||||
|
||||
// NewBreaker creates a Breaker with the given options.
|
||||
func NewBreaker(opts ...Option) *Breaker {
|
||||
o := defaults()
|
||||
for _, fn := range opts {
|
||||
fn(&o)
|
||||
}
|
||||
return &Breaker{opts: o}
|
||||
}
|
||||
|
||||
// State returns the current state of the breaker.
|
||||
func (b *Breaker) State() State {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.stateLocked()
|
||||
}
|
||||
|
||||
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
||||
// open duration has elapsed. Caller must hold b.mu.
|
||||
func (b *Breaker) stateLocked() State {
|
||||
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
|
||||
b.state = StateHalfOpen
|
||||
b.halfOpenCur = 0
|
||||
}
|
||||
return b.state
|
||||
}
|
||||
|
||||
// Allow checks whether a request is permitted. If allowed it returns a done
|
||||
// callback that the caller MUST invoke with the result of the request. If the
|
||||
// breaker is open, it returns ErrCircuitOpen.
|
||||
func (b *Breaker) Allow() (done func(success bool), err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.stateLocked() {
|
||||
case StateClosed:
|
||||
// always allow
|
||||
case StateOpen:
|
||||
return nil, ErrCircuitOpen
|
||||
case StateHalfOpen:
|
||||
if b.halfOpenCur >= b.opts.halfOpenMax {
|
||||
return nil, ErrCircuitOpen
|
||||
}
|
||||
b.halfOpenCur++
|
||||
}
|
||||
|
||||
return b.doneFunc(), nil
|
||||
}
|
||||
|
||||
// doneFunc returns the callback for a single in-flight request. Caller must
|
||||
// hold b.mu when calling doneFunc, but the returned function acquires the lock
|
||||
// itself.
|
||||
func (b *Breaker) doneFunc() func(success bool) {
|
||||
var once sync.Once
|
||||
return func(success bool) {
|
||||
once.Do(func() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.record(success)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// record processes the outcome of a single request. Caller must hold b.mu.
|
||||
func (b *Breaker) record(success bool) {
|
||||
switch b.state {
|
||||
case StateClosed:
|
||||
if success {
|
||||
b.failures = 0
|
||||
return
|
||||
}
|
||||
b.failures++
|
||||
if b.failures >= b.opts.failureThreshold {
|
||||
b.tripLocked()
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
b.halfOpenCur--
|
||||
if success {
|
||||
b.state = StateClosed
|
||||
b.failures = 0
|
||||
} else {
|
||||
b.tripLocked()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tripLocked transitions to the Open state and records the timestamp.
|
||||
func (b *Breaker) tripLocked() {
|
||||
b.state = StateOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenCur = 0
|
||||
}
|
||||
|
||||
// Transport returns a middleware that applies per-host circuit breaking. It
|
||||
// maintains an internal map of host → *Breaker so each target host is tracked
|
||||
// independently.
|
||||
func Transport(opts ...Option) middleware.Middleware {
|
||||
var hosts sync.Map // map[string]*Breaker
|
||||
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
host := req.URL.Host
|
||||
|
||||
val, ok := hosts.Load(host)
|
||||
if !ok {
|
||||
val, _ = hosts.LoadOrStore(host, NewBreaker(opts...))
|
||||
}
|
||||
cb := val.(*Breaker)
|
||||
|
||||
done, err := cb.Allow()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, rtErr := next.RoundTrip(req)
|
||||
done(rtErr == nil && resp != nil && resp.StatusCode < 500)
|
||||
|
||||
return resp, rtErr
|
||||
})
|
||||
}
|
||||
}
|
||||
249
circuitbreaker/breaker_test.go
Normal file
249
circuitbreaker/breaker_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(fn)
|
||||
}
|
||||
|
||||
func okResponse() *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func errResponse(code int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_StartsInClosedState(t *testing.T) {
|
||||
b := NewBreaker()
|
||||
if s := b.State(); s != StateClosed {
|
||||
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) {
|
||||
const threshold = 3
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(threshold),
|
||||
WithOpenDuration(time.Hour), // long duration so it stays open
|
||||
)
|
||||
|
||||
for i := 0; i < threshold; i++ {
|
||||
done, err := b.Allow()
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: Allow returned error: %v", i, err)
|
||||
}
|
||||
done(false)
|
||||
}
|
||||
|
||||
if s := b.State(); s != StateOpen {
|
||||
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_OpenRejectsRequests(t *testing.T) {
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(time.Hour),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
done, err := b.Allow()
|
||||
if err != nil {
|
||||
t.Fatalf("Allow returned error: %v", err)
|
||||
}
|
||||
done(false)
|
||||
|
||||
// Subsequent requests should be rejected.
|
||||
_, err = b.Allow()
|
||||
if !errors.Is(err, ErrCircuitOpen) {
|
||||
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
done, err := b.Allow()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
done(false)
|
||||
|
||||
if s := b.State(); s != StateOpen {
|
||||
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||
}
|
||||
|
||||
// Wait for the open duration to elapse.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
|
||||
if s := b.State(); s != StateHalfOpen {
|
||||
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
done, err := b.Allow()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
done(false)
|
||||
|
||||
// Wait for half-open.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
|
||||
// A successful request in half-open should close the breaker.
|
||||
done, err = b.Allow()
|
||||
if err != nil {
|
||||
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||
}
|
||||
done(true)
|
||||
|
||||
if s := b.State(); s != StateClosed {
|
||||
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||
const openDuration = 50 * time.Millisecond
|
||||
b := NewBreaker(
|
||||
WithFailureThreshold(1),
|
||||
WithOpenDuration(openDuration),
|
||||
)
|
||||
|
||||
// Trip the breaker.
|
||||
done, err := b.Allow()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
done(false)
|
||||
|
||||
// Wait for half-open.
|
||||
time.Sleep(openDuration + 10*time.Millisecond)
|
||||
|
||||
// A failed request in half-open should re-open the breaker.
|
||||
done, err = b.Allow()
|
||||
if err != nil {
|
||||
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||
}
|
||||
done(false)
|
||||
|
||||
if s := b.State(); s != StateOpen {
|
||||
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport_PerHostBreakers(t *testing.T) {
|
||||
const threshold = 2
|
||||
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Host == "failing.example.com" {
|
||||
return errResponse(http.StatusInternalServerError), nil
|
||||
}
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
rt := Transport(
|
||||
WithFailureThreshold(threshold),
|
||||
WithOpenDuration(time.Hour),
|
||||
)(base)
|
||||
|
||||
t.Run("failing host trips breaker", func(t *testing.T) {
|
||||
for i := 0; i < threshold; i++ {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Next request to failing host should be rejected.
|
||||
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rt.RoundTrip(req)
|
||||
if !errors.Is(err, ErrCircuitOpen) {
|
||||
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("healthy host is unaffected", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransport_SuccessResetsFailures(t *testing.T) {
|
||||
callCount := 0
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
callCount++
|
||||
// Fail on odd calls, succeed on even.
|
||||
if callCount%2 == 1 {
|
||||
return errResponse(http.StatusInternalServerError), nil
|
||||
}
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
rt := Transport(
|
||||
WithFailureThreshold(3),
|
||||
WithOpenDuration(time.Hour),
|
||||
)(base)
|
||||
|
||||
// Alternate fail/success — should never trip because successes reset the
|
||||
// consecutive failure counter.
|
||||
for i := 0; i < 10; i++ {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
27
circuitbreaker/doc.go
Normal file
27
circuitbreaker/doc.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Package circuitbreaker provides a per-host circuit breaker as HTTP middleware.
|
||||
//
|
||||
// The circuit breaker monitors request failures and temporarily blocks requests
|
||||
// to unhealthy hosts, allowing them time to recover before retrying.
|
||||
//
|
||||
// # State machine
|
||||
//
|
||||
// - Closed — normal operation, requests pass through
|
||||
// - Open — too many failures, requests are rejected with ErrCircuitOpen
|
||||
// - HalfOpen — after a cooldown period, one probe request is allowed through
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// mw := circuitbreaker.Transport(
|
||||
// circuitbreaker.WithThreshold(5),
|
||||
// circuitbreaker.WithTimeout(30 * time.Second),
|
||||
// )
|
||||
// transport := mw(http.DefaultTransport)
|
||||
//
|
||||
// The circuit breaker is per-host: each unique request host gets its own
|
||||
// independent breaker state machine stored in a sync.Map.
|
||||
//
|
||||
// # Sentinel errors
|
||||
//
|
||||
// ErrCircuitOpen is returned when a request is rejected because the circuit
|
||||
// is in the Open state.
|
||||
package circuitbreaker
|
||||
50
circuitbreaker/options.go
Normal file
50
circuitbreaker/options.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package circuitbreaker
|
||||
|
||||
import "time"
|
||||
|
||||
type options struct {
|
||||
failureThreshold int // consecutive failures to trip
|
||||
openDuration time.Duration // how long to stay open before half-open
|
||||
halfOpenMax int // max concurrent requests in half-open
|
||||
}
|
||||
|
||||
func defaults() options {
|
||||
return options{
|
||||
failureThreshold: 5,
|
||||
openDuration: 30 * time.Second,
|
||||
halfOpenMax: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures a Breaker.
|
||||
type Option func(*options)
|
||||
|
||||
// WithFailureThreshold sets the number of consecutive failures required to
|
||||
// trip the breaker from Closed to Open. Default is 5.
|
||||
func WithFailureThreshold(n int) Option {
|
||||
return func(o *options) {
|
||||
if n > 0 {
|
||||
o.failureThreshold = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithOpenDuration sets how long the breaker stays in the Open state before
|
||||
// transitioning to HalfOpen. Default is 30s.
|
||||
func WithOpenDuration(d time.Duration) Option {
|
||||
return func(o *options) {
|
||||
if d > 0 {
|
||||
o.openDuration = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithHalfOpenMax sets the maximum number of concurrent probe requests
|
||||
// allowed while the breaker is in the HalfOpen state. Default is 1.
|
||||
func WithHalfOpenMax(n int) Option {
|
||||
return func(o *options) {
|
||||
if n > 0 {
|
||||
o.halfOpenMax = n
|
||||
}
|
||||
}
|
||||
}
|
||||
204
client.go
Normal file
204
client.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/balancer"
|
||||
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
// Client is a high-level HTTP client that composes middleware for retry,
|
||||
// circuit breaking, load balancing, logging, and more.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
errorMapper ErrorMapper
|
||||
balancerCloser *balancer.Closer
|
||||
maxResponseBody int64
|
||||
}
|
||||
|
||||
// New creates a new Client with the given options.
|
||||
//
|
||||
// The middleware chain is assembled as (outermost → innermost):
|
||||
//
|
||||
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Base Transport
|
||||
func New(opts ...Option) *Client {
|
||||
o := &clientOptions{
|
||||
transport: http.DefaultTransport,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// Build the middleware chain from inside out.
|
||||
var chain []middleware.Middleware
|
||||
|
||||
// Balancer (innermost, wraps base transport).
|
||||
var balancerCloser *balancer.Closer
|
||||
if len(o.endpoints) > 0 {
|
||||
var mw middleware.Middleware
|
||||
mw, balancerCloser = balancer.Transport(o.endpoints, o.balancerOpts...)
|
||||
chain = append(chain, mw)
|
||||
}
|
||||
|
||||
// Circuit breaker wraps balancer.
|
||||
if o.enableCB {
|
||||
chain = append(chain, circuitbreaker.Transport(o.cbOpts...))
|
||||
}
|
||||
|
||||
// Retry wraps circuit breaker + balancer.
|
||||
if o.enableRetry {
|
||||
chain = append(chain, retry.Transport(o.retryOpts...))
|
||||
}
|
||||
|
||||
// User middlewares.
|
||||
for i := len(o.middlewares) - 1; i >= 0; i-- {
|
||||
chain = append(chain, o.middlewares[i])
|
||||
}
|
||||
|
||||
// Logging (outermost).
|
||||
if o.logger != nil {
|
||||
chain = append(chain, middleware.Logging(o.logger))
|
||||
}
|
||||
|
||||
// Assemble: chain[last] is outermost.
|
||||
rt := o.transport
|
||||
for _, mw := range chain {
|
||||
rt = mw(rt)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Transport: rt,
|
||||
Timeout: o.timeout,
|
||||
},
|
||||
baseURL: o.baseURL,
|
||||
errorMapper: o.errorMapper,
|
||||
balancerCloser: balancerCloser,
|
||||
maxResponseBody: o.maxResponseBody,
|
||||
}
|
||||
}
|
||||
|
||||
// Do executes an HTTP request.
|
||||
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
||||
req = req.WithContext(ctx)
|
||||
if err := c.resolveURL(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &Error{
|
||||
Op: req.Method,
|
||||
URL: req.URL.String(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if c.maxResponseBody > 0 {
|
||||
resp.Body = &limitedReadCloser{
|
||||
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody},
|
||||
C: resp.Body,
|
||||
}
|
||||
}
|
||||
|
||||
r := newResponse(resp)
|
||||
|
||||
if c.errorMapper != nil {
|
||||
if mapErr := c.errorMapper(resp); mapErr != nil {
|
||||
return r, &Error{
|
||||
Op: req.Method,
|
||||
URL: req.URL.String(),
|
||||
StatusCode: resp.StatusCode,
|
||||
Err: mapErr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Get performs a GET request to the given URL.
|
||||
func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Post performs a POST request to the given URL with the given body.
|
||||
func (c *Client) Post(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Put performs a PUT request to the given URL with the given body.
|
||||
func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Patch performs a PATCH request to the given URL with the given body.
|
||||
func (c *Client) Patch(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Delete performs a DELETE request to the given URL.
|
||||
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Close releases resources associated with the Client, such as background
|
||||
// health checker goroutines. It is safe to call multiple times.
|
||||
func (c *Client) Close() {
|
||||
if c.balancerCloser != nil {
|
||||
c.balancerCloser.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient returns the underlying *http.Client for advanced use cases.
|
||||
// Mutating the returned client may bypass the configured middleware chain.
|
||||
func (c *Client) HTTPClient() *http.Client {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func (c *Client) resolveURL(req *http.Request) error {
|
||||
if c.baseURL == "" {
|
||||
return nil
|
||||
}
|
||||
// Only resolve relative URLs (no scheme).
|
||||
if req.URL.Scheme == "" && req.URL.Host == "" {
|
||||
path := req.URL.String()
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
base := strings.TrimRight(c.baseURL, "/")
|
||||
u, err := req.URL.Parse(base + path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("httpx: resolving URL %q with base %q: %w", path, c.baseURL, err)
|
||||
}
|
||||
req.URL = u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
96
client_options.go
Normal file
96
client_options.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/balancer"
|
||||
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
type clientOptions struct {
|
||||
baseURL string
|
||||
timeout time.Duration
|
||||
transport http.RoundTripper
|
||||
logger *slog.Logger
|
||||
errorMapper ErrorMapper
|
||||
middlewares []middleware.Middleware
|
||||
retryOpts []retry.Option
|
||||
enableRetry bool
|
||||
cbOpts []circuitbreaker.Option
|
||||
enableCB bool
|
||||
endpoints []balancer.Endpoint
|
||||
balancerOpts []balancer.Option
|
||||
maxResponseBody int64
|
||||
}
|
||||
|
||||
// Option configures a Client.
|
||||
type Option func(*clientOptions)
|
||||
|
||||
// WithBaseURL sets the base URL prepended to all relative request paths.
|
||||
func WithBaseURL(url string) Option {
|
||||
return func(o *clientOptions) { o.baseURL = url }
|
||||
}
|
||||
|
||||
// WithTimeout sets the overall request timeout.
|
||||
func WithTimeout(d time.Duration) Option {
|
||||
return func(o *clientOptions) { o.timeout = d }
|
||||
}
|
||||
|
||||
// WithTransport sets the base http.RoundTripper. Defaults to http.DefaultTransport.
|
||||
func WithTransport(rt http.RoundTripper) Option {
|
||||
return func(o *clientOptions) { o.transport = rt }
|
||||
}
|
||||
|
||||
// WithLogger enables structured logging of requests and responses.
|
||||
func WithLogger(l *slog.Logger) Option {
|
||||
return func(o *clientOptions) { o.logger = l }
|
||||
}
|
||||
|
||||
// WithErrorMapper sets a function that maps HTTP responses to errors.
|
||||
func WithErrorMapper(m ErrorMapper) Option {
|
||||
return func(o *clientOptions) { o.errorMapper = m }
|
||||
}
|
||||
|
||||
// WithMiddleware appends user middlewares to the chain.
|
||||
// These run between logging and retry in the middleware stack.
|
||||
func WithMiddleware(mws ...middleware.Middleware) Option {
|
||||
return func(o *clientOptions) { o.middlewares = append(o.middlewares, mws...) }
|
||||
}
|
||||
|
||||
// WithRetry enables retry with the given options.
|
||||
func WithRetry(opts ...retry.Option) Option {
|
||||
return func(o *clientOptions) {
|
||||
o.enableRetry = true
|
||||
o.retryOpts = opts
|
||||
}
|
||||
}
|
||||
|
||||
// WithCircuitBreaker enables per-host circuit breaking.
|
||||
func WithCircuitBreaker(opts ...circuitbreaker.Option) Option {
|
||||
return func(o *clientOptions) {
|
||||
o.enableCB = true
|
||||
o.cbOpts = opts
|
||||
}
|
||||
}
|
||||
|
||||
// WithEndpoints sets the endpoints for load balancing.
|
||||
func WithEndpoints(eps ...balancer.Endpoint) Option {
|
||||
return func(o *clientOptions) { o.endpoints = eps }
|
||||
}
|
||||
|
||||
// WithBalancer configures the load balancer strategy and options.
|
||||
func WithBalancer(opts ...balancer.Option) Option {
|
||||
return func(o *clientOptions) { o.balancerOpts = opts }
|
||||
}
|
||||
|
||||
// WithMaxResponseBody limits the number of bytes read from response bodies
|
||||
// by Response.Bytes (and by extension String, JSON, XML). If the response
|
||||
// body exceeds n bytes, reading stops and returns an error.
|
||||
// A value of 0 means no limit (the default).
|
||||
func WithMaxResponseBody(n int64) Option {
|
||||
return func(o *clientOptions) { o.maxResponseBody = n }
|
||||
}
|
||||
45
client_patch_test.go
Normal file
45
client_patch_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestClient_Patch(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPatch {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
if string(b) != `{"name":"updated"}` {
|
||||
t.Errorf("expected body %q, got %q", `{"name":"updated"}`, string(b))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "patched")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Patch(context.Background(), srv.URL+"/item/1", strings.NewReader(`{"name":"updated"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "patched" {
|
||||
t.Errorf("expected body %q, got %q", "patched", body)
|
||||
}
|
||||
}
|
||||
311
client_test.go
Normal file
311
client_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
"git.codelab.vc/pkg/httpx/balancer"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
func TestClient_Get(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("expected GET, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "hello")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "hello" {
|
||||
t.Errorf("expected body %q, got %q", "hello", body)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Post(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
if string(b) != "request-body" {
|
||||
t.Errorf("expected body %q, got %q", "request-body", string(b))
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
fmt.Fprint(w, "created")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Post(context.Background(), srv.URL+"/items", strings.NewReader("request-body"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Errorf("expected status 201, got %d", resp.StatusCode)
|
||||
}
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "created" {
|
||||
t.Errorf("expected body %q, got %q", "created", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_BaseURL(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v1/users" {
|
||||
t.Errorf("expected path /api/v1/users, got %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithBaseURL(srv.URL + "/api/v1"))
|
||||
|
||||
// Use a relative path (no scheme/host).
|
||||
resp, err := client.Get(context.Background(), "/users")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithMiddleware(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
val := r.Header.Get("X-Custom-Header")
|
||||
if val != "test-value" {
|
||||
t.Errorf("expected header X-Custom-Header=%q, got %q", "test-value", val)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
addHeader := func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set("X-Custom-Header", "test-value")
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
|
||||
client := httpx.New(httpx.WithMiddleware(addHeader))
|
||||
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_RetryIntegration(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
n := calls.Add(1)
|
||||
if n <= 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "success")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(
|
||||
httpx.WithRetry(
|
||||
retry.WithMaxAttempts(3),
|
||||
retry.WithBackoff(retry.ConstantBackoff(1*time.Millisecond)),
|
||||
),
|
||||
)
|
||||
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/flaky")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "success" {
|
||||
t.Errorf("expected body %q, got %q", "success", body)
|
||||
}
|
||||
|
||||
if got := calls.Load(); got != 3 {
|
||||
t.Errorf("expected 3 total requests, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_BalancerIntegration(t *testing.T) {
|
||||
var hits1, hits2 atomic.Int32
|
||||
|
||||
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits1.Add(1)
|
||||
fmt.Fprint(w, "server1")
|
||||
}))
|
||||
defer srv1.Close()
|
||||
|
||||
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits2.Add(1)
|
||||
fmt.Fprint(w, "server2")
|
||||
}))
|
||||
defer srv2.Close()
|
||||
|
||||
client := httpx.New(
|
||||
httpx.WithEndpoints(
|
||||
balancer.Endpoint{URL: srv1.URL},
|
||||
balancer.Endpoint{URL: srv2.URL},
|
||||
),
|
||||
)
|
||||
|
||||
const totalRequests = 6
|
||||
for i := range totalRequests {
|
||||
resp, err := client.Get(context.Background(), fmt.Sprintf("/item/%d", i))
|
||||
if err != nil {
|
||||
t.Fatalf("request %d: unexpected error: %v", i, err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
h1 := hits1.Load()
|
||||
h2 := hits2.Load()
|
||||
|
||||
if h1+h2 != totalRequests {
|
||||
t.Errorf("expected %d total hits, got %d", totalRequests, h1+h2)
|
||||
}
|
||||
if h1 == 0 || h2 == 0 {
|
||||
t.Errorf("expected requests distributed across both servers, got server1=%d server2=%d", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ErrorMapper(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, "not found")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
mapper := func(resp *http.Response) error {
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
client := httpx.New(httpx.WithErrorMapper(mapper))
|
||||
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/missing")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
// The response should still be returned alongside the error.
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response even on mapped error")
|
||||
}
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message contains the status code.
|
||||
if !strings.Contains(err.Error(), "404") {
|
||||
t.Errorf("expected error to contain 404, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_JSON(t *testing.T) {
|
||||
type reqPayload struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
type respPayload struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
|
||||
t.Errorf("expected Content-Type application/json, got %q", ct)
|
||||
}
|
||||
|
||||
var p reqPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||
t.Errorf("decoding request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if p.Name != "Alice" || p.Age != 30 {
|
||||
t.Errorf("unexpected payload: %+v", p)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(respPayload{ID: 1, Name: p.Name})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
|
||||
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, srv.URL+"/users", reqPayload{
|
||||
Name: "Alice",
|
||||
Age: 30,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("creating JSON request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var result respPayload
|
||||
if err := resp.JSON(&result); err != nil {
|
||||
t.Fatalf("decoding JSON response: %v", err)
|
||||
}
|
||||
if result.ID != 1 {
|
||||
t.Errorf("expected ID 1, got %d", result.ID)
|
||||
}
|
||||
if result.Name != "Alice" {
|
||||
t.Errorf("expected Name %q, got %q", "Alice", result.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure slog import is used (referenced in imports for completeness with the spec).
|
||||
var _ = slog.Default
|
||||
39
doc.go
Normal file
39
doc.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Package httpx provides a high-level HTTP client with composable middleware
|
||||
// for retry, circuit breaking, load balancing, structured logging, and more.
|
||||
//
|
||||
// The client is configured via functional options and assembled as a middleware
|
||||
// chain around a standard http.RoundTripper:
|
||||
//
|
||||
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport
|
||||
//
|
||||
// # Quick start
|
||||
//
|
||||
// client := httpx.New(
|
||||
// httpx.WithBaseURL("https://api.example.com"),
|
||||
// httpx.WithTimeout(10 * time.Second),
|
||||
// httpx.WithRetry(),
|
||||
// httpx.WithCircuitBreaker(),
|
||||
// )
|
||||
// defer client.Close()
|
||||
//
|
||||
// resp, err := client.Get(ctx, "/users/1")
|
||||
//
|
||||
// # Request builders
|
||||
//
|
||||
// NewJSONRequest and NewFormRequest create requests with appropriate
|
||||
// Content-Type headers and GetBody set for retry compatibility.
|
||||
//
|
||||
// # Error handling
|
||||
//
|
||||
// Failed requests return *httpx.Error with structured fields (Op, URL,
|
||||
// StatusCode). Sentinel errors ErrRetryExhausted, ErrCircuitOpen, and
|
||||
// ErrNoHealthy can be checked with errors.Is.
|
||||
//
|
||||
// # Sub-packages
|
||||
//
|
||||
// - middleware — client-side middleware (logging, auth, headers, recovery, request ID)
|
||||
// - retry — configurable retry with backoff and Retry-After support
|
||||
// - circuitbreaker — per-host circuit breaker (closed → open → half-open)
|
||||
// - balancer — client-side load balancing with health checking
|
||||
// - server — production HTTP server with router, middleware, and graceful shutdown
|
||||
package httpx
|
||||
49
error.go
Normal file
49
error.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/balancer"
|
||||
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by httpx components.
|
||||
// These are aliases for the canonical errors defined in sub-packages,
|
||||
// so that errors.Is works regardless of which import the caller uses.
|
||||
var (
|
||||
ErrRetryExhausted = retry.ErrRetryExhausted
|
||||
ErrCircuitOpen = circuitbreaker.ErrCircuitOpen
|
||||
ErrNoHealthy = balancer.ErrNoHealthy
|
||||
)
|
||||
|
||||
// Error provides structured error information for failed HTTP operations.
|
||||
type Error struct {
|
||||
// Op is the operation that failed (e.g. "Get", "Do").
|
||||
Op string
|
||||
// URL is the originally-requested URL.
|
||||
URL string
|
||||
// Endpoint is the resolved endpoint URL (after balancing).
|
||||
Endpoint string
|
||||
// StatusCode is the HTTP status code, if a response was received.
|
||||
StatusCode int
|
||||
// Retries is the number of retry attempts made.
|
||||
Retries int
|
||||
// Err is the underlying error.
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.Endpoint != "" && e.Endpoint != e.URL {
|
||||
return fmt.Sprintf("httpx: %s %s (endpoint %s): %v", e.Op, e.URL, e.Endpoint, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("httpx: %s %s: %v", e.Op, e.URL, e.Err)
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error { return e.Err }
|
||||
|
||||
// ErrorMapper maps an HTTP response to an error. If the response is
|
||||
// acceptable, the mapper should return nil. Used by Client to convert
|
||||
// non-successful HTTP responses into Go errors.
|
||||
type ErrorMapper func(resp *http.Response) error
|
||||
90
error_test.go
Normal file
90
error_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
t.Run("formats without endpoint", func(t *testing.T) {
|
||||
inner := errors.New("connection refused")
|
||||
e := &httpx.Error{
|
||||
Op: "Get",
|
||||
URL: "http://example.com/api",
|
||||
Err: inner,
|
||||
}
|
||||
|
||||
want := "httpx: Get http://example.com/api: connection refused"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("formats with endpoint different from url", func(t *testing.T) {
|
||||
inner := errors.New("timeout")
|
||||
e := &httpx.Error{
|
||||
Op: "Do",
|
||||
URL: "http://example.com/api",
|
||||
Endpoint: "http://node1.example.com/api",
|
||||
Err: inner,
|
||||
}
|
||||
|
||||
want := "httpx: Do http://example.com/api (endpoint http://node1.example.com/api): timeout"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("formats with endpoint same as url", func(t *testing.T) {
|
||||
inner := errors.New("not found")
|
||||
e := &httpx.Error{
|
||||
Op: "Get",
|
||||
URL: "http://example.com/api",
|
||||
Endpoint: "http://example.com/api",
|
||||
Err: inner,
|
||||
}
|
||||
|
||||
want := "httpx: Get http://example.com/api: not found"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unwrap returns inner error", func(t *testing.T) {
|
||||
inner := errors.New("underlying")
|
||||
e := &httpx.Error{Op: "Get", URL: "http://example.com", Err: inner}
|
||||
|
||||
if got := e.Unwrap(); got != inner {
|
||||
t.Errorf("Unwrap() = %v, want %v", got, inner)
|
||||
}
|
||||
|
||||
if !errors.Is(e, inner) {
|
||||
t.Error("errors.Is should find the inner error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentinelErrors(t *testing.T) {
|
||||
t.Run("ErrRetryExhausted", func(t *testing.T) {
|
||||
if httpx.ErrRetryExhausted == nil {
|
||||
t.Fatal("ErrRetryExhausted is nil")
|
||||
}
|
||||
if httpx.ErrRetryExhausted.Error() == "" {
|
||||
t.Fatal("ErrRetryExhausted has empty message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrCircuitOpen", func(t *testing.T) {
|
||||
if httpx.ErrCircuitOpen == nil {
|
||||
t.Fatal("ErrCircuitOpen is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrNoHealthy", func(t *testing.T) {
|
||||
if httpx.ErrNoHealthy == nil {
|
||||
t.Fatal("ErrNoHealthy is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
67
examples/basic-client/main.go
Normal file
67
examples/basic-client/main.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Basic HTTP client with retry, timeout, and structured logging.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client := httpx.New(
|
||||
httpx.WithBaseURL("https://httpbin.org"),
|
||||
httpx.WithTimeout(10*time.Second),
|
||||
httpx.WithRetry(
|
||||
retry.WithMaxAttempts(3),
|
||||
retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 2*time.Second, true)),
|
||||
),
|
||||
httpx.WithMiddleware(
|
||||
middleware.UserAgent("httpx-example/1.0"),
|
||||
),
|
||||
httpx.WithMaxResponseBody(1<<20), // 1 MB limit
|
||||
httpx.WithLogger(slog.Default()),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// GET request.
|
||||
resp, err := client.Get(ctx, "/get")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("GET /get → %d (%d bytes)\n", resp.StatusCode, len(body))
|
||||
|
||||
// POST with JSON body.
|
||||
type payload struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
req, err := httpx.NewJSONRequest(ctx, "POST", "/post", payload{
|
||||
Name: "Alice",
|
||||
Email: "alice@example.com",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err = client.Do(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
fmt.Printf("POST /post → %d\n", resp.StatusCode)
|
||||
}
|
||||
41
examples/form-request/main.go
Normal file
41
examples/form-request/main.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Demonstrates form-encoded requests for OAuth token endpoints and similar APIs.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client := httpx.New(
|
||||
httpx.WithBaseURL("https://httpbin.org"),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Form-encoded POST, common for OAuth token endpoints.
|
||||
req, err := httpx.NewFormRequest(ctx, http.MethodPost, "/post", url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"client_id": {"my-app"},
|
||||
"client_secret": {"secret"},
|
||||
"scope": {"read write"},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.String()
|
||||
fmt.Printf("Status: %d\nBody: %s\n", resp.StatusCode, body)
|
||||
}
|
||||
54
examples/load-balancing/main.go
Normal file
54
examples/load-balancing/main.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Demonstrates load balancing across multiple backend endpoints with
|
||||
// circuit breaking and health-checked failover.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
"git.codelab.vc/pkg/httpx/balancer"
|
||||
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||
"git.codelab.vc/pkg/httpx/retry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client := httpx.New(
|
||||
httpx.WithEndpoints(
|
||||
balancer.Endpoint{URL: "http://localhost:8081", Weight: 3},
|
||||
balancer.Endpoint{URL: "http://localhost:8082", Weight: 1},
|
||||
),
|
||||
httpx.WithBalancer(
|
||||
balancer.WithStrategy(balancer.WeightedRandom()),
|
||||
balancer.WithHealthCheck(
|
||||
balancer.WithHealthInterval(5*time.Second),
|
||||
balancer.WithHealthPath("/healthz"),
|
||||
),
|
||||
),
|
||||
httpx.WithCircuitBreaker(
|
||||
circuitbreaker.WithFailureThreshold(5),
|
||||
circuitbreaker.WithOpenDuration(30*time.Second),
|
||||
),
|
||||
httpx.WithRetry(
|
||||
retry.WithMaxAttempts(3),
|
||||
),
|
||||
httpx.WithLogger(slog.Default()),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
resp, err := client.Get(ctx, fmt.Sprintf("/api/item/%d", i))
|
||||
if err != nil {
|
||||
log.Printf("request %d failed: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, _ := resp.String()
|
||||
fmt.Printf("request %d → %d: %s\n", i, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
54
examples/request-id-propagation/main.go
Normal file
54
examples/request-id-propagation/main.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Demonstrates request ID propagation between server and client.
|
||||
// The server assigns a request ID to incoming requests, and the client
|
||||
// middleware forwards it to downstream services via X-Request-Id header.
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := slog.Default()
|
||||
|
||||
// Client that propagates request IDs from context.
|
||||
client := httpx.New(
|
||||
httpx.WithBaseURL("http://localhost:9090"),
|
||||
httpx.WithMiddleware(
|
||||
middleware.RequestID(), // Picks up ID from context, sets X-Request-Id.
|
||||
),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
r := server.NewRouter()
|
||||
|
||||
r.HandleFunc("GET /proxy", func(w http.ResponseWriter, r *http.Request) {
|
||||
// The request ID is in r.Context() thanks to server.RequestID().
|
||||
id := server.RequestIDFromContext(r.Context())
|
||||
logger.Info("handling request", "request_id", id)
|
||||
|
||||
// Client automatically forwards the request ID to downstream.
|
||||
resp, err := client.Get(r.Context(), "/downstream")
|
||||
if err != nil {
|
||||
server.WriteError(w, http.StatusBadGateway, fmt.Sprintf("downstream error: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.String()
|
||||
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"request_id": id,
|
||||
"downstream_response": body,
|
||||
})
|
||||
})
|
||||
|
||||
// Server with RequestID middleware that assigns IDs to incoming requests.
|
||||
srv := server.New(r, server.Defaults(logger)...)
|
||||
log.Fatal(srv.ListenAndServe())
|
||||
}
|
||||
54
examples/server-basic/main.go
Normal file
54
examples/server-basic/main.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Basic HTTP server with routing, middleware, health checks, and graceful shutdown.
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := slog.Default()
|
||||
|
||||
r := server.NewRouter(
|
||||
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteError(w, http.StatusNotFound, "resource not found")
|
||||
})),
|
||||
)
|
||||
|
||||
// Public endpoints.
|
||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "Hello, World!",
|
||||
})
|
||||
})
|
||||
|
||||
// API group with shared prefix.
|
||||
api := r.Group("/api/v1")
|
||||
api.HandleFunc("GET /users/{id}", getUser)
|
||||
api.HandleFunc("POST /users", createUser)
|
||||
|
||||
// Health checks.
|
||||
r.Mount("/", server.HealthHandler())
|
||||
|
||||
// Server with production defaults (RequestID → Recovery → Logging).
|
||||
srv := server.New(r, server.Defaults(logger)...)
|
||||
log.Fatal(srv.ListenAndServe())
|
||||
}
|
||||
|
||||
func getUser(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
server.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"id": id,
|
||||
"name": "Alice",
|
||||
})
|
||||
}
|
||||
|
||||
func createUser(w http.ResponseWriter, r *http.Request) {
|
||||
server.WriteJSON(w, http.StatusCreated, map[string]any{
|
||||
"id": 1,
|
||||
"message": "user created",
|
||||
})
|
||||
}
|
||||
61
examples/server-protected/main.go
Normal file
61
examples/server-protected/main.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Production server with CORS, rate limiting, body size limits, and timeouts.
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := slog.Default()
|
||||
|
||||
r := server.NewRouter(
|
||||
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteError(w, http.StatusNotFound, "not found")
|
||||
})),
|
||||
)
|
||||
|
||||
r.HandleFunc("GET /api/data", func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
})
|
||||
|
||||
r.HandleFunc("POST /api/upload", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Body is already limited by MaxBodySize middleware.
|
||||
server.WriteJSON(w, http.StatusAccepted, map[string]string{"status": "received"})
|
||||
})
|
||||
|
||||
r.Mount("/", server.HealthHandler())
|
||||
|
||||
srv := server.New(r,
|
||||
append(
|
||||
server.Defaults(logger),
|
||||
server.WithMiddleware(
|
||||
// CORS for browser-facing APIs.
|
||||
server.CORS(
|
||||
server.AllowOrigins("https://app.example.com", "https://admin.example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.ExposeHeaders("X-Request-Id"),
|
||||
server.AllowCredentials(true),
|
||||
server.MaxAge(3600),
|
||||
),
|
||||
// Rate limit: 100 req/s per IP, burst of 200.
|
||||
server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(200),
|
||||
),
|
||||
// Limit request body to 1 MB.
|
||||
server.MaxBodySize(1<<20),
|
||||
// Per-request timeout of 30 seconds.
|
||||
server.Timeout(30*time.Second),
|
||||
),
|
||||
server.WithAddr(":8080"),
|
||||
)...,
|
||||
)
|
||||
|
||||
log.Fatal(srv.ListenAndServe())
|
||||
}
|
||||
175
internal/clock/clock.go
Normal file
175
internal/clock/clock.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package clock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Clock abstracts time operations for deterministic testing.
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
Since(t time.Time) time.Duration
|
||||
NewTimer(d time.Duration) Timer
|
||||
After(d time.Duration) <-chan time.Time
|
||||
}
|
||||
|
||||
// Timer abstracts time.Timer for testability.
|
||||
type Timer interface {
|
||||
C() <-chan time.Time
|
||||
Stop() bool
|
||||
Reset(d time.Duration) bool
|
||||
}
|
||||
|
||||
// System returns a Clock backed by the real system time.
|
||||
func System() Clock { return systemClock{} }
|
||||
|
||||
type systemClock struct{}
|
||||
|
||||
func (systemClock) Now() time.Time { return time.Now() }
|
||||
func (systemClock) Since(t time.Time) time.Duration { return time.Since(t) }
|
||||
func (systemClock) NewTimer(d time.Duration) Timer { return &systemTimer{t: time.NewTimer(d)} }
|
||||
func (systemClock) After(d time.Duration) <-chan time.Time { return time.After(d) }
|
||||
|
||||
type systemTimer struct{ t *time.Timer }
|
||||
|
||||
func (s *systemTimer) C() <-chan time.Time { return s.t.C }
|
||||
func (s *systemTimer) Stop() bool { return s.t.Stop() }
|
||||
func (s *systemTimer) Reset(d time.Duration) bool { return s.t.Reset(d) }
|
||||
|
||||
// Mock returns a manually-controlled Clock for tests.
|
||||
func Mock(now time.Time) *MockClock {
|
||||
return &MockClock{now: now}
|
||||
}
|
||||
|
||||
// MockClock is a deterministic clock for testing.
|
||||
type MockClock struct {
|
||||
mu sync.Mutex
|
||||
now time.Time
|
||||
timers []*mockTimer
|
||||
}
|
||||
|
||||
func (m *MockClock) Now() time.Time {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.now
|
||||
}
|
||||
|
||||
func (m *MockClock) Since(t time.Time) time.Duration {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.now.Sub(t)
|
||||
}
|
||||
|
||||
func (m *MockClock) NewTimer(d time.Duration) Timer {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t := &mockTimer{
|
||||
clock: m,
|
||||
ch: make(chan time.Time, 1),
|
||||
deadline: m.now.Add(d),
|
||||
active: true,
|
||||
}
|
||||
m.timers = append(m.timers, t)
|
||||
if d <= 0 {
|
||||
t.fire(m.now)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (m *MockClock) After(d time.Duration) <-chan time.Time {
|
||||
return m.NewTimer(d).C()
|
||||
}
|
||||
|
||||
// Advance moves the clock forward by d and fires any expired timers.
|
||||
func (m *MockClock) Advance(d time.Duration) {
|
||||
m.mu.Lock()
|
||||
m.now = m.now.Add(d)
|
||||
now := m.now
|
||||
m.mu.Unlock()
|
||||
|
||||
m.fireExpired(now)
|
||||
}
|
||||
|
||||
// Set sets the clock to an absolute time and fires any expired timers.
|
||||
func (m *MockClock) Set(t time.Time) {
|
||||
m.mu.Lock()
|
||||
m.now = t
|
||||
now := m.now
|
||||
m.mu.Unlock()
|
||||
|
||||
m.fireExpired(now)
|
||||
}
|
||||
|
||||
// fireExpired fires all active timers whose deadline has passed, then
|
||||
// removes inactive timers to prevent unbounded growth.
|
||||
func (m *MockClock) fireExpired(now time.Time) {
|
||||
m.mu.Lock()
|
||||
timers := m.timers
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, t := range timers {
|
||||
t.mu.Lock()
|
||||
if t.active && !now.Before(t.deadline) {
|
||||
t.fire(now)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Compact: remove inactive timers. Use a new slice to avoid aliasing
|
||||
// the backing array (NewTimer may have appended between snapshots).
|
||||
m.mu.Lock()
|
||||
n := len(m.timers)
|
||||
active := make([]*mockTimer, 0, n)
|
||||
for _, t := range m.timers {
|
||||
t.mu.Lock()
|
||||
keep := t.active
|
||||
t.mu.Unlock()
|
||||
if keep {
|
||||
active = append(active, t)
|
||||
}
|
||||
}
|
||||
m.timers = active
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
type mockTimer struct {
|
||||
mu sync.Mutex
|
||||
clock *MockClock
|
||||
ch chan time.Time
|
||||
deadline time.Time
|
||||
active bool
|
||||
}
|
||||
|
||||
func (t *mockTimer) C() <-chan time.Time { return t.ch }
|
||||
|
||||
func (t *mockTimer) Stop() bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
was := t.active
|
||||
t.active = false
|
||||
return was
|
||||
}
|
||||
|
||||
func (t *mockTimer) Reset(d time.Duration) bool {
|
||||
// Acquire clock lock first to match the lock ordering in fireExpired
|
||||
// (clock.mu → t.mu), preventing deadlock.
|
||||
t.clock.mu.Lock()
|
||||
deadline := t.clock.now.Add(d)
|
||||
t.clock.mu.Unlock()
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
was := t.active
|
||||
t.active = true
|
||||
t.deadline = deadline
|
||||
return was
|
||||
}
|
||||
|
||||
// fire sends the time on the channel. Caller must hold t.mu.
|
||||
func (t *mockTimer) fire(now time.Time) {
|
||||
t.active = false
|
||||
select {
|
||||
case t.ch <- now:
|
||||
default:
|
||||
}
|
||||
}
|
||||
19
internal/requestid/requestid.go
Normal file
19
internal/requestid/requestid.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Package requestid provides a shared context key for request IDs,
|
||||
// allowing both client and server packages to access request IDs
|
||||
// without circular imports.
|
||||
package requestid
|
||||
|
||||
import "context"
|
||||
|
||||
type key struct{}
|
||||
|
||||
// NewContext returns a context with the given request ID.
|
||||
func NewContext(ctx context.Context, id string) context.Context {
|
||||
return context.WithValue(ctx, key{}, id)
|
||||
}
|
||||
|
||||
// FromContext returns the request ID from ctx, or empty string if not set.
|
||||
func FromContext(ctx context.Context) string {
|
||||
id, _ := ctx.Value(key{}).(string)
|
||||
return id
|
||||
}
|
||||
33
middleware/auth.go
Normal file
33
middleware/auth.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BearerAuth returns a middleware that sets the Authorization header to a
|
||||
// Bearer token obtained by calling tokenFunc on each request. If tokenFunc
|
||||
// returns an error, the request is not sent and the error is returned.
|
||||
func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
token, err := tokenFunc(req.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BasicAuth returns a middleware that sets HTTP Basic Authentication
|
||||
// credentials on every outgoing request.
|
||||
func BasicAuth(username, password string) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req.SetBasicAuth(username, password)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
89
middleware/auth_test.go
Normal file
89
middleware/auth_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestBearerAuth(t *testing.T) {
|
||||
t.Run("sets authorization header", func(t *testing.T) {
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
tokenFunc := func(_ context.Context) (string, error) {
|
||||
return "my-secret-token", nil
|
||||
}
|
||||
|
||||
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := "Bearer my-secret-token"
|
||||
if got := captured.Get("Authorization"); got != want {
|
||||
t.Errorf("Authorization = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
t.Fatal("base transport should not be called")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
tokenErr := errors.New("token expired")
|
||||
tokenFunc := func(_ context.Context) (string, error) {
|
||||
return "", tokenErr
|
||||
}
|
||||
|
||||
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tokenErr) {
|
||||
t.Errorf("got error %v, want %v", err, tokenErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
t.Run("sets basic auth header", func(t *testing.T) {
|
||||
var capturedReq *http.Request
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
capturedReq = req
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.BasicAuth("user", "pass")(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
username, password, ok := capturedReq.BasicAuth()
|
||||
if !ok {
|
||||
t.Fatal("BasicAuth() returned ok=false")
|
||||
}
|
||||
if username != "user" {
|
||||
t.Errorf("username = %q, want %q", username, "user")
|
||||
}
|
||||
if password != "pass" {
|
||||
t.Errorf("password = %q, want %q", password, "pass")
|
||||
}
|
||||
})
|
||||
}
|
||||
28
middleware/doc.go
Normal file
28
middleware/doc.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Package middleware provides client-side HTTP middleware for use with
|
||||
// httpx.Client or any http.RoundTripper-based transport chain.
|
||||
//
|
||||
// Each middleware is a function of type func(http.RoundTripper) http.RoundTripper.
|
||||
// Compose them with Chain:
|
||||
//
|
||||
// chain := middleware.Chain(
|
||||
// middleware.Logging(logger),
|
||||
// middleware.Recovery(),
|
||||
// middleware.UserAgent("my-service/1.0"),
|
||||
// )
|
||||
// transport := chain(http.DefaultTransport)
|
||||
//
|
||||
// # Available middleware
|
||||
//
|
||||
// - Logging — structured request/response logging via slog
|
||||
// - Recovery — panic recovery, converts panics to errors
|
||||
// - DefaultHeaders — adds default headers to outgoing requests
|
||||
// - UserAgent — sets User-Agent header
|
||||
// - BearerAuth — dynamic Bearer token authentication
|
||||
// - BasicAuth — HTTP Basic authentication
|
||||
// - RequestID — propagates request ID from context to X-Request-Id header
|
||||
//
|
||||
// # RoundTripperFunc
|
||||
//
|
||||
// RoundTripperFunc adapts plain functions to http.RoundTripper, similar to
|
||||
// http.HandlerFunc. Useful for testing and inline middleware.
|
||||
package middleware
|
||||
29
middleware/headers.go
Normal file
29
middleware/headers.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// DefaultHeaders returns a middleware that adds the given headers to every
|
||||
// outgoing request. Existing headers on the request are not overwritten.
|
||||
func DefaultHeaders(headers http.Header) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
for key, values := range headers {
|
||||
if req.Header.Get(key) != "" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// UserAgent returns a middleware that sets the User-Agent header on every
|
||||
// outgoing request, unless one is already present.
|
||||
func UserAgent(ua string) Middleware {
|
||||
return DefaultHeaders(http.Header{
|
||||
"User-Agent": {ua},
|
||||
})
|
||||
}
|
||||
107
middleware/headers_test.go
Normal file
107
middleware/headers_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestDefaultHeaders(t *testing.T) {
|
||||
t.Run("adds headers without overwriting existing", func(t *testing.T) {
|
||||
defaults := http.Header{
|
||||
"X-Custom": {"default-value"},
|
||||
"X-Untouched": {"from-middleware"},
|
||||
}
|
||||
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.DefaultHeaders(defaults)(base)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("X-Custom", "request-value")
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := captured.Get("X-Custom"); got != "request-value" {
|
||||
t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value")
|
||||
}
|
||||
if got := captured.Get("X-Untouched"); got != "from-middleware" {
|
||||
t.Errorf("X-Untouched = %q, want %q", got, "from-middleware")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("adds headers when absent", func(t *testing.T) {
|
||||
defaults := http.Header{
|
||||
"Accept": {"application/json"},
|
||||
}
|
||||
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.DefaultHeaders(defaults)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := captured.Get("Accept"); got != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserAgent(t *testing.T) {
|
||||
t.Run("sets user agent header", func(t *testing.T) {
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := captured.Get("User-Agent"); got != "httpx/1.0" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not overwrite existing user agent", func(t *testing.T) {
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("User-Agent", "custom-agent")
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := captured.Get("User-Agent"); got != "custom-agent" {
|
||||
t.Errorf("User-Agent = %q, want %q", got, "custom-agent")
|
||||
}
|
||||
})
|
||||
}
|
||||
38
middleware/logging.go
Normal file
38
middleware/logging.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logging returns a middleware that logs each request's method, URL, status
|
||||
// code, duration, and error (if any) using the provided structured logger.
|
||||
// Successful responses are logged at Info level; errors at Error level.
|
||||
func Logging(logger *slog.Logger) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
start := time.Now()
|
||||
|
||||
resp, err := next.RoundTrip(req)
|
||||
|
||||
duration := time.Since(start)
|
||||
attrs := []slog.Attr{
|
||||
slog.String("method", req.Method),
|
||||
slog.String("url", req.URL.String()),
|
||||
slog.Duration("duration", duration),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
logger.LogAttrs(req.Context(), slog.LevelError, "request failed", attrs...)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Int("status", resp.StatusCode))
|
||||
logger.LogAttrs(req.Context(), slog.LevelInfo, "request completed", attrs...)
|
||||
|
||||
return resp, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
124
middleware/logging_test.go
Normal file
124
middleware/logging_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
// captureHandler is a slog.Handler that captures log records for inspection.
|
||||
type captureHandler struct {
|
||||
mu sync.Mutex
|
||||
records []slog.Record
|
||||
}
|
||||
|
||||
func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||
|
||||
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.records = append(h.records, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||
func (h *captureHandler) WithGroup(_ string) slog.Handler { return h }
|
||||
|
||||
func (h *captureHandler) lastRecord() slog.Record {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.records[len(h.records)-1]
|
||||
}
|
||||
|
||||
func TestLogging(t *testing.T) {
|
||||
t.Run("logs method url status duration on success", func(t *testing.T) {
|
||||
handler := &captureHandler{}
|
||||
logger := slog.New(handler)
|
||||
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
transport := middleware.Logging(logger)(base)
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil)
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
rec := handler.lastRecord()
|
||||
if rec.Level != slog.LevelInfo {
|
||||
t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo)
|
||||
}
|
||||
|
||||
attrs := map[string]string{}
|
||||
rec.Attrs(func(a slog.Attr) bool {
|
||||
attrs[a.Key] = a.Value.String()
|
||||
return true
|
||||
})
|
||||
|
||||
if attrs["method"] != "POST" {
|
||||
t.Errorf("method = %q, want %q", attrs["method"], "POST")
|
||||
}
|
||||
if attrs["url"] != "http://example.com/api" {
|
||||
t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api")
|
||||
}
|
||||
if _, ok := attrs["status"]; !ok {
|
||||
t.Error("missing status attribute")
|
||||
}
|
||||
if _, ok := attrs["duration"]; !ok {
|
||||
t.Error("missing duration attribute")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("logs error on failure", func(t *testing.T) {
|
||||
handler := &captureHandler{}
|
||||
logger := slog.New(handler)
|
||||
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("connection refused")
|
||||
})
|
||||
|
||||
transport := middleware.Logging(logger)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil)
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
rec := handler.lastRecord()
|
||||
if rec.Level != slog.LevelError {
|
||||
t.Errorf("got level %v, want %v", rec.Level, slog.LevelError)
|
||||
}
|
||||
|
||||
attrs := map[string]string{}
|
||||
rec.Attrs(func(a slog.Attr) bool {
|
||||
attrs[a.Key] = a.Value.String()
|
||||
return true
|
||||
})
|
||||
|
||||
if attrs["error"] != "connection refused" {
|
||||
t.Errorf("error = %q, want %q", attrs["error"], "connection refused")
|
||||
}
|
||||
if _, ok := attrs["method"]; !ok {
|
||||
t.Error("missing method attribute")
|
||||
}
|
||||
if _, ok := attrs["url"]; !ok {
|
||||
t.Error("missing url attribute")
|
||||
}
|
||||
})
|
||||
}
|
||||
29
middleware/middleware.go
Normal file
29
middleware/middleware.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middleware wraps an http.RoundTripper to add behavior.
|
||||
// This is the fundamental building block of the httpx library.
|
||||
type Middleware func(http.RoundTripper) http.RoundTripper
|
||||
|
||||
// Chain composes middlewares so that Chain(A, B, C)(base) == A(B(C(base))).
|
||||
// Middlewares are applied from right to left: C wraps base first, then B wraps
|
||||
// the result, then A wraps last. This means A is the outermost layer and sees
|
||||
// every request first.
|
||||
func Chain(mws ...Middleware) Middleware {
|
||||
return func(rt http.RoundTripper) http.RoundTripper {
|
||||
for i := len(mws) - 1; i >= 0; i-- {
|
||||
rt = mws[i](rt)
|
||||
}
|
||||
return rt
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTripperFunc is an adapter to allow the use of ordinary functions as
|
||||
// http.RoundTripper. It works exactly like http.HandlerFunc for handlers.
|
||||
type RoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
// RoundTrip implements http.RoundTripper.
|
||||
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
115
middleware/middleware_test.go
Normal file
115
middleware/middleware_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(fn)
|
||||
}
|
||||
|
||||
func okResponse() *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
t.Run("applies middlewares in correct order", func(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
mwA := func(next http.RoundTripper) http.RoundTripper {
|
||||
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
order = append(order, "A-before")
|
||||
resp, err := next.RoundTrip(req)
|
||||
order = append(order, "A-after")
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
|
||||
mwB := func(next http.RoundTripper) http.RoundTripper {
|
||||
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
order = append(order, "B-before")
|
||||
resp, err := next.RoundTrip(req)
|
||||
order = append(order, "B-after")
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
|
||||
mwC := func(next http.RoundTripper) http.RoundTripper {
|
||||
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
order = append(order, "C-before")
|
||||
resp, err := next.RoundTrip(req)
|
||||
order = append(order, "C-after")
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
order = append(order, "base")
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
chained := middleware.Chain(mwA, mwB, mwC)(base)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
_, err := chained.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty chain returns base transport", func(t *testing.T) {
|
||||
called := false
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
called = true
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
chained := middleware.Chain()(base)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
_, err := chained.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("base transport was not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTripperFunc(t *testing.T) {
|
||||
t.Run("implements RoundTripper", func(t *testing.T) {
|
||||
fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
var rt http.RoundTripper = fn
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
22
middleware/recovery.go
Normal file
22
middleware/recovery.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Recovery returns a middleware that recovers from panics in the inner
|
||||
// RoundTripper chain. A recovered panic is converted to an error wrapping
|
||||
// the panic value.
|
||||
func Recovery() Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic recovered in round trip: %v", r)
|
||||
}
|
||||
}()
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
51
middleware/recovery_test.go
Normal file
51
middleware/recovery_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
t.Run("recovers from panic and returns error", func(t *testing.T) {
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
panic("something went wrong")
|
||||
})
|
||||
|
||||
transport := middleware.Recovery()(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Errorf("expected nil response, got %v", resp)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "panic recovered") {
|
||||
t.Errorf("error = %q, want it to contain %q", err.Error(), "panic recovered")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "something went wrong") {
|
||||
t.Errorf("error = %q, want it to contain %q", err.Error(), "something went wrong")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("passes through normal responses", func(t *testing.T) {
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.Recovery()(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
23
middleware/requestid.go
Normal file
23
middleware/requestid.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that propagates the request ID from the
|
||||
// request context to the outgoing X-Request-Id header. This pairs with
|
||||
// the server.RequestID middleware: the server stores the ID in the context,
|
||||
// and the client middleware forwards it to downstream services.
|
||||
func RequestID() Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if id := requestid.FromContext(req.Context()); id != "" {
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set("X-Request-Id", id)
|
||||
}
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
69
middleware/requestid_test.go
Normal file
69
middleware/requestid_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
t.Run("propagates ID from context", func(t *testing.T) {
|
||||
var gotHeader string
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotHeader = req.Header.Get("X-Request-Id")
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
ctx := requestid.NewContext(context.Background(), "test-id-123")
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||
_, err := mw.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotHeader != "test-id-123" {
|
||||
t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no ID in context skips header", func(t *testing.T) {
|
||||
var gotHeader string
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotHeader = req.Header.Get("X-Request-Id")
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
|
||||
_, err := mw.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotHeader != "" {
|
||||
t.Fatalf("expected no X-Request-Id header, got %q", gotHeader)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not mutate original request", func(t *testing.T) {
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
ctx := requestid.NewContext(context.Background(), "test-id")
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||
_, _ = mw.RoundTrip(req)
|
||||
|
||||
if req.Header.Get("X-Request-Id") != "" {
|
||||
t.Fatal("original request was mutated")
|
||||
}
|
||||
})
|
||||
}
|
||||
52
request.go
Normal file
52
request.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// NewRequest creates an http.Request with context. It is a convenience
|
||||
// wrapper around http.NewRequestWithContext.
|
||||
func NewRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||
return http.NewRequestWithContext(ctx, method, url, body)
|
||||
}
|
||||
|
||||
// NewJSONRequest creates an http.Request with a JSON-encoded body and
|
||||
// sets Content-Type to application/json.
|
||||
func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Request, error) {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("httpx: encoding JSON body: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(b)), nil
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewFormRequest creates an http.Request with a form-encoded body and
|
||||
// sets Content-Type to application/x-www-form-urlencoded.
|
||||
// The GetBody function is set so that the request can be retried.
|
||||
func NewFormRequest(ctx context.Context, method, rawURL string, values url.Values) (*http.Request, error) {
|
||||
encoded := values.Encode()
|
||||
b := []byte(encoded)
|
||||
req, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(b)), nil
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
80
request_form_test.go
Normal file
80
request_form_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestNewFormRequest(t *testing.T) {
|
||||
t.Run("body is form-encoded", func(t *testing.T) {
|
||||
values := url.Values{"username": {"alice"}, "scope": {"read"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com/token", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
t.Fatalf("parsing form: %v", err)
|
||||
}
|
||||
if parsed.Get("username") != "alice" {
|
||||
t.Errorf("username = %q, want %q", parsed.Get("username"), "alice")
|
||||
}
|
||||
if parsed.Get("scope") != "read" {
|
||||
t.Errorf("scope = %q, want %q", parsed.Get("scope"), "read")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content type is set", func(t *testing.T) {
|
||||
values := url.Values{"key": {"value"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ct := req.Header.Get("Content-Type")
|
||||
if ct != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("Content-Type = %q, want %q", ct, "application/x-www-form-urlencoded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetBody works for retry", func(t *testing.T) {
|
||||
values := url.Values{"key": {"value"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if req.GetBody == nil {
|
||||
t.Fatal("GetBody is nil")
|
||||
}
|
||||
|
||||
b1, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
body2, err := req.GetBody()
|
||||
if err != nil {
|
||||
t.Fatalf("GetBody(): %v", err)
|
||||
}
|
||||
b2, err := io.ReadAll(body2)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body2: %v", err)
|
||||
}
|
||||
|
||||
if string(b1) != string(b2) {
|
||||
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
|
||||
}
|
||||
})
|
||||
}
|
||||
78
request_test.go
Normal file
78
request_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestNewJSONRequest(t *testing.T) {
|
||||
t.Run("body is JSON encoded", func(t *testing.T) {
|
||||
payload := map[string]string{"key": "value"}
|
||||
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]string
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if decoded["key"] != "value" {
|
||||
t.Errorf("decoded[key] = %q, want %q", decoded["key"], "value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content type is set", func(t *testing.T) {
|
||||
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ct := req.Header.Get("Content-Type")
|
||||
if ct != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetBody works", func(t *testing.T) {
|
||||
payload := map[string]int{"num": 123}
|
||||
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if req.GetBody == nil {
|
||||
t.Fatal("GetBody is nil")
|
||||
}
|
||||
|
||||
// Read body first time
|
||||
b1, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
// Get a fresh body
|
||||
body2, err := req.GetBody()
|
||||
if err != nil {
|
||||
t.Fatalf("GetBody(): %v", err)
|
||||
}
|
||||
b2, err := io.ReadAll(body2)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body2: %v", err)
|
||||
}
|
||||
|
||||
if string(b1) != string(b2) {
|
||||
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
|
||||
}
|
||||
})
|
||||
}
|
||||
114
response.go
Normal file
114
response.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Response wraps http.Response with convenience methods.
|
||||
type Response struct {
|
||||
*http.Response
|
||||
body []byte
|
||||
read bool
|
||||
}
|
||||
|
||||
func newResponse(resp *http.Response) *Response {
|
||||
return &Response{Response: resp}
|
||||
}
|
||||
|
||||
// Bytes reads and returns the entire response body.
|
||||
// The body is cached so subsequent calls return the same data.
|
||||
func (r *Response) Bytes() ([]byte, error) {
|
||||
if r.read {
|
||||
return r.body, nil
|
||||
}
|
||||
defer r.Body.Close()
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.body = b
|
||||
r.read = true
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// String reads the response body and returns it as a string.
|
||||
func (r *Response) String() (string, error) {
|
||||
b, err := r.Bytes()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// JSON decodes the response body as JSON into v.
|
||||
func (r *Response) JSON(v any) error {
|
||||
b, err := r.Bytes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("httpx: reading response body: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(b, v); err != nil {
|
||||
return fmt.Errorf("httpx: decoding JSON: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// XML decodes the response body as XML into v.
|
||||
func (r *Response) XML(v any) error {
|
||||
b, err := r.Bytes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("httpx: reading response body: %w", err)
|
||||
}
|
||||
if err := xml.Unmarshal(b, v); err != nil {
|
||||
return fmt.Errorf("httpx: decoding XML: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSuccess returns true if the status code is in the 2xx range.
|
||||
func (r *Response) IsSuccess() bool {
|
||||
return r.StatusCode >= 200 && r.StatusCode < 300
|
||||
}
|
||||
|
||||
// IsError returns true if the status code is 4xx or 5xx.
|
||||
func (r *Response) IsError() bool {
|
||||
return r.StatusCode >= 400
|
||||
}
|
||||
|
||||
// Close drains and closes the response body.
|
||||
func (r *Response) Close() error {
|
||||
if r.read {
|
||||
return nil
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
return r.Body.Close()
|
||||
}
|
||||
|
||||
// BodyReader returns a reader for the response body.
|
||||
// If the body has already been read via Bytes/String/JSON/XML,
|
||||
// returns a reader over the cached bytes.
|
||||
func (r *Response) BodyReader() io.Reader {
|
||||
if r.read {
|
||||
return bytes.NewReader(r.body)
|
||||
}
|
||||
return r.Body
|
||||
}
|
||||
|
||||
// limitedReadCloser wraps an io.LimitedReader with a separate Closer
|
||||
// so the original body can be closed.
|
||||
type limitedReadCloser struct {
|
||||
R io.LimitedReader
|
||||
C io.Closer
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Read(p []byte) (int, error) {
|
||||
return l.R.Read(p)
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Close() error {
|
||||
return l.C.Close()
|
||||
}
|
||||
76
response_limit_test.go
Normal file
76
response_limit_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestClient_MaxResponseBody(t *testing.T) {
|
||||
t.Run("allows response within limit", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, "hello")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(1024))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "hello" {
|
||||
t.Fatalf("body = %q, want %q", body, "hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncates response exceeding limit", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", 1000)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, largeBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if len(b) != 100 {
|
||||
t.Fatalf("body length = %d, want %d", len(b), 100)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no limit when zero", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", 10000)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, largeBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if len(b) != 10000 {
|
||||
t.Fatalf("body length = %d, want %d", len(b), 10000)
|
||||
}
|
||||
})
|
||||
}
|
||||
96
response_test.go
Normal file
96
response_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeTestResponse(statusCode int, body string) *Response {
|
||||
return newResponse(&http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponse(t *testing.T) {
|
||||
t.Run("Bytes returns body", func(t *testing.T) {
|
||||
r := makeTestResponse(200, "hello world")
|
||||
b, err := r.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Bytes() error: %v", err)
|
||||
}
|
||||
if string(b) != "hello world" {
|
||||
t.Errorf("Bytes() = %q, want %q", string(b), "hello world")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("body caching returns same data", func(t *testing.T) {
|
||||
r := makeTestResponse(200, "cached body")
|
||||
b1, err := r.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("first Bytes() error: %v", err)
|
||||
}
|
||||
b2, err := r.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("second Bytes() error: %v", err)
|
||||
}
|
||||
if string(b1) != string(b2) {
|
||||
t.Errorf("Bytes() returned different data: %q vs %q", b1, b2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("String returns body as string", func(t *testing.T) {
|
||||
r := makeTestResponse(200, "string body")
|
||||
s, err := r.String()
|
||||
if err != nil {
|
||||
t.Fatalf("String() error: %v", err)
|
||||
}
|
||||
if s != "string body" {
|
||||
t.Errorf("String() = %q, want %q", s, "string body")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON decodes body", func(t *testing.T) {
|
||||
r := makeTestResponse(200, `{"name":"test","value":42}`)
|
||||
var result struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
if err := r.JSON(&result); err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
if result.Name != "test" {
|
||||
t.Errorf("Name = %q, want %q", result.Name, "test")
|
||||
}
|
||||
if result.Value != 42 {
|
||||
t.Errorf("Value = %d, want %d", result.Value, 42)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsSuccess for 2xx", func(t *testing.T) {
|
||||
for _, code := range []int{200, 201, 204, 299} {
|
||||
r := makeTestResponse(code, "")
|
||||
if !r.IsSuccess() {
|
||||
t.Errorf("IsSuccess() = false for status %d", code)
|
||||
}
|
||||
if r.IsError() {
|
||||
t.Errorf("IsError() = true for status %d", code)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IsError for 4xx and 5xx", func(t *testing.T) {
|
||||
for _, code := range []int{400, 404, 500, 503} {
|
||||
r := makeTestResponse(code, "")
|
||||
if !r.IsError() {
|
||||
t.Errorf("IsError() = false for status %d", code)
|
||||
}
|
||||
if r.IsSuccess() {
|
||||
t.Errorf("IsSuccess() = true for status %d", code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
64
retry/backoff.go
Normal file
64
retry/backoff.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Backoff computes the delay before the next retry attempt.
|
||||
type Backoff interface {
|
||||
// Delay returns the wait duration for the given attempt number (zero-based).
|
||||
Delay(attempt int) time.Duration
|
||||
}
|
||||
|
||||
// ExponentialBackoff returns a Backoff that doubles the delay on each attempt.
|
||||
// The delay is calculated as base * 2^attempt, capped at max. When withJitter
|
||||
// is true, a random duration in [0, delay*0.5) is added.
|
||||
func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff {
|
||||
return &exponentialBackoff{
|
||||
base: base,
|
||||
max: max,
|
||||
withJitter: withJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// ConstantBackoff returns a Backoff that always returns the same delay.
|
||||
func ConstantBackoff(d time.Duration) Backoff {
|
||||
return constantBackoff{delay: d}
|
||||
}
|
||||
|
||||
type exponentialBackoff struct {
|
||||
base time.Duration
|
||||
max time.Duration
|
||||
withJitter bool
|
||||
}
|
||||
|
||||
func (b *exponentialBackoff) Delay(attempt int) time.Duration {
|
||||
delay := b.base
|
||||
for range attempt {
|
||||
delay *= 2
|
||||
if delay >= b.max {
|
||||
delay = b.max
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if b.withJitter {
|
||||
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
|
||||
delay += jitter
|
||||
}
|
||||
|
||||
if delay > b.max {
|
||||
delay = b.max
|
||||
}
|
||||
|
||||
return delay
|
||||
}
|
||||
|
||||
type constantBackoff struct {
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (b constantBackoff) Delay(_ int) time.Duration {
|
||||
return b.delay
|
||||
}
|
||||
77
retry/backoff_test.go
Normal file
77
retry/backoff_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExponentialBackoff(t *testing.T) {
|
||||
t.Run("doubles each attempt", func(t *testing.T) {
|
||||
b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false)
|
||||
|
||||
want := []time.Duration{
|
||||
100 * time.Millisecond, // attempt 0: base
|
||||
200 * time.Millisecond, // attempt 1: base*2
|
||||
400 * time.Millisecond, // attempt 2: base*4
|
||||
800 * time.Millisecond, // attempt 3: base*8
|
||||
1600 * time.Millisecond, // attempt 4: base*16
|
||||
}
|
||||
|
||||
for i, expected := range want {
|
||||
got := b.Delay(i)
|
||||
if got != expected {
|
||||
t.Errorf("attempt %d: expected %v, got %v", i, expected, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("caps at max", func(t *testing.T) {
|
||||
b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false)
|
||||
|
||||
// attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms
|
||||
for _, attempt := range []int{3, 4, 10} {
|
||||
got := b.Delay(attempt)
|
||||
if got != 500*time.Millisecond {
|
||||
t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with jitter adds randomness", func(t *testing.T) {
|
||||
base := 100 * time.Millisecond
|
||||
b := ExponentialBackoff(base, 10*time.Second, true)
|
||||
|
||||
// Run multiple times; with jitter, delay >= base for attempt 0.
|
||||
// Also verify not all values are identical (randomness).
|
||||
seen := make(map[time.Duration]bool)
|
||||
for range 20 {
|
||||
d := b.Delay(0)
|
||||
if d < base {
|
||||
t.Fatalf("delay %v is less than base %v", d, base)
|
||||
}
|
||||
// With jitter: delay = base + rand in [0, base/2), so max is base*1.5
|
||||
maxExpected := base + base/2
|
||||
if d > maxExpected {
|
||||
t.Fatalf("delay %v exceeds expected max %v", d, maxExpected)
|
||||
}
|
||||
seen[d] = true
|
||||
}
|
||||
if len(seen) < 2 {
|
||||
t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConstantBackoff(t *testing.T) {
|
||||
t.Run("always returns same value", func(t *testing.T) {
|
||||
d := 250 * time.Millisecond
|
||||
b := ConstantBackoff(d)
|
||||
|
||||
for _, attempt := range []int{0, 1, 2, 5, 100} {
|
||||
got := b.Delay(attempt)
|
||||
if got != d {
|
||||
t.Errorf("attempt %d: expected %v, got %v", attempt, d, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
31
retry/doc.go
Normal file
31
retry/doc.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Package retry provides configurable HTTP request retry as client middleware.
|
||||
//
|
||||
// The retry middleware wraps an http.RoundTripper and automatically retries
|
||||
// failed requests based on a configurable policy, with exponential backoff
|
||||
// and optional jitter.
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// mw := retry.Transport(
|
||||
// retry.WithMaxAttempts(3),
|
||||
// retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 5*time.Second)),
|
||||
// )
|
||||
// transport := mw(http.DefaultTransport)
|
||||
//
|
||||
// # Retry-After
|
||||
//
|
||||
// The retry middleware respects the Retry-After response header. If a server
|
||||
// returns 429 or 503 with Retry-After, the delay from the header overrides
|
||||
// the backoff strategy.
|
||||
//
|
||||
// # Request bodies
|
||||
//
|
||||
// For requests with bodies to be retried, the request must have GetBody set.
|
||||
// Use httpx.NewJSONRequest or httpx.NewFormRequest which set GetBody
|
||||
// automatically.
|
||||
//
|
||||
// # Sentinel errors
|
||||
//
|
||||
// ErrRetryExhausted is returned when all attempts fail. The original error
|
||||
// is wrapped and accessible via errors.Unwrap.
|
||||
package retry
|
||||
56
retry/options.go
Normal file
56
retry/options.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package retry
|
||||
|
||||
import "time"
|
||||
|
||||
type options struct {
|
||||
maxAttempts int // default 3
|
||||
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||
retryAfter bool // default true, respect Retry-After header
|
||||
}
|
||||
|
||||
// Option configures the retry transport.
|
||||
type Option func(*options)
|
||||
|
||||
func defaults() options {
|
||||
return options{
|
||||
maxAttempts: 3,
|
||||
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
||||
policy: defaultPolicy{},
|
||||
retryAfter: true,
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxAttempts sets the maximum number of attempts (including the first).
|
||||
// Values less than 1 are treated as 1 (no retries).
|
||||
func WithMaxAttempts(n int) Option {
|
||||
return func(o *options) {
|
||||
if n < 1 {
|
||||
n = 1
|
||||
}
|
||||
o.maxAttempts = n
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackoff sets the backoff strategy used to compute delays between retries.
|
||||
func WithBackoff(b Backoff) Option {
|
||||
return func(o *options) {
|
||||
o.backoff = b
|
||||
}
|
||||
}
|
||||
|
||||
// WithPolicy sets the retry policy that decides whether to retry a request.
|
||||
func WithPolicy(p Policy) Option {
|
||||
return func(o *options) {
|
||||
o.policy = p
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetryAfter controls whether the Retry-After response header is respected.
|
||||
// When enabled and present, the Retry-After delay is used if it exceeds the
|
||||
// backoff delay.
|
||||
func WithRetryAfter(enable bool) Option {
|
||||
return func(o *options) {
|
||||
o.retryAfter = enable
|
||||
}
|
||||
}
|
||||
137
retry/retry.go
Normal file
137
retry/retry.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
// ErrRetryExhausted is returned when all retry attempts have been exhausted
|
||||
// and the last attempt also failed.
|
||||
var ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted")
|
||||
|
||||
// Policy decides whether a failed request should be retried.
|
||||
type Policy interface {
|
||||
// ShouldRetry reports whether the request should be retried. The extra
|
||||
// duration, if non-zero, is a policy-suggested delay that overrides the
|
||||
// backoff strategy.
|
||||
ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration)
|
||||
}
|
||||
|
||||
// Transport returns a middleware that retries failed requests according to
|
||||
// the provided options.
|
||||
func Transport(opts ...Option) middleware.Middleware {
|
||||
cfg := defaults()
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
var exhausted bool
|
||||
|
||||
for attempt := range cfg.maxAttempts {
|
||||
// For retries (attempt > 0), restore the request body.
|
||||
if attempt > 0 {
|
||||
if req.GetBody != nil {
|
||||
body, bodyErr := req.GetBody()
|
||||
if bodyErr != nil {
|
||||
return resp, bodyErr
|
||||
}
|
||||
req.Body = body
|
||||
} else if req.Body != nil {
|
||||
// Body was consumed and cannot be re-created.
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
resp, err = next.RoundTrip(req)
|
||||
|
||||
// Last attempt — return whatever we got.
|
||||
if attempt == cfg.maxAttempts-1 {
|
||||
exhausted = true
|
||||
break
|
||||
}
|
||||
|
||||
shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
|
||||
// Compute delay: use backoff or policy delay, whichever is larger.
|
||||
delay := cfg.backoff.Delay(attempt)
|
||||
if policyDelay > delay {
|
||||
delay = policyDelay
|
||||
}
|
||||
|
||||
// Respect Retry-After header if enabled.
|
||||
if cfg.retryAfter && resp != nil {
|
||||
if ra, ok := ParseRetryAfter(resp); ok && ra > delay {
|
||||
delay = ra
|
||||
}
|
||||
}
|
||||
|
||||
// Drain and close the response body to release the connection.
|
||||
if resp != nil {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Wait for the delay or context cancellation.
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
timer.Stop()
|
||||
return nil, req.Context().Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap with ErrRetryExhausted only when all attempts were used.
|
||||
if exhausted && err != nil {
|
||||
err = fmt.Errorf("%w: %w", ErrRetryExhausted, err)
|
||||
}
|
||||
return resp, err
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// defaultPolicy retries on network errors, 429, and 5xx server errors.
|
||||
// It refuses to retry non-idempotent methods.
|
||||
type defaultPolicy struct{}
|
||||
|
||||
func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||
if !isIdempotent(req.Method) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Network error — always retry idempotent requests.
|
||||
if err != nil {
|
||||
return true, 0
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusTooManyRequests, // 429
|
||||
http.StatusBadGateway, // 502
|
||||
http.StatusServiceUnavailable, // 503
|
||||
http.StatusGatewayTimeout: // 504
|
||||
return true, 0
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// isIdempotent reports whether the HTTP method is safe to retry.
|
||||
func isIdempotent(method string) bool {
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
43
retry/retry_after.go
Normal file
43
retry/retry_after.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231).
|
||||
// It supports both the delay-seconds format ("120") and the HTTP-date format
|
||||
// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the
|
||||
// header was present and valid; otherwise returns 0 and false.
|
||||
func ParseRetryAfter(resp *http.Response) (time.Duration, bool) {
|
||||
if resp == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
val := resp.Header.Get("Retry-After")
|
||||
if val == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Try delay-seconds first (most common).
|
||||
if seconds, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
if seconds < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return time.Duration(seconds) * time.Second, true
|
||||
}
|
||||
|
||||
// Try HTTP-date format (RFC 7231 section 7.1.1.1).
|
||||
t, err := http.ParseTime(val)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
delay := time.Until(t)
|
||||
if delay < 0 {
|
||||
// The date is in the past; no need to wait.
|
||||
return 0, true
|
||||
}
|
||||
return delay, true
|
||||
}
|
||||
58
retry/retry_after_test.go
Normal file
58
retry/retry_after_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseRetryAfter(t *testing.T) {
|
||||
t.Run("seconds format", func(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"Retry-After": []string{"120"}},
|
||||
}
|
||||
d, ok := ParseRetryAfter(resp)
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if d != 120*time.Second {
|
||||
t.Fatalf("expected 120s, got %v", d)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty header", func(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
d, ok := ParseRetryAfter(resp)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for empty header")
|
||||
}
|
||||
if d != 0 {
|
||||
t.Fatalf("expected 0, got %v", d)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil response", func(t *testing.T) {
|
||||
d, ok := ParseRetryAfter(nil)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for nil response")
|
||||
}
|
||||
if d != 0 {
|
||||
t.Fatalf("expected 0, got %v", d)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("negative value", func(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
Header: http.Header{"Retry-After": []string{"-5"}},
|
||||
}
|
||||
d, ok := ParseRetryAfter(resp)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false for negative value")
|
||||
}
|
||||
if d != 0 {
|
||||
t.Fatalf("expected 0, got %v", d)
|
||||
}
|
||||
})
|
||||
}
|
||||
237
retry/retry_test.go
Normal file
237
retry/retry_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||
return middleware.RoundTripperFunc(fn)
|
||||
}
|
||||
|
||||
func okResponse() *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func statusResponse(code int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
t.Run("successful request no retry", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return okResponse(), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("retries on 503 then succeeds", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
n := calls.Add(1)
|
||||
if n < 3 {
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}
|
||||
return okResponse(), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 3 {
|
||||
t.Fatalf("expected 3 calls, got %d", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stops on context cancellation", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
rt := Transport(
|
||||
WithMaxAttempts(5),
|
||||
WithBackoff(ConstantBackoff(50*time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
n := calls.Add(1)
|
||||
if n == 1 {
|
||||
cancel()
|
||||
}
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("respects maxAttempts", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
rt := Transport(
|
||||
WithMaxAttempts(2),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
return statusResponse(http.StatusBadGateway), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected 502, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
var bodies []string
|
||||
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
calls.Add(1)
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
bodies = append(bodies, string(b))
|
||||
if len(bodies) < 2 {
|
||||
return statusResponse(http.StatusServiceUnavailable), nil
|
||||
}
|
||||
return okResponse(), nil
|
||||
}))
|
||||
|
||||
bodyContent := "request-body"
|
||||
body := bytes.NewReader([]byte(bodyContent))
|
||||
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
|
||||
}
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", got)
|
||||
}
|
||||
for i, b := range bodies {
|
||||
if b != bodyContent {
|
||||
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom policy", func(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
// Custom policy: retry only on 418
|
||||
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||
if resp != nil && resp.StatusCode == http.StatusTeapot {
|
||||
return true, 0
|
||||
}
|
||||
return false, 0
|
||||
})
|
||||
|
||||
rt := Transport(
|
||||
WithMaxAttempts(3),
|
||||
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||
WithPolicy(custom),
|
||||
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
n := calls.Add(1)
|
||||
if n == 1 {
|
||||
return statusResponse(http.StatusTeapot), nil
|
||||
}
|
||||
return okResponse(), nil
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// policyFunc adapts a function into a Policy.
|
||||
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
||||
|
||||
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||
return f(attempt, req, resp, err)
|
||||
}
|
||||
38
server/doc.go
Normal file
38
server/doc.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Package server provides a production-ready HTTP server with graceful
|
||||
// shutdown, middleware composition, routing, and JSON response helpers.
|
||||
//
|
||||
// # Server
|
||||
//
|
||||
// Server wraps http.Server with net.Listener, signal-based graceful shutdown
|
||||
// (SIGINT/SIGTERM), and lifecycle hooks. It is configured via functional options:
|
||||
//
|
||||
// srv := server.New(handler,
|
||||
// server.WithAddr(":8080"),
|
||||
// server.Defaults(logger),
|
||||
// )
|
||||
// srv.ListenAndServe()
|
||||
//
|
||||
// # Router
|
||||
//
|
||||
// Router wraps http.ServeMux with middleware groups, prefix-based route groups,
|
||||
// and sub-handler mounting. It supports Go 1.22+ method-based patterns:
|
||||
//
|
||||
// r := server.NewRouter()
|
||||
// r.HandleFunc("GET /users/{id}", getUser)
|
||||
//
|
||||
// api := r.Group("/api/v1", authMiddleware)
|
||||
// api.HandleFunc("GET /items", listItems)
|
||||
//
|
||||
// # Middleware
|
||||
//
|
||||
// Server middleware follows the func(http.Handler) http.Handler pattern.
|
||||
// Available middleware: RequestID, Recovery, Logging, CORS, RateLimit,
|
||||
// MaxBodySize, Timeout. Use Chain to compose them:
|
||||
//
|
||||
// chain := server.Chain(server.RequestID(), server.Recovery(logger), server.Logging(logger))
|
||||
//
|
||||
// # Response helpers
|
||||
//
|
||||
// WriteJSON and WriteError provide JSON response writing with proper
|
||||
// Content-Type headers.
|
||||
package server
|
||||
55
server/health.go
Normal file
55
server/health.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ReadinessChecker is a function that reports whether a dependency is ready.
|
||||
// Return nil if healthy, or an error describing the problem.
|
||||
type ReadinessChecker func() error
|
||||
|
||||
// HealthHandler returns an http.Handler that exposes liveness and readiness
|
||||
// endpoints:
|
||||
//
|
||||
// - GET /healthz — liveness check, always returns 200 OK
|
||||
// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise
|
||||
func HealthHandler(checkers ...ReadinessChecker) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var errs []string
|
||||
for _, check := range checkers {
|
||||
if err := check(); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_ = json.NewEncoder(w).Encode(healthResponse{
|
||||
Status: "unavailable",
|
||||
Errors: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
type healthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
166
server/health_test.go
Normal file
166
server/health_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestHealthHandler(t *testing.T) {
|
||||
t.Run("liveness always returns 200", func(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if resp["status"] != "ok" {
|
||||
t.Fatalf("got status %q, want %q", resp["status"], "ok")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("readiness returns 200 when all checks pass", func(t *testing.T) {
|
||||
h := server.HealthHandler(
|
||||
func() error { return nil },
|
||||
func() error { return nil },
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("readiness returns 503 when a check fails", func(t *testing.T) {
|
||||
h := server.HealthHandler(
|
||||
func() error { return nil },
|
||||
func() error { return errors.New("db down") },
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if resp["status"] != "unavailable" {
|
||||
t.Fatalf("got status %q, want %q", resp["status"], "unavailable")
|
||||
}
|
||||
errs, ok := resp["errors"].([]any)
|
||||
if !ok || len(errs) != 1 {
|
||||
t.Fatalf("expected 1 error, got %v", resp["errors"])
|
||||
}
|
||||
if errs[0] != "db down" {
|
||||
t.Fatalf("got error %q, want %q", errs[0], "db down")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("readiness returns 200 with no checkers", func(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealth_MultipleFailingCheckers(t *testing.T) {
|
||||
h := server.HealthHandler(
|
||||
func() error { return errors.New("db down") },
|
||||
func() error { return errors.New("cache down") },
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
errs, ok := resp["errors"].([]any)
|
||||
if !ok || len(errs) != 2 {
|
||||
t.Fatalf("expected 2 errors, got %v", resp["errors"])
|
||||
}
|
||||
|
||||
errStrs := make(map[string]bool)
|
||||
for _, e := range errs {
|
||||
errStrs[e.(string)] = true
|
||||
}
|
||||
if !errStrs["db down"] || !errStrs["cache down"] {
|
||||
t.Fatalf("expected 'db down' and 'cache down', got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_LivenessContentType(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if ct != "application/json" {
|
||||
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_ReadinessContentType(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if ct != "application/json" {
|
||||
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_PostMethodNotAllowed(t *testing.T) {
|
||||
h := server.HealthHandler()
|
||||
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
t.Run("POST "+path, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
// ServeMux with "GET /healthz" pattern should reject POST.
|
||||
if w.Code == http.StatusOK {
|
||||
t.Fatalf("POST %s should not return 200, got %d", path, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
56
server/middleware.go
Normal file
56
server/middleware.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middleware wraps an http.Handler to add behavior.
|
||||
// This is the server-side counterpart of the client middleware type
|
||||
// func(http.RoundTripper) http.RoundTripper.
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))).
|
||||
// Middlewares are applied from right to left: C wraps handler first, then B wraps
|
||||
// the result, then A wraps last. This means A is the outermost layer and sees
|
||||
// every request first.
|
||||
func Chain(mws ...Middleware) Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
for i := len(mws) - 1; i >= 0; i-- {
|
||||
h = mws[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
}
|
||||
|
||||
// statusWriter wraps http.ResponseWriter to capture the response status code.
|
||||
// It implements Unwrap() so that http.ResponseController can access the
|
||||
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
written bool
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
func (w *statusWriter) WriteHeader(code int) {
|
||||
if !w.written {
|
||||
w.status = code
|
||||
w.written = true
|
||||
}
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Write delegates to the underlying writer, defaulting status to 200 if
|
||||
// WriteHeader was not called explicitly.
|
||||
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
if !w.written {
|
||||
w.status = http.StatusOK
|
||||
w.written = true
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter. This is required for
|
||||
// http.ResponseController to detect optional interfaces like http.Flusher
|
||||
// and http.Hijacker on the original writer.
|
||||
func (w *statusWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
15
server/middleware_bodylimit.go
Normal file
15
server/middleware_bodylimit.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// MaxBodySize returns a middleware that limits the size of incoming request
|
||||
// bodies. If the body exceeds n bytes, the server returns 413 Request Entity
|
||||
// Too Large. It wraps the body with http.MaxBytesReader.
|
||||
func MaxBodySize(n int64) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, n)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
61
server/middleware_bodylimit_test.go
Normal file
61
server/middleware_bodylimit_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestMaxBodySize(t *testing.T) {
|
||||
t.Run("allows body within limit", func(t *testing.T) {
|
||||
handler := server.MaxBodySize(1024)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}),
|
||||
)
|
||||
|
||||
body := strings.NewReader("hello")
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "hello" {
|
||||
t.Fatalf("got body %q, want %q", w.Body.String(), "hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects body exceeding limit", func(t *testing.T) {
|
||||
handler := server.MaxBodySize(5)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
body := strings.NewReader("this is longer than 5 bytes")
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
}
|
||||
128
server/middleware_cors.go
Normal file
128
server/middleware_cors.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type corsOptions struct {
|
||||
allowOrigins []string
|
||||
allowMethods []string
|
||||
allowHeaders []string
|
||||
exposeHeaders []string
|
||||
allowCredentials bool
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// CORSOption configures the CORS middleware.
|
||||
type CORSOption func(*corsOptions)
|
||||
|
||||
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
|
||||
// Default is no origins (CORS disabled).
|
||||
func AllowOrigins(origins ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowOrigins = origins }
|
||||
}
|
||||
|
||||
// AllowMethods sets the allowed HTTP methods for preflight requests.
|
||||
// Default is GET, POST, HEAD.
|
||||
func AllowMethods(methods ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowMethods = methods }
|
||||
}
|
||||
|
||||
// AllowHeaders sets the allowed request headers for preflight requests.
|
||||
func AllowHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowHeaders = headers }
|
||||
}
|
||||
|
||||
// ExposeHeaders sets headers that browsers are allowed to access.
|
||||
func ExposeHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.exposeHeaders = headers }
|
||||
}
|
||||
|
||||
// AllowCredentials indicates whether the response to the request can be
|
||||
// exposed when the credentials flag is true.
|
||||
func AllowCredentials(allow bool) CORSOption {
|
||||
return func(o *corsOptions) { o.allowCredentials = allow }
|
||||
}
|
||||
|
||||
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
|
||||
func MaxAge(seconds int) CORSOption {
|
||||
return func(o *corsOptions) { o.maxAge = seconds }
|
||||
}
|
||||
|
||||
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
|
||||
// It processes preflight OPTIONS requests and sets the appropriate
|
||||
// Access-Control-* response headers.
|
||||
func CORS(opts ...CORSOption) Middleware {
|
||||
o := &corsOptions{
|
||||
allowMethods: []string{"GET", "POST", "HEAD"},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
|
||||
allowAll := false
|
||||
for _, origin := range o.allowOrigins {
|
||||
if origin == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
allowedOrigins[origin] = struct{}{}
|
||||
}
|
||||
|
||||
methods := strings.Join(o.allowMethods, ", ")
|
||||
headers := strings.Join(o.allowHeaders, ", ")
|
||||
expose := strings.Join(o.exposeHeaders, ", ")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed := allowAll
|
||||
if !allowed {
|
||||
_, allowed = allowedOrigins[origin]
|
||||
}
|
||||
if !allowed {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the allowed origin. When credentials are enabled,
|
||||
// we must echo the specific origin, not "*".
|
||||
if allowAll && !o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if expose != "" {
|
||||
w.Header().Set("Access-Control-Expose-Headers", expose)
|
||||
}
|
||||
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// Handle preflight.
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
w.Header().Set("Access-Control-Allow-Methods", methods)
|
||||
if headers != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", headers)
|
||||
}
|
||||
if o.maxAge > 0 {
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
143
server/middleware_cors_test.go
Normal file
143
server/middleware_cors_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("no Origin header passes through", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Fatal("expected no CORS headers without Origin")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard origin", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("specific origin allowed", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://evil.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preflight OPTIONS", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("http://example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.MaxAge(3600),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
|
||||
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
|
||||
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
|
||||
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials with specific origin", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.AllowCredentials(true),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// With credentials, must echo specific origin even with wildcard config.
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expose headers", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.ExposeHeaders("X-Custom", "X-Request-Id"),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
|
||||
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Vary header is set", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Vary"); got != "Origin" {
|
||||
t.Fatalf("Vary = %q, want %q", got, "Origin")
|
||||
}
|
||||
})
|
||||
}
|
||||
39
server/middleware_logging.go
Normal file
39
server/middleware_logging.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logging returns a middleware that logs each request's method, path,
|
||||
// status code, duration, and request ID using the provided structured logger.
|
||||
func Logging(logger *slog.Logger) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(sw, r)
|
||||
|
||||
duration := time.Since(start)
|
||||
attrs := []slog.Attr{
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int("status", sw.status),
|
||||
slog.Duration("duration", duration),
|
||||
}
|
||||
|
||||
if id := RequestIDFromContext(r.Context()); id != "" {
|
||||
attrs = append(attrs, slog.String("request_id", id))
|
||||
}
|
||||
|
||||
level := slog.LevelInfo
|
||||
if sw.status >= http.StatusInternalServerError {
|
||||
level = slog.LevelError
|
||||
}
|
||||
|
||||
logger.LogAttrs(r.Context(), level, "request completed", attrs...)
|
||||
})
|
||||
}
|
||||
}
|
||||
129
server/middleware_ratelimit.go
Normal file
129
server/middleware_ratelimit.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
type rateLimitOptions struct {
|
||||
rate float64
|
||||
burst int
|
||||
keyFunc func(r *http.Request) string
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
// RateLimitOption configures the RateLimit middleware.
|
||||
type RateLimitOption func(*rateLimitOptions)
|
||||
|
||||
// WithRate sets the token refill rate (tokens per second).
|
||||
func WithRate(tokensPerSecond float64) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
|
||||
}
|
||||
|
||||
// WithBurst sets the maximum burst size (bucket capacity).
|
||||
func WithBurst(n int) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.burst = n }
|
||||
}
|
||||
|
||||
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||
// request. By default, the client IP address is used.
|
||||
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||
}
|
||||
|
||||
// withRateLimitClock sets the clock for testing. Not exported.
|
||||
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.clock = c }
|
||||
}
|
||||
|
||||
// RateLimit returns a middleware that limits requests using a per-key token
|
||||
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||
// Requests with a Retry-After header.
|
||||
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||
o := &rateLimitOptions{
|
||||
rate: 10,
|
||||
burst: 20,
|
||||
clock: clock.System(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
if o.keyFunc == nil {
|
||||
o.keyFunc = clientIP
|
||||
}
|
||||
|
||||
var buckets sync.Map
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := o.keyFunc(r)
|
||||
val, _ := buckets.LoadOrStore(key, &bucket{
|
||||
tokens: float64(o.burst),
|
||||
lastTime: o.clock.Now(),
|
||||
})
|
||||
b := val.(*bucket)
|
||||
|
||||
b.mu.Lock()
|
||||
now := o.clock.Now()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
b.tokens += elapsed * o.rate
|
||||
if b.tokens > float64(o.burst) {
|
||||
b.tokens = float64(o.burst)
|
||||
}
|
||||
b.lastTime = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
retryAfter := (1 - b.tokens) / o.rate
|
||||
b.mu.Unlock()
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
b.tokens--
|
||||
b.mu.Unlock()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP from the request. It checks
|
||||
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First IP in the comma-separated list is the original client.
|
||||
if i := indexOf(xff, ','); i > 0 {
|
||||
return xff[:i]
|
||||
}
|
||||
return xff
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func indexOf(s string, b byte) int {
|
||||
for i := range len(s) {
|
||||
if s[i] == b {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
171
server/middleware_ratelimit_test.go
Normal file
171
server/middleware_ratelimit_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(10),
|
||||
)(okHandler)
|
||||
|
||||
for i := range 10 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(2),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Next request should be rejected.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("expected Retry-After header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different IPs have independent limits", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// First IP exhausts its limit.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Second IP should still be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "5.6.7.8:5678"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust limit for 10.0.0.1.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same forwarded IP should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom key function", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
server.WithKeyFunc(func(r *http.Request) string {
|
||||
return r.Header.Get("X-API-Key")
|
||||
}),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust key "abc".
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same key should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Different key should be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "xyz")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tokens refill over time", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1000), // Very fast refill for test
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Wait a bit for refill.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
50
server/middleware_recovery.go
Normal file
50
server/middleware_recovery.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
// RecoveryOption configures the Recovery middleware.
|
||||
type RecoveryOption func(*recoveryOptions)
|
||||
|
||||
type recoveryOptions struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// WithRecoveryLogger sets the logger for the Recovery middleware.
|
||||
// If not set, panics are recovered silently (500 is still returned).
|
||||
func WithRecoveryLogger(l *slog.Logger) RecoveryOption {
|
||||
return func(o *recoveryOptions) { o.logger = l }
|
||||
}
|
||||
|
||||
// Recovery returns a middleware that recovers from panics in downstream
|
||||
// handlers. A recovered panic results in a 500 Internal Server Error
|
||||
// response and is logged (if a logger is configured) with the stack trace.
|
||||
func Recovery(opts ...RecoveryOption) Middleware {
|
||||
o := &recoveryOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
if o.logger != nil {
|
||||
o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered",
|
||||
slog.Any("panic", v),
|
||||
slog.String("stack", string(debug.Stack())),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.String("request_id", RequestIDFromContext(r.Context())),
|
||||
)
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
51
server/middleware_requestid.go
Normal file
51
server/middleware_requestid.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that assigns a unique request ID to each
|
||||
// request. If the incoming request already has an X-Request-Id header, that
|
||||
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
||||
//
|
||||
// The request ID is stored in the request context (retrieve with
|
||||
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||
func RequestID() Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.Header.Get("X-Request-Id")
|
||||
if id == "" {
|
||||
id = newUUID()
|
||||
}
|
||||
|
||||
ctx := requestid.NewContext(r.Context(), id)
|
||||
w.Header().Set("X-Request-Id", id)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||
// string if none is set.
|
||||
func RequestIDFromContext(ctx context.Context) string {
|
||||
return requestid.FromContext(ctx)
|
||||
}
|
||||
|
||||
// newUUID generates a UUID v4 string using crypto/rand.
|
||||
func newUUID() string {
|
||||
var uuid [16]byte
|
||||
_, _ = rand.Read(uuid[:])
|
||||
|
||||
// Set version 4 (bits 12-15 of time_hi_and_version).
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
||||
// Set variant bits (10xx).
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
|
||||
}
|
||||
508
server/middleware_test.go
Normal file
508
server/middleware_test.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
t.Run("applies middlewares in correct order", func(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
mwA := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "A-before")
|
||||
next.ServeHTTP(w, r)
|
||||
order = append(order, "A-after")
|
||||
})
|
||||
}
|
||||
|
||||
mwB := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "B-before")
|
||||
next.ServeHTTP(w, r)
|
||||
order = append(order, "B-after")
|
||||
})
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
chained := server.Chain(mwA, mwB)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
chained.ServeHTTP(w, req)
|
||||
|
||||
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
|
||||
called := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
chained := server.Chain()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
chained.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("handler was not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChain_SingleMiddleware(t *testing.T) {
|
||||
var called bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
chained := server.Chain(mw)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
chained.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("single middleware was not called")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.WriteHeader(http.StatusNotFound) // second call should not change captured status
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=201") {
|
||||
t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_WriteDefaultsTo200(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello")) // Write without WriteHeader
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=200") {
|
||||
t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriter_Unwrap(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
rc := http.NewResponseController(w)
|
||||
if err := rc.Flush(); err != nil {
|
||||
// httptest.ResponseRecorder implements Flusher, so this should succeed
|
||||
// if Unwrap works correctly.
|
||||
http.Error(w, "flush failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Use Logging to wrap in statusWriter
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusInternalServerError {
|
||||
t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
t.Run("generates ID when not present", func(t *testing.T) {
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if gotID == "" {
|
||||
t.Fatal("expected request ID in context, got empty")
|
||||
}
|
||||
if w.Header().Get("X-Request-Id") != gotID {
|
||||
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
|
||||
}
|
||||
// UUID v4 format: 8-4-4-4-12 hex chars.
|
||||
if len(gotID) != 36 {
|
||||
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing ID", func(t *testing.T) {
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Request-Id", "custom-123")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if gotID != "custom-123" {
|
||||
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context without ID returns empty", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if id := server.RequestIDFromContext(req.Context()); id != "" {
|
||||
t.Fatalf("expected empty, got %q", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||
uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
|
||||
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !uuidV4Re.MatchString(gotID) {
|
||||
t.Fatalf("generated ID %q does not match UUID v4 format", gotID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 1000)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := server.RequestIDFromContext(r.Context())
|
||||
if _, exists := seen[id]; exists {
|
||||
t.Fatalf("duplicate request ID: %q", id)
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
if len(seen) != 1000 {
|
||||
t.Fatalf("expected 1000 unique IDs, got %d", len(seen))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("something went wrong")
|
||||
})
|
||||
|
||||
mw := server.Recovery()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("logs panic with logger", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("boom")
|
||||
})
|
||||
|
||||
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "panic recovered") {
|
||||
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
|
||||
}
|
||||
if !strings.Contains(buf.String(), "boom") {
|
||||
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("passes through without panic", func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
mw := server.Recovery()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRecovery_PanicWithNonString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
}{
|
||||
{"integer", 42},
|
||||
{"struct", struct{ X int }{X: 1}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic(tt.value)
|
||||
})
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if !strings.Contains(buf.String(), "panic recovered") {
|
||||
t.Fatalf("expected 'panic recovered' in log, got %q", buf.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery_ResponseBody(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("fail")
|
||||
})
|
||||
|
||||
mw := server.Recovery()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
body := strings.TrimSpace(w.Body.String())
|
||||
if body != "Internal Server Error" {
|
||||
t.Fatalf("got body %q, want %q", body, "Internal Server Error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovery_LogAttributes(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
// Put RequestID before Recovery so request_id is in context
|
||||
handler := server.RequestID()(
|
||||
server.Recovery(server.WithRecoveryLogger(logger))(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
panic("boom")
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
for _, attr := range []string{"method=", "path=", "request_id="} {
|
||||
if !strings.Contains(logOutput, attr) {
|
||||
t.Fatalf("expected %q in log, got %q", attr, logOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging(t *testing.T) {
|
||||
t.Run("logs request details", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "request completed") {
|
||||
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
|
||||
}
|
||||
if !strings.Contains(logOutput, "POST") {
|
||||
t.Fatalf("expected method in log, got %q", logOutput)
|
||||
}
|
||||
if !strings.Contains(logOutput, "/api/users") {
|
||||
t.Fatalf("expected path in log, got %q", logOutput)
|
||||
}
|
||||
if !strings.Contains(logOutput, "status=201") {
|
||||
t.Fatalf("expected status=201 in log, got %q", logOutput)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("logs error level for 5xx", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "level=ERROR") {
|
||||
t.Fatalf("expected ERROR level in log, got %q", logOutput)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogging_4xxIsInfoLevel(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "level=INFO") {
|
||||
t.Fatalf("expected INFO level for 404, got %q", logOutput)
|
||||
}
|
||||
if strings.Contains(logOutput, "level=ERROR") {
|
||||
t.Fatalf("404 should not be logged as ERROR, got %q", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging_DefaultStatus200(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
})
|
||||
|
||||
mw := server.Logging(logger)(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "status=200") {
|
||||
t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogging_IncludesRequestID(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := server.RequestID()(
|
||||
server.Logging(logger)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(buf.String(), "request_id=") {
|
||||
t.Fatalf("expected request_id in log output, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
15
server/middleware_timeout.go
Normal file
15
server/middleware_timeout.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Timeout returns a middleware that limits request processing time.
|
||||
// If the handler does not complete within d, the client receives a
|
||||
// 503 Service Unavailable response. It wraps http.TimeoutHandler.
|
||||
func Timeout(d time.Duration) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.TimeoutHandler(next, d, "Service Unavailable\n")
|
||||
}
|
||||
}
|
||||
49
server/middleware_timeout_test.go
Normal file
49
server/middleware_timeout_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
t.Run("handler completes within timeout", func(t *testing.T) {
|
||||
handler := server.Timeout(1 * time.Second)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handler exceeds timeout returns 503", func(t *testing.T) {
|
||||
handler := server.Timeout(10 * time.Millisecond)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
})
|
||||
}
|
||||
89
server/options.go
Normal file
89
server/options.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type serverOptions struct {
|
||||
addr string
|
||||
readTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
shutdownTimeout time.Duration
|
||||
logger *slog.Logger
|
||||
middlewares []Middleware
|
||||
onShutdown []func()
|
||||
}
|
||||
|
||||
// Option configures a Server.
|
||||
type Option func(*serverOptions)
|
||||
|
||||
// WithAddr sets the listen address. Default is ":8080".
|
||||
func WithAddr(addr string) Option {
|
||||
return func(o *serverOptions) { o.addr = addr }
|
||||
}
|
||||
|
||||
// WithReadTimeout sets the maximum duration for reading the entire request.
|
||||
func WithReadTimeout(d time.Duration) Option {
|
||||
return func(o *serverOptions) { o.readTimeout = d }
|
||||
}
|
||||
|
||||
// WithReadHeaderTimeout sets the maximum duration for reading request headers.
|
||||
func WithReadHeaderTimeout(d time.Duration) Option {
|
||||
return func(o *serverOptions) { o.readHeaderTimeout = d }
|
||||
}
|
||||
|
||||
// WithWriteTimeout sets the maximum duration before timing out writes of the response.
|
||||
func WithWriteTimeout(d time.Duration) Option {
|
||||
return func(o *serverOptions) { o.writeTimeout = d }
|
||||
}
|
||||
|
||||
// WithIdleTimeout sets the maximum amount of time to wait for the next request
|
||||
// when keep-alives are enabled.
|
||||
func WithIdleTimeout(d time.Duration) Option {
|
||||
return func(o *serverOptions) { o.idleTimeout = d }
|
||||
}
|
||||
|
||||
// WithShutdownTimeout sets the maximum duration to wait for active connections
|
||||
// to close during graceful shutdown. Default is 15 seconds.
|
||||
func WithShutdownTimeout(d time.Duration) Option {
|
||||
return func(o *serverOptions) { o.shutdownTimeout = d }
|
||||
}
|
||||
|
||||
// WithLogger sets the structured logger used by the server for lifecycle events.
|
||||
func WithLogger(l *slog.Logger) Option {
|
||||
return func(o *serverOptions) { o.logger = l }
|
||||
}
|
||||
|
||||
// WithMiddleware appends server middlewares to the chain.
|
||||
// These are applied to the handler in the order given.
|
||||
func WithMiddleware(mws ...Middleware) Option {
|
||||
return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) }
|
||||
}
|
||||
|
||||
// WithOnShutdown registers a function to be called during graceful shutdown,
|
||||
// before the HTTP server begins draining connections.
|
||||
func WithOnShutdown(fn func()) Option {
|
||||
return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) }
|
||||
}
|
||||
|
||||
// Defaults returns a production-ready set of options including standard
|
||||
// middleware (RequestID, Recovery, Logging), sensible timeouts, and the
|
||||
// provided logger.
|
||||
//
|
||||
// Middleware order: RequestID → Recovery → Logging → user handler.
|
||||
func Defaults(logger *slog.Logger) []Option {
|
||||
return []Option{
|
||||
WithReadHeaderTimeout(10 * time.Second),
|
||||
WithIdleTimeout(120 * time.Second),
|
||||
WithShutdownTimeout(15 * time.Second),
|
||||
WithLogger(logger),
|
||||
WithMiddleware(
|
||||
RequestID(),
|
||||
Recovery(WithRecoveryLogger(logger)),
|
||||
Logging(logger),
|
||||
),
|
||||
}
|
||||
}
|
||||
29
server/respond.go
Normal file
29
server/respond.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// WriteJSON encodes v as JSON and writes it to w with the given status code.
|
||||
// It sets Content-Type to application/json.
|
||||
func WriteJSON(w http.ResponseWriter, status int, v any) error {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, err = w.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteError writes a JSON error response with the given status code and
|
||||
// message. The response body is {"error": "<message>"}.
|
||||
func WriteError(w http.ResponseWriter, status int, msg string) error {
|
||||
return WriteJSON(w, status, errorBody{Error: msg})
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
72
server/respond_test.go
Normal file
72
server/respond_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
t.Run("writes JSON with status and content type", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
type resp struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
err := server.WriteJSON(w, 201, resp{ID: 1, Name: "Alice"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if w.Code != 201 {
|
||||
t.Fatalf("got status %d, want %d", w.Code, 201)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||
}
|
||||
|
||||
var decoded resp
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if decoded.ID != 1 || decoded.Name != "Alice" {
|
||||
t.Fatalf("got %+v, want {ID:1 Name:Alice}", decoded)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for unmarshalable input", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := server.WriteJSON(w, 200, make(chan int))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for channel type")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := server.WriteError(w, 404, "not found")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Fatalf("got status %d, want %d", w.Code, 404)
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body.Error != "not found" {
|
||||
t.Fatalf("error = %q, want %q", body.Error, "not found")
|
||||
}
|
||||
}
|
||||
126
server/route.go
Normal file
126
server/route.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Router is a lightweight wrapper around http.ServeMux that adds middleware
|
||||
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
|
||||
// like "GET /users/{id}".
|
||||
type Router struct {
|
||||
mux *http.ServeMux
|
||||
prefix string
|
||||
middlewares []Middleware
|
||||
notFoundHandler http.Handler
|
||||
}
|
||||
|
||||
// RouterOption configures a Router.
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithNotFoundHandler sets a custom handler for requests that don't match
|
||||
// any registered pattern. This is useful for returning JSON 404/405 responses
|
||||
// instead of the default plain text.
|
||||
func WithNotFoundHandler(h http.Handler) RouterOption {
|
||||
return func(r *Router) { r.notFoundHandler = h }
|
||||
}
|
||||
|
||||
// NewRouter creates a new Router backed by a fresh http.ServeMux.
|
||||
func NewRouter(opts ...RouterOption) *Router {
|
||||
r := &Router{
|
||||
mux: http.NewServeMux(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Handle registers a handler for the given pattern. The pattern follows
|
||||
// http.ServeMux conventions, including method-based patterns like "GET /users".
|
||||
func (r *Router) Handle(pattern string, handler http.Handler) {
|
||||
if len(r.middlewares) > 0 {
|
||||
handler = Chain(r.middlewares...)(handler)
|
||||
}
|
||||
r.mux.Handle(r.prefixedPattern(pattern), handler)
|
||||
}
|
||||
|
||||
// HandleFunc registers a handler function for the given pattern.
|
||||
func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) {
|
||||
r.Handle(pattern, fn)
|
||||
}
|
||||
|
||||
// Group creates a sub-router with a shared prefix and optional middleware.
|
||||
// Patterns registered on the group are prefixed automatically. The group
|
||||
// shares the underlying ServeMux with the parent router.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// api := router.Group("/api/v1", authMiddleware)
|
||||
// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users"
|
||||
func (r *Router) Group(prefix string, mws ...Middleware) *Router {
|
||||
return &Router{
|
||||
mux: r.mux,
|
||||
prefix: r.prefix + prefix,
|
||||
middlewares: append(r.middlewaresSnapshot(), mws...),
|
||||
}
|
||||
}
|
||||
|
||||
// Mount attaches an http.Handler under the given prefix. All requests
|
||||
// starting with prefix are forwarded to the handler with the prefix stripped.
|
||||
func (r *Router) Mount(prefix string, handler http.Handler) {
|
||||
full := r.prefix + prefix
|
||||
if !strings.HasSuffix(full, "/") {
|
||||
full += "/"
|
||||
}
|
||||
r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler))
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler, making Router usable as a handler.
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.notFoundHandler != nil {
|
||||
// Use the mux to check for a match. If none, use the custom handler.
|
||||
_, pattern := r.mux.Handler(req)
|
||||
if pattern == "" {
|
||||
r.notFoundHandler.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// prefixedPattern inserts the router prefix into a pattern. It is aware of
|
||||
// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users".
|
||||
func (r *Router) prefixedPattern(pattern string) string {
|
||||
if r.prefix == "" {
|
||||
return pattern
|
||||
}
|
||||
|
||||
// Split method prefix if present: "GET /users" → method="GET ", path="/users"
|
||||
method, path, hasMethod := splitMethodPattern(pattern)
|
||||
|
||||
path = r.prefix + path
|
||||
|
||||
if hasMethod {
|
||||
return method + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// splitMethodPattern splits "GET /path" into ("GET ", "/path", true).
|
||||
// If there is no method prefix, returns ("", pattern, false).
|
||||
func splitMethodPattern(pattern string) (method, path string, hasMethod bool) {
|
||||
if idx := strings.IndexByte(pattern, ' '); idx >= 0 {
|
||||
return pattern[:idx+1], pattern[idx+1:], true
|
||||
}
|
||||
return "", pattern, false
|
||||
}
|
||||
|
||||
func (r *Router) middlewaresSnapshot() []Middleware {
|
||||
if len(r.middlewares) == 0 {
|
||||
return nil
|
||||
}
|
||||
cp := make([]Middleware, len(r.middlewares))
|
||||
copy(cp, r.middlewares)
|
||||
return cp
|
||||
}
|
||||
70
server/route_notfound_test.go
Normal file
70
server/route_notfound_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRouter_NotFoundHandler(t *testing.T) {
|
||||
t.Run("custom 404 handler", func(t *testing.T) {
|
||||
notFound := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||
})
|
||||
|
||||
r := server.NewRouter(server.WithNotFoundHandler(notFound))
|
||||
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Matched route works normally.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/exists", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("matched route: got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Unmatched route uses custom handler.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("not found: got status %d, want %d", w.Code, http.StatusNotFound)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body["error"] != "not found" {
|
||||
t.Fatalf("error = %q, want %q", body["error"], "not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default behavior without custom handler", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Default ServeMux returns 404.
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
336
server/route_test.go
Normal file
336
server/route_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
t.Run("basic route", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("world"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "world" {
|
||||
t.Fatalf("got body %q, want %q", body, "world")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle with http.Handler", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("pong"))
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if body := w.Body.String(); body != "pong" {
|
||||
t.Fatalf("got body %q, want %q", body, "pong")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("path parameter", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
_, _ = w.Write([]byte("user:" + req.PathValue("id")))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/users/42", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if body := w.Body.String(); body != "user:42" {
|
||||
t.Fatalf("got body %q, want %q", body, "user:42")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouterGroup(t *testing.T) {
|
||||
t.Run("prefix is applied", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
api := r.Group("/api/v1")
|
||||
api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("users"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "users" {
|
||||
t.Fatalf("got body %q, want %q", body, "users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nested groups", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
api := r.Group("/api")
|
||||
v1 := api.Group("/v1")
|
||||
v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("items"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if body := w.Body.String(); body != "items" {
|
||||
t.Fatalf("got body %q, want %q", body, "items")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("group middleware", func(t *testing.T) {
|
||||
var mwCalled bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mwCalled = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
g := r.Group("/admin", mw)
|
||||
g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if !mwCalled {
|
||||
t.Fatal("group middleware was not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouterMount(t *testing.T) {
|
||||
t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) {
|
||||
sub := http.NewServeMux()
|
||||
sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("info"))
|
||||
})
|
||||
|
||||
r := server.NewRouter()
|
||||
r.Mount("/sub", sub)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/sub/info", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
if string(body) != "info" {
|
||||
t.Fatalf("got body %q, want %q", body, "info")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mount with trailing slash", func(t *testing.T) {
|
||||
sub := http.NewServeMux()
|
||||
sub.HandleFunc("GET /data", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("data"))
|
||||
})
|
||||
|
||||
r := server.NewRouter()
|
||||
r.Mount("/sub/", sub)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/sub/data", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
if string(body) != "data" {
|
||||
t.Fatalf("got body %q, want %q", body, "data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouter_PatternWithoutMethod(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("/static/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("static"))
|
||||
})
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodPost} {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(method, "/static/file.css", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("%s /static/file.css: got status %d, want %d", method, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupEmptyPrefix(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
g := r.Group("")
|
||||
g.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "hello" {
|
||||
t.Fatalf("got body %q, want %q", body, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupInheritsMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
parentMW := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "parent")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
childMW := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "child")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
parent := r.Group("/api", parentMW)
|
||||
child := parent.Group("/v1", childMW)
|
||||
child.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expected := []string{"parent", "child", "handler"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GroupMiddlewareOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
mwA := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "A")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
mwB := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "B")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
g := r.Group("/api", mwA)
|
||||
sub := g.Group("/v1", mwB)
|
||||
sub.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Parent MW (A) should run before child MW (B), then handler.
|
||||
expected := []string{"A", "B", "handler"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("got %v, want %v", order, expected)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_PathParamWithGroup(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
api := r.Group("/api")
|
||||
api.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
_, _ = w.Write([]byte("id=" + req.PathValue("id")))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/42", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if body := w.Body.String(); body != "id=42" {
|
||||
t.Fatalf("got body %q, want %q", body, "id=42")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_MiddlewareNotAppliedToOtherRoutes(t *testing.T) {
|
||||
var mwCalled bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mwCalled = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
r := server.NewRouter()
|
||||
|
||||
// Add middleware only to /admin group.
|
||||
admin := r.Group("/admin", mw)
|
||||
admin.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("admin"))
|
||||
})
|
||||
|
||||
// Route outside the group.
|
||||
r.HandleFunc("GET /public", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("public"))
|
||||
})
|
||||
|
||||
// Request to /public should NOT trigger group middleware.
|
||||
mwCalled = false
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/public", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if mwCalled {
|
||||
t.Fatal("group middleware should not be called for routes outside the group")
|
||||
}
|
||||
if w.Body.String() != "public" {
|
||||
t.Fatalf("got body %q, want %q", w.Body.String(), "public")
|
||||
}
|
||||
}
|
||||
173
server/server.go
Normal file
173
server/server.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server is a production-ready HTTP server with graceful shutdown,
|
||||
// middleware support, and signal handling.
|
||||
type Server struct {
|
||||
httpServer *http.Server
|
||||
listener net.Listener
|
||||
addr atomic.Value
|
||||
logger *slog.Logger
|
||||
shutdownTimeout time.Duration
|
||||
onShutdown []func()
|
||||
listenAddr string
|
||||
}
|
||||
|
||||
// New creates a new Server that will serve the given handler with the
|
||||
// provided options. Middleware from options is applied to the handler.
|
||||
func New(handler http.Handler, opts ...Option) *Server {
|
||||
o := &serverOptions{
|
||||
addr: ":8080",
|
||||
shutdownTimeout: 15 * time.Second,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// Apply middleware chain to the handler.
|
||||
if len(o.middlewares) > 0 {
|
||||
handler = Chain(o.middlewares...)(handler)
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
httpServer: &http.Server{
|
||||
Handler: handler,
|
||||
ReadTimeout: o.readTimeout,
|
||||
ReadHeaderTimeout: o.readHeaderTimeout,
|
||||
WriteTimeout: o.writeTimeout,
|
||||
IdleTimeout: o.idleTimeout,
|
||||
},
|
||||
logger: o.logger,
|
||||
shutdownTimeout: o.shutdownTimeout,
|
||||
onShutdown: o.onShutdown,
|
||||
listenAddr: o.addr,
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM
|
||||
// signal is received. It then performs a graceful shutdown within the
|
||||
// configured shutdown timeout.
|
||||
//
|
||||
// Returns nil on clean shutdown or an error if listen/shutdown fails.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
ln, err := net.Listen("tcp", s.listenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = ln
|
||||
s.addr.Store(ln.Addr().String())
|
||||
|
||||
s.log("server started", slog.String("addr", ln.Addr().String()))
|
||||
|
||||
// Wait for signal in context.
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
close(errCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
stop()
|
||||
return s.shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServeTLS starts the server with TLS and blocks until a signal
|
||||
// is received.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
ln, err := net.Listen("tcp", s.listenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = ln
|
||||
s.addr.Store(ln.Addr().String())
|
||||
|
||||
s.log("server started (TLS)", slog.String("addr", ln.Addr().String()))
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
close(errCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
stop()
|
||||
return s.shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server. It calls any registered
|
||||
// onShutdown hooks, then waits for active connections to drain within
|
||||
// the shutdown timeout.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.runOnShutdown()
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// Addr returns the listener address after the server has started.
|
||||
// Returns an empty string if the server has not started yet.
|
||||
func (s *Server) Addr() string {
|
||||
v := s.addr.Load()
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return v.(string)
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() error {
|
||||
s.log("shutting down")
|
||||
|
||||
s.runOnShutdown()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||
s.log("shutdown error", slog.String("error", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
s.log("server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) runOnShutdown() {
|
||||
for _, fn := range s.onShutdown {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) log(msg string, attrs ...slog.Attr) {
|
||||
if s.logger != nil {
|
||||
s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...)
|
||||
}
|
||||
}
|
||||
268
server/server_test.go
Normal file
268
server/server_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
t.Run("starts and serves requests", func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("hello"))
|
||||
})
|
||||
|
||||
srv := server.New(handler, server.WithAddr(":0"))
|
||||
|
||||
// Start in background and wait for addr.
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.ListenAndServe() }()
|
||||
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if string(body) != "hello" {
|
||||
t.Fatalf("got body %q, want %q", body, "hello")
|
||||
}
|
||||
|
||||
// Shutdown.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("shutdown failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("addr returns empty before start", func(t *testing.T) {
|
||||
srv := server.New(http.NotFoundHandler())
|
||||
if addr := srv.Addr(); addr != "" {
|
||||
t.Fatalf("got addr %q before start, want empty", addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
t.Run("calls onShutdown hooks", func(t *testing.T) {
|
||||
called := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithOnShutdown(func() { called = true }),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Fatal("onShutdown hook was not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerWithMiddleware(t *testing.T) {
|
||||
t.Run("applies middleware from options", func(t *testing.T) {
|
||||
var called bool
|
||||
mw := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithMiddleware(mw),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if !called {
|
||||
t.Fatal("middleware was not called")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerDefaults(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler, append(server.Defaults(logger), server.WithAddr(":0"))...)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Defaults includes RequestID middleware, so response should have X-Request-Id.
|
||||
if resp.Header.Get("X-Request-Id") == "" {
|
||||
t.Fatal("expected X-Request-Id header from Defaults middleware")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func TestServerListenError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Use an invalid address to trigger a listen error.
|
||||
srv := server.New(handler, server.WithAddr(":-1"))
|
||||
|
||||
err := srv.ListenAndServe()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from invalid address, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMultipleOnShutdownHooks(t *testing.T) {
|
||||
var calls []int
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 1) }),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 2) }),
|
||||
server.WithOnShutdown(func() { calls = append(calls, 3) }),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 3 {
|
||||
t.Fatalf("expected 3 hooks called, got %d: %v", len(calls), calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerShutdownWithLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithLogger(logger),
|
||||
)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
// Send SIGINT to trigger graceful shutdown via ListenAndServe's signal handler.
|
||||
// Instead, use Shutdown directly and check log from server start.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
|
||||
// The server logs "server started" on ListenAndServe.
|
||||
logOutput := buf.String()
|
||||
if !strings.Contains(logOutput, "server started") {
|
||||
t.Fatalf("expected 'server started' in log, got %q", logOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerOptions(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Verify options don't panic and server starts correctly.
|
||||
srv := server.New(handler,
|
||||
server.WithAddr(":0"),
|
||||
server.WithReadTimeout(5*time.Second),
|
||||
server.WithReadHeaderTimeout(3*time.Second),
|
||||
server.WithWriteTimeout(10*time.Second),
|
||||
server.WithIdleTimeout(60*time.Second),
|
||||
server.WithShutdownTimeout(5*time.Second),
|
||||
)
|
||||
|
||||
go func() { _ = srv.ListenAndServe() }()
|
||||
waitForAddr(t, srv)
|
||||
|
||||
resp, err := http.Get("http://" + srv.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// waitForAddr polls until the server's Addr() is non-empty.
|
||||
func waitForAddr(t *testing.T, srv *server.Server) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if srv.Addr() != "" {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("server did not start in time")
|
||||
}
|
||||
Reference in New Issue
Block a user