Compare commits

..

29 Commits

Author SHA1 Message Date
b6350185d9 Fix publish workflow: use git archive instead of go install
All checks were successful
CI / test (push) Successful in 51s
Publish / publish (push) Successful in 51s
golang.org/x/mod/zip is a library, not a CLI tool. Use git archive
with --prefix to create the module zip in the format Gitea expects.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 10:39:35 +03:00
85cdc5e2c9 Add publish workflow for Gitea Go Package Registry
Some checks failed
CI / test (push) Successful in 32s
Publish / publish (push) Failing after 37s
Publishes the module to Gitea Package Registry on tag push (v*).
Runs vet and tests before publishing to prevent broken releases.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 10:25:19 +03:00
25beb2f5c2 Add AGENTS.md reference to CLAUDE.md
All checks were successful
CI / test (push) Successful in 31s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:42 +03:00
16ff427c93 Add AI agent configuration files (AGENTS.md, .cursorrules, copilot-instructions)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:38 +03:00
138d4b6c6d Add package-level doc comments for go doc and gopls
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:33 +03:00
3aa7536328 Add examples/ with runnable usage demos for all major features
All checks were successful
CI / test (push) Successful in 31s
Six examples covering the full API surface:
- basic-client: retry, timeout, logging, response size limit
- form-request: form-encoded POST for OAuth/webhooks
- load-balancing: weighted endpoints, circuit breaker, health checks
- server-basic: routing, groups, JSON helpers, health, custom 404
- server-protected: CORS, rate limiting, body limits, timeouts
- request-id-propagation: cross-service request ID forwarding

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:05:08 +03:00
89cfc38f0e Update documentation with new client and server features
All checks were successful
CI / test (push) Successful in 30s
README: add PATCH, NewFormRequest, CORS, RateLimit, MaxBodySize,
Timeout, WriteJSON/WriteError, request ID propagation, response body
limit, and custom 404 handler examples. CLAUDE.md: document new
architecture details for all added components.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:55:15 +03:00
8a63f142a7 Add WithNotFoundHandler option for custom 404 responses on Router
Allows configuring a custom handler for unmatched routes, enabling
consistent JSON error responses instead of ServeMux's default plain
text. NewRouter now accepts RouterOption functional options.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:48:13 +03:00
21274c178a Add WithMaxResponseBody option to prevent client-side OOM
Wraps response body with io.LimitedReader when configured, preventing
unbounded reads from io.ReadAll in Response.Bytes(). Protects against
upstream services returning unexpectedly large responses.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:48:05 +03:00
49be6f8a7e Add client RequestID middleware for cross-service propagation
Introduces internal/requestid package with shared context key to avoid
circular imports between server and middleware packages. Server's
RequestID middleware now uses the shared key. Client middleware picks up
the ID from context and sets X-Request-Id on outgoing requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:58 +03:00
3395f70abd Add server RateLimit middleware with per-key token bucket
Protects against abuse with configurable rate/burst per client IP.
Supports custom key functions, X-Forwarded-For extraction, and
Retry-After headers on 429 responses. Uses internal/clock for
testability.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:51 +03:00
7a2cef00c3 Add server WriteJSON and WriteError response helpers
Eliminates repeated marshal-set-header-write boilerplate in handlers.
WriteError produces consistent {"error": "..."} JSON responses.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:45 +03:00
de5bf9a6d9 Add server CORS middleware with preflight handling
Supports AllowOrigins, AllowMethods, AllowHeaders, ExposeHeaders,
AllowCredentials, and MaxAge options. Handles preflight OPTIONS requests
correctly, including Vary header and credential-aware origin echoing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:39 +03:00
7f12b0c87a Add server Timeout middleware for context-based request deadlines
Wraps http.TimeoutHandler to return 503 when handlers exceed the
configured duration. Unlike http.Server.WriteTimeout, this allows
handlers to complete gracefully via context cancellation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:33 +03:00
1b322c8c81 Add server MaxBodySize middleware to prevent memory exhaustion
Wraps request body with http.MaxBytesReader to limit incoming payload
size. Without this, any endpoint accepting a body is vulnerable to
large uploads consuming all available memory.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:26 +03:00
b40a373675 Add NewFormRequest for form-encoded HTTP requests
Creates requests with application/x-www-form-urlencoded body from
url.Values. Supports GetBody for retry compatibility, following the
same pattern as NewJSONRequest.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:19 +03:00
f6384ecbea Add Client.Patch method for PATCH HTTP requests
Follows the same pattern as Put/Post, accepting context, URL, and body.
Closes an obvious gap in the REST client API.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:11 +03:00
4d47918a66 Update documentation with server package details
All checks were successful
CI / test (push) Successful in 30s
Add server package description, component table, and usage example to
README. Document server architecture, middleware chain, and test
conventions in CLAUDE.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 22:56:12 +03:00
7fae6247d5 Add comprehensive test coverage for server/ package
All checks were successful
CI / test (push) Successful in 30s
Cover edge cases: statusWriter multi-call/default/unwrap, UUID v4 format
and uniqueness, non-string panics, recovery body and log attributes,
4xx log level, default status in logging, request ID propagation,
server defaults/options/listen-error/multiple-hooks/logger, router
groups with empty prefix/inherited middleware/ordering/path params/
isolation, mount trailing slash, health content-type and POST rejection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:55:22 +03:00
cea75d198b Add production-ready HTTP server package with routing, health checks, and middleware
Introduces server/ sub-package as the server-side companion to the existing Client.
Includes Router (over http.ServeMux with groups and mounting), graceful shutdown with
signal handling, health endpoints (/healthz, /readyz), and built-in middlewares
(RequestID, Recovery, Logging). Zero external dependencies.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:41:54 +03:00
6b901c931e Add CLAUDE.md and Gitea CI workflow
All checks were successful
CI / test (push) Successful in 1m7s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:06:30 +03:00
d260abc393 Add README with usage examples and package overview 2026-03-20 15:23:13 +03:00
5cfd1a7400 Fix sentinel error aliasing, hot-path allocations, and resource leaks
- Deduplicate sentinel errors: httpx.ErrNoHealthy, ErrCircuitOpen, and
  ErrRetryExhausted are now aliases to the canonical sub-package values
  so errors.Is works across package boundaries
- Retry transport returns ErrRetryExhausted only when all attempts are
  actually exhausted, not on early policy exit
- Balancer: pre-parse endpoint URLs at construction, replace req.Clone
  with cheap shallow struct copy to avoid per-request allocations
- Circuit breaker: Load before LoadOrStore to avoid allocating a Breaker
  on every request for known hosts
- Health checker: drain response body before close for connection reuse,
  probe endpoints concurrently, run initial probe synchronously in Start
- Client: add Close() to shut down health checker goroutine, propagate
  URL resolution errors instead of silently discarding them
- MockClock: fix lock ordering in Reset (clock.mu before t.mu), fix
  timer slice compaction to avoid backing-array aliasing, extract
  fireExpired to deduplicate Advance/Set
2026-03-20 15:21:32 +03:00
f9a05f5c57 Add Client with response wrapper, request helpers, and full middleware assembly
Implements the top-level httpx.Client that composes the full chain:
  Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport

- Response wrapper with JSON/XML/Bytes decoding and body caching
- NewJSONRequest helper with Content-Type and GetBody support
- Functional options: WithBaseURL, WithTimeout, WithRetry, WithEndpoints, etc.
- Integration tests covering retry, balancing, error mapping, and JSON round-trips
2026-03-20 14:22:22 +03:00
a90c4cd7fa Add standard middlewares: logging, headers, auth, and panic recovery
- Logging: structured slog output with method, URL, status, duration
- DefaultHeaders/UserAgent: inject headers without overwriting existing
- BearerAuth/BasicAuth: per-request token resolution and static credentials
- Recovery: catches panics in the RoundTripper chain
2026-03-20 14:22:14 +03:00
8d322123a4 Add load balancer with round-robin, failover, and weighted strategies
Implements balancer middleware with URL rewriting per-request:
- RoundRobin, Failover, and WeightedRandom endpoint selection strategies
- Background HealthChecker with configurable probe interval and path
- Thread-safe health state tracking with sync.RWMutex
2026-03-20 14:22:07 +03:00
2ca930236d Add per-host circuit breaker with three-state machine
Implements circuit breaker as a RoundTripper middleware:
- Closed → Open after consecutive failure threshold
- Open → HalfOpen after configurable duration
- HalfOpen → Closed on success, back to Open on failure
- Per-host tracking via sync.Map for independent endpoint isolation
2026-03-20 14:22:00 +03:00
505c7b8c4f Add retry transport with configurable backoff and Retry-After support
Implements retry middleware as a RoundTripper wrapper:
- Exponential and constant backoff strategies with jitter
- RFC 7231 Retry-After header parsing (seconds and HTTP-date)
- Default policy retries idempotent methods on 429/5xx and network errors
- Body restoration via GetBody, context cancellation, response body cleanup
2026-03-20 14:21:53 +03:00
6b1941fce7 Add foundation: middleware type, error types, and internal clock
Introduce the core building blocks for the httpx library:
- middleware.Middleware type and Chain() composer
- Error struct with sentinel errors (ErrRetryExhausted, ErrCircuitOpen, ErrNoHealthy)
- internal/clock package with Clock interface and MockClock for deterministic testing
2026-03-20 14:21:43 +03:00
86 changed files with 7854 additions and 0 deletions

48
.cursorrules Normal file
View File

@@ -0,0 +1,48 @@
You are working on `git.codelab.vc/pkg/httpx`, a Go 1.24 HTTP client/server library with zero external dependencies.
## Architecture
- Client middleware: `func(http.RoundTripper) http.RoundTripper` — compose with `middleware.Chain`
- Server middleware: `func(http.Handler) http.Handler` — compose with `server.Chain`
- All configuration uses functional options pattern (`WithXxx` functions)
- Chain order for client: Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
## Package structure
- `httpx` (root) — Client, request builders (NewJSONRequest, NewFormRequest), error types
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
- `retry/` — retry middleware with exponential backoff and Retry-After support
- `circuitbreaker/` — per-host circuit breaker (sync.Map of host → Breaker)
- `balancer/` — load balancing with health checking (RoundRobin, Weighted, Failover)
- `server/` — Server, Router, server middleware (RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout), response helpers (WriteJSON, WriteError)
- `internal/requestid/` — shared context key (avoids circular import between server and middleware)
- `internal/clock/` — deterministic time for tests
## Code conventions
- Zero external dependencies — stdlib only, do not add imports outside the module
- Functional options: `type Option func(*options)` with `With<Name>` constructors
- Test with stdlib only: `testing`, `httptest`, `net/http`. No testify/gomock
- Client test helper: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc`
- Server test helper: `httptest.NewRecorder`, `httptest.NewRequest`, `waitForAddr(t, srv)`
- Thread safety with `sync.Mutex`, `sync.Map`, or `atomic`
- Use `internal/clock` for time-dependent tests, not `time.Now()` directly
- Sentinel errors in sub-packages, re-exported as aliases in root package
## When writing new code
- Client middleware → file in `middleware/`, return `middleware.Middleware`
- Server middleware → file in `server/middleware_<name>.go`, return `server.Middleware`
- New option → add field to options struct, create `With<Name>` func, apply in constructor
- Do NOT import `server` from `middleware` or vice versa (use `internal/requestid` for shared context)
- Client.Close() must be called when using WithEndpoints() (stops health checker goroutine)
- Request bodies must have GetBody set for retry — use NewJSONRequest/NewFormRequest
## Commands
```bash
go build ./... # compile
go test ./... # test
go test -race ./... # test with race detector
go vet ./... # static analysis
```

23
.gitea/workflows/ci.yml Normal file
View File

@@ -0,0 +1,23 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...

View File

@@ -0,0 +1,39 @@
name: Publish
on:
push:
tags: ["v*"]
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...
- name: Publish to Gitea Package Registry
run: |
VERSION=${GITHUB_REF#refs/tags/}
MODULE=$(go list -m)
# Create module zip with required prefix: module@version/
git archive --format=zip --prefix="${MODULE}@${VERSION}/" HEAD -o module.zip
# Gitea Go Package Registry API
curl -s -f \
-X PUT \
-H "Authorization: token ${{ secrets.PUBLISH_TOKEN }}" \
-H "Content-Type: application/zip" \
--data-binary @module.zip \
"${{ github.server_url }}/api/packages/pkg/go/upload?module=${MODULE}&version=${VERSION}"
env:
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}

50
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,50 @@
# Copilot instructions — httpx
## Project
`git.codelab.vc/pkg/httpx` is a Go 1.24 HTTP client and server library with zero external dependencies.
## Architecture
### Client (`httpx` root package)
- Middleware type: `func(http.RoundTripper) http.RoundTripper`
- Chain assembly (outermost → innermost): Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
- Circuit breaker is per-host (sync.Map of host → Breaker)
- Client.Close() required when using WithEndpoints() — stops health checker goroutine
### Server (`server/` package)
- Middleware type: `func(http.Handler) http.Handler`
- Router wraps http.ServeMux with groups, prefix routing, Mount for sub-handlers
- Defaults() preset: RequestID → Recovery → Logging + production timeouts
- Available middleware: RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout
- WriteJSON/WriteError for JSON responses
### Sub-packages
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
- `retry/` — retry with exponential backoff and Retry-After
- `circuitbreaker/` — per-host circuit breaker
- `balancer/` — load balancing with health checking
- `internal/requestid/` — shared context key between server and middleware
- `internal/clock/` — deterministic time for tests
## Conventions
- All configuration uses functional options (`WithXxx` functions)
- Zero external dependencies — do not add requires to go.mod
- Tests use stdlib only (testing, httptest) — no testify or gomock
- Thread safety with sync.Mutex, sync.Map, or atomic
- Client test mock: `mockTransport(fn)` using `middleware.RoundTripperFunc`
- Server test helpers: `httptest.NewRecorder`, `httptest.NewRequest`
- Do NOT import server from middleware or vice versa — use internal/requestid for shared context
- Sentinel errors in sub-packages, re-exported in root package
- Use internal/clock for time-dependent tests
## Commands
```bash
go build ./... # compile
go test ./... # test
go test -race ./... # test with race detector
go vet ./... # static analysis
```

139
AGENTS.md Normal file
View File

@@ -0,0 +1,139 @@
# AGENTS.md — httpx
Universal guide for AI coding agents working with this codebase.
## Overview
`git.codelab.vc/pkg/httpx` is a Go HTTP toolkit with **zero external dependencies** (Go 1.24, stdlib only). It provides:
- A composable HTTP **client** with retry, circuit breaking, load balancing
- A production-ready HTTP **server** with routing, middleware, graceful shutdown
## Package map
```
httpx/ Root — Client, request builders, error types
├── middleware/ Client-side middleware (RoundTripper wrappers)
├── retry/ Retry middleware with backoff
├── circuitbreaker/ Per-host circuit breaker
├── balancer/ Client-side load balancing + health checking
├── server/ Server, Router, server-side middleware, response helpers
└── internal/
├── requestid/ Shared context key (avoids circular imports)
└── clock/ Deterministic time for testing
```
## Middleware chain architecture
### Client middleware: `func(http.RoundTripper) http.RoundTripper`
```
Request flow (outermost → innermost):
Logging
└→ User Middlewares
└→ Retry
└→ Circuit Breaker
└→ Balancer
└→ Base Transport (http.DefaultTransport)
```
Retry wraps CB+Balancer so each attempt can hit a different endpoint.
### Server middleware: `func(http.Handler) http.Handler`
```
Chain(A, B, C)(handler) == A(B(C(handler)))
A is outermost (sees request first, response last)
```
Defaults() preset: `RequestID → Recovery → Logging`
## Common tasks
### Add a client middleware
1. Create file in `middleware/` (or inline)
2. Return `middleware.Middleware` (`func(http.RoundTripper) http.RoundTripper`)
3. Use `middleware.RoundTripperFunc` for the inner adapter
4. Test with `middleware.RoundTripperFunc` as mock transport
```go
func MyMiddleware() middleware.Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
// before
resp, err := next.RoundTrip(req)
// after
return resp, err
})
}
}
```
### Add a server middleware
1. Create file in `server/` named `middleware_<name>.go`
2. Return `server.Middleware` (`func(http.Handler) http.Handler`)
3. Use `server.statusWriter` if you need to capture the response status
4. Test with `httptest.NewRecorder` + `httptest.NewRequest`
```go
func MyMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// before
next.ServeHTTP(w, r)
// after
})
}
}
```
### Add a route
```go
r := server.NewRouter()
r.HandleFunc("GET /users/{id}", getUser)
r.HandleFunc("POST /users", createUser)
// Group with prefix + middleware
api := r.Group("/api/v1", authMiddleware)
api.HandleFunc("GET /items", listItems)
// Mount sub-handler
r.Mount("/health", server.HealthHandler())
```
### Add a functional option
1. Add field to the options struct (`clientOptions` or `serverOptions`)
2. Create `With<Name>` function returning `Option`
3. Apply the field in the constructor (`New`)
## Gotchas
- **Middleware order matters**: Retry wraps CB+Balancer intentionally — each retry attempt can hit a different endpoint and a different circuit breaker
- **Circular imports via `internal/`**: Both `server` and `middleware` packages need request ID context. The shared key lives in `internal/requestid` — do NOT import `server` from `middleware` or vice versa
- **Client.Close() is required** when using `WithEndpoints()` — the balancer starts a background health checker goroutine that must be stopped
- **GetBody for retries**: Request bodies must be replayable. Use `NewJSONRequest`/`NewFormRequest` (they set `GetBody`) or set it manually
- **statusWriter.Unwrap()**: Server middleware must not type-assert `http.ResponseWriter` directly — use `http.ResponseController` which calls `Unwrap()` to find `http.Flusher`, `http.Hijacker`, etc.
- **No external deps**: This is a zero-dependency library. Do not add any `require` to `go.mod`
## Commands
```bash
go build ./... # compile
go test ./... # all tests
go test -race ./... # tests with race detector
go test -v -run TestName ./package/ # single test
go vet ./... # static analysis
```
## Conventions
- **Functional options** for all configuration (client and server)
- **stdlib only** testing — no testify, no gomock
- **Thread safety** — use `sync.Mutex`, `sync.Map`, or `atomic` where needed
- **`internal/clock`** — use for deterministic time in tests (never `time.Now()` directly in testable code)
- **Test helpers**: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server), `waitForAddr(t, srv)` for server integration tests
- **Sentinel errors** live in sub-packages, root package re-exports as aliases

56
CLAUDE.md Normal file
View File

@@ -0,0 +1,56 @@
# CLAUDE.md — httpx
## Commands
```bash
go build ./... # compile
go test ./... # all tests
go test -race ./... # tests with race detector
go test -v -run TestName ./package/ # single test
go vet ./... # static analysis
```
## Architecture
- **Module**: `git.codelab.vc/pkg/httpx`, Go 1.24, zero external dependencies
### Client
- **Core pattern**: middleware is `func(http.RoundTripper) http.RoundTripper`
- **Chain assembly order** (client.go): Logging → User MW → Retry → CB → Balancer → Transport
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
- **Circuit breaker** is per-host (`sync.Map` of host → Breaker)
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
- **Client.Close()** stops the health checker goroutine
- **Client.Patch()** — PATCH method, same pattern as Put/Post
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
### Server (`server/`)
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers, `WithNotFoundHandler` for custom 404
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock`
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
## Conventions
- Functional options for all configuration (client and server)
- Test helpers: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server)
- Server tests use `waitForAddr(t, srv)` helper to poll until server is ready
- No external test frameworks — stdlib only
- Thread safety required (`sync.Mutex`/`atomic`)
- `internal/clock` for deterministic time testing
## See also
- `AGENTS.md` — universal AI agent guide with common tasks, gotchas, and ASCII diagrams

211
README.md
View File

@@ -1,2 +1,213 @@
# httpx # httpx
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking, request ID propagation, response size limits — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging, CORS, rate limiting, body limits, timeouts), health checks, JSON helpers, graceful shutdown. stdlib only, zero external deps.
```
go get git.codelab.vc/pkg/httpx
```
## Quick start
```go
client := httpx.New(
httpx.WithBaseURL("https://api.example.com"),
httpx.WithTimeout(10*time.Second),
httpx.WithRetry(retry.WithMaxAttempts(3)),
httpx.WithMiddleware(
middleware.UserAgent("my-service/1.0"),
middleware.BearerAuth(func(ctx context.Context) (string, error) {
return os.Getenv("API_TOKEN"), nil
}),
),
)
defer client.Close()
resp, err := client.Get(ctx, "/users/123")
if err != nil {
log.Fatal(err)
}
var user User
resp.JSON(&user)
// PATCH request
resp, err = client.Patch(ctx, "/users/123", strings.NewReader(`{"name":"updated"}`))
// Form-encoded request (OAuth, webhooks, etc.)
req, _ := httpx.NewFormRequest(ctx, http.MethodPost, "/oauth/token", url.Values{
"grant_type": {"client_credentials"},
"scope": {"read write"},
})
resp, err = client.Do(ctx, req)
```
## Packages
### Client
Client middleware is `func(http.RoundTripper) http.RoundTripper`. Use them with `httpx.Client` or plug into a plain `http.Client`.
| Package | What it does |
|---------|-------------|
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery, request ID propagation. |
### Server
Server middleware is `func(http.Handler) http.Handler`. The `server` package provides a production-ready HTTP server.
| Component | What it does |
|-----------|-------------|
| `server.Server` | Wraps `http.Server` with graceful shutdown, signal handling, lifecycle logging. |
| `server.Router` | Lightweight wrapper around `http.ServeMux` with groups, prefix routing, sub-router mounting. |
| `server.RequestID` | Assigns/propagates `X-Request-Id` (UUID v4 via `crypto/rand`). |
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. |
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
| `server.WriteError` | JSON error response (`{"error": "..."}`) helper. |
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
The client assembles them in this order:
```
Request → Logging → Your Middleware → Retry → Circuit Breaker → Balancer → Transport
```
Retry wraps the circuit breaker and balancer, so each attempt can pick a different endpoint.
## Multi-DC setup
```go
client := httpx.New(
httpx.WithEndpoints(
balancer.Endpoint{URL: "https://dc1.api.internal", Weight: 3},
balancer.Endpoint{URL: "https://dc2.api.internal", Weight: 1},
),
httpx.WithBalancer(balancer.WithStrategy(balancer.WeightedRandom())),
httpx.WithRetry(retry.WithMaxAttempts(4)),
httpx.WithCircuitBreaker(circuitbreaker.WithFailureThreshold(5)),
httpx.WithLogger(slog.Default()),
)
defer client.Close()
```
## Standalone usage
Each component works with any `http.Client`, no need for the full wrapper:
```go
// Just retry, nothing else
transport := retry.Transport(retry.WithMaxAttempts(3))
httpClient := &http.Client{
Transport: transport(http.DefaultTransport),
}
```
```go
// Chain a few middlewares together
chain := middleware.Chain(
middleware.Logging(slog.Default()),
middleware.UserAgent("my-service/1.0"),
retry.Transport(retry.WithMaxAttempts(2)),
)
httpClient := &http.Client{
Transport: chain(http.DefaultTransport),
}
```
## Server
```go
logger := slog.Default()
r := server.NewRouter(
// Custom JSON 404 instead of plain text
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, 404, "not found")
})),
)
r.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
server.WriteJSON(w, 200, map[string]string{"message": "world"})
})
// Groups with middleware
api := r.Group("/api/v1", authMiddleware)
api.HandleFunc("GET /users/{id}", getUser)
// Health checks
r.Mount("/", server.HealthHandler(
func() error { return db.Ping() },
))
srv := server.New(r,
append(server.Defaults(logger),
// Protection middleware
server.WithMiddleware(
server.CORS(
server.AllowOrigins("https://app.example.com"),
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
server.AllowHeaders("Authorization", "Content-Type"),
server.MaxAge(3600),
),
server.RateLimit(
server.WithRate(100),
server.WithBurst(200),
),
server.MaxBodySize(1<<20), // 1 MB
server.Timeout(30*time.Second),
),
)...,
)
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
```
## Client request ID propagation
In microservices, forward the incoming request ID to downstream calls:
```go
client := httpx.New(
httpx.WithMiddleware(middleware.RequestID()),
)
// In a server handler — the context already has the request ID from server.RequestID():
func handler(w http.ResponseWriter, r *http.Request) {
// ID is automatically forwarded as X-Request-Id
resp, err := client.Get(r.Context(), "https://downstream/api")
}
```
## Response body limit
Protect against OOM from unexpectedly large upstream responses:
```go
client := httpx.New(
httpx.WithMaxResponseBody(10 << 20), // 10 MB max
)
```
## Examples
See the [`examples/`](examples/) directory for runnable programs:
| Example | Description |
|---------|-------------|
| [`basic-client`](examples/basic-client/) | HTTP client with retry, timeout, logging, and response size limit |
| [`form-request`](examples/form-request/) | Form-encoded POST requests (OAuth, webhooks) |
| [`load-balancing`](examples/load-balancing/) | Multi-endpoint client with weighted balancing, circuit breaker, and health checks |
| [`server-basic`](examples/server-basic/) | Server with routing, groups, JSON helpers, health checks, and custom 404 |
| [`server-protected`](examples/server-protected/) | Production server with CORS, rate limiting, body limits, and timeouts |
| [`request-id-propagation`](examples/request-id-propagation/) | Request ID forwarding between server and client for distributed tracing |
## Requirements
Go 1.24+, stdlib only.

104
balancer/balancer.go Normal file
View File

@@ -0,0 +1,104 @@
package balancer
import (
"errors"
"fmt"
"net/http"
"net/url"
"git.codelab.vc/pkg/httpx/middleware"
)
// ErrNoHealthy is returned when no healthy endpoints are available.
var ErrNoHealthy = errors.New("httpx: no healthy endpoints available")
// Endpoint represents a backend server that can handle requests.
type Endpoint struct {
URL string
Weight int
Meta map[string]string
}
// Strategy selects an endpoint from the list of healthy endpoints.
type Strategy interface {
Next(healthy []Endpoint) (Endpoint, error)
}
// Closer can be used to shut down resources associated with a balancer
// transport (e.g. background health checker goroutines).
type Closer struct {
healthChecker *HealthChecker
}
// Close stops background goroutines. Safe to call multiple times.
func (c *Closer) Close() {
if c.healthChecker != nil {
c.healthChecker.Stop()
}
}
// Transport returns a middleware that load-balances requests across the
// provided endpoints using the configured strategy.
//
// For each request the middleware picks an endpoint via the strategy,
// replaces the request URL scheme and host with the endpoint's URL,
// and forwards the request to the underlying RoundTripper.
//
// If active health checking is enabled (WithHealthCheck), a background
// goroutine periodically probes endpoints. Otherwise all endpoints are
// assumed healthy.
func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Closer) {
o := &options{
strategy: RoundRobin(),
}
for _, opt := range opts {
opt(o)
}
// Pre-parse endpoint URLs once at construction time.
parsed := make(map[string]*url.URL, len(endpoints))
for _, ep := range endpoints {
u, err := url.Parse(ep.URL)
if err != nil {
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
}
parsed[ep.URL] = u
}
if o.healthChecker != nil {
o.healthChecker.Start(endpoints)
}
closer := &Closer{healthChecker: o.healthChecker}
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
healthy := endpoints
if o.healthChecker != nil {
healthy = o.healthChecker.Healthy(endpoints)
}
if len(healthy) == 0 {
return nil, ErrNoHealthy
}
ep, err := o.strategy.Next(healthy)
if err != nil {
return nil, err
}
epURL := parsed[ep.URL]
// Shallow-copy request and URL to avoid mutating the original,
// without the expense of req.Clone's deep header copy.
r := *req
u := *req.URL
r.URL = &u
r.URL.Scheme = epURL.Scheme
r.URL.Host = epURL.Host
r.Host = epURL.Host
return next.RoundTrip(&r)
})
}, closer
}

214
balancer/balancer_test.go Normal file
View File

@@ -0,0 +1,214 @@
package balancer
import (
"io"
"math"
"net/http"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
endpoints := []Endpoint{
{URL: "https://backend1.example.com"},
}
var captured *http.Request
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req
return okResponse(), nil
})
mw, _ := Transport(endpoints)
rt := mw(base)
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil)
if err != nil {
t.Fatal(err)
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if captured == nil {
t.Fatal("base transport was not called")
}
if captured.URL.Scheme != "https" {
t.Errorf("scheme = %q, want %q", captured.URL.Scheme, "https")
}
if captured.URL.Host != "backend1.example.com" {
t.Errorf("host = %q, want %q", captured.URL.Host, "backend1.example.com")
}
if captured.URL.Path != "/api/v1/users" {
t.Errorf("path = %q, want %q", captured.URL.Path, "/api/v1/users")
}
}
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
var endpoints []Endpoint
base := mockTransport(func(req *http.Request) (*http.Response, error) {
t.Fatal("base transport should not be called")
return nil, nil
})
mw, _ := Transport(endpoints)
rt := mw(base)
req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil)
if err != nil {
t.Fatal(err)
}
_, err = rt.RoundTrip(req)
if err != ErrNoHealthy {
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
}
}
func TestRoundRobin_DistributesEvenly(t *testing.T) {
endpoints := []Endpoint{
{URL: "https://a.example.com"},
{URL: "https://b.example.com"},
{URL: "https://c.example.com"},
}
rr := RoundRobin()
counts := make(map[string]int)
const iterations = 300
for i := 0; i < iterations; i++ {
ep, err := rr.Next(endpoints)
if err != nil {
t.Fatalf("iteration %d: unexpected error: %v", i, err)
}
counts[ep.URL]++
}
expected := iterations / len(endpoints)
for _, ep := range endpoints {
got := counts[ep.URL]
if got != expected {
t.Errorf("endpoint %s: got %d calls, want %d", ep.URL, got, expected)
}
}
}
func TestRoundRobin_ErrNoHealthy(t *testing.T) {
rr := RoundRobin()
_, err := rr.Next(nil)
if err != ErrNoHealthy {
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
}
}
func TestFailover_AlwaysPicksFirst(t *testing.T) {
endpoints := []Endpoint{
{URL: "https://primary.example.com"},
{URL: "https://secondary.example.com"},
{URL: "https://tertiary.example.com"},
}
fo := Failover()
for i := 0; i < 10; i++ {
ep, err := fo.Next(endpoints)
if err != nil {
t.Fatalf("iteration %d: unexpected error: %v", i, err)
}
if ep.URL != "https://primary.example.com" {
t.Errorf("iteration %d: got %q, want %q", i, ep.URL, "https://primary.example.com")
}
}
}
func TestFailover_ErrNoHealthy(t *testing.T) {
fo := Failover()
_, err := fo.Next(nil)
if err != ErrNoHealthy {
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
}
}
func TestWeightedRandom_RespectsWeights(t *testing.T) {
endpoints := []Endpoint{
{URL: "https://heavy.example.com", Weight: 80},
{URL: "https://light.example.com", Weight: 20},
}
wr := WeightedRandom()
counts := make(map[string]int)
const iterations = 10000
for i := 0; i < iterations; i++ {
ep, err := wr.Next(endpoints)
if err != nil {
t.Fatalf("iteration %d: unexpected error: %v", i, err)
}
counts[ep.URL]++
}
totalWeight := 0
for _, ep := range endpoints {
totalWeight += ep.Weight
}
for _, ep := range endpoints {
got := float64(counts[ep.URL]) / float64(iterations)
want := float64(ep.Weight) / float64(totalWeight)
if math.Abs(got-want) > 0.05 {
t.Errorf("endpoint %s: got ratio %.3f, want ~%.3f (tolerance 0.05)", ep.URL, got, want)
}
}
}
func TestWeightedRandom_DefaultWeightForZero(t *testing.T) {
endpoints := []Endpoint{
{URL: "https://a.example.com", Weight: 0},
{URL: "https://b.example.com", Weight: 0},
}
wr := WeightedRandom()
counts := make(map[string]int)
const iterations = 1000
for i := 0; i < iterations; i++ {
ep, err := wr.Next(endpoints)
if err != nil {
t.Fatalf("iteration %d: unexpected error: %v", i, err)
}
counts[ep.URL]++
}
// With equal default weights, distribution should be roughly even.
for _, ep := range endpoints {
got := float64(counts[ep.URL]) / float64(iterations)
if math.Abs(got-0.5) > 0.1 {
t.Errorf("endpoint %s: got ratio %.3f, want ~0.5 (tolerance 0.1)", ep.URL, got)
}
}
}
func TestWeightedRandom_ErrNoHealthy(t *testing.T) {
wr := WeightedRandom()
_, err := wr.Next(nil)
if err != ErrNoHealthy {
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
}
}

34
balancer/doc.go Normal file
View File

@@ -0,0 +1,34 @@
// Package balancer provides client-side load balancing as HTTP middleware.
//
// It distributes requests across multiple backend endpoints using pluggable
// strategies (round-robin, weighted, failover) with optional health checking.
//
// # Usage
//
// mw, closer := balancer.Transport(
// []balancer.Endpoint{
// {URL: "http://backend1:8080"},
// {URL: "http://backend2:8080"},
// },
// balancer.WithStrategy(balancer.RoundRobin()),
// balancer.WithHealthCheck(5 * time.Second),
// )
// defer closer.Close()
// transport := mw(http.DefaultTransport)
//
// # Strategies
//
// - RoundRobin — cycles through healthy endpoints
// - Weighted — distributes based on endpoint Weight field
// - Failover — prefers primary, falls back to secondaries
//
// # Health checking
//
// When enabled, a background goroutine periodically probes each endpoint.
// The returned Closer must be closed to stop the health checker goroutine.
// In httpx.Client, this is handled by Client.Close().
//
// # Sentinel errors
//
// ErrNoHealthy is returned when no healthy endpoints are available.
package balancer

17
balancer/failover.go Normal file
View File

@@ -0,0 +1,17 @@
package balancer
type failover struct{}
// Failover returns a strategy that always picks the first healthy endpoint.
// If the primary endpoint is unhealthy, it falls back to the next available
// healthy endpoint in order.
func Failover() Strategy {
return &failover{}
}
func (f *failover) Next(healthy []Endpoint) (Endpoint, error) {
if len(healthy) == 0 {
return Endpoint{}, ErrNoHealthy
}
return healthy[0], nil
}

174
balancer/health.go Normal file
View File

@@ -0,0 +1,174 @@
package balancer
import (
"context"
"io"
"net/http"
"sync"
"time"
)
const (
defaultHealthInterval = 10 * time.Second
defaultHealthPath = "/health"
defaultHealthTimeout = 5 * time.Second
)
// HealthOption configures the HealthChecker.
type HealthOption func(*HealthChecker)
// WithHealthInterval sets the interval between health check probes.
// Default is 10 seconds.
func WithHealthInterval(d time.Duration) HealthOption {
return func(h *HealthChecker) {
h.interval = d
}
}
// WithHealthPath sets the HTTP path to probe for health checks.
// Default is "/health".
func WithHealthPath(path string) HealthOption {
return func(h *HealthChecker) {
h.path = path
}
}
// WithHealthTimeout sets the timeout for each health check request.
// Default is 5 seconds.
func WithHealthTimeout(d time.Duration) HealthOption {
return func(h *HealthChecker) {
h.timeout = d
}
}
// HealthChecker periodically probes endpoints to determine their health status.
type HealthChecker struct {
interval time.Duration
path string
timeout time.Duration
client *http.Client
mu sync.RWMutex
status map[string]bool
cancel context.CancelFunc
stopped chan struct{}
}
func newHealthChecker(opts ...HealthOption) *HealthChecker {
h := &HealthChecker{
interval: defaultHealthInterval,
path: defaultHealthPath,
timeout: defaultHealthTimeout,
status: make(map[string]bool),
}
for _, opt := range opts {
opt(h)
}
h.client = &http.Client{
Timeout: h.timeout,
}
return h
}
// Start begins the background health checking loop for the given endpoints.
// An initial probe is run synchronously so that unhealthy endpoints are
// detected before the first request.
func (h *HealthChecker) Start(endpoints []Endpoint) {
// Mark all healthy as a safe default, then immediately probe.
h.mu.Lock()
for _, ep := range endpoints {
h.status[ep.URL] = true
}
h.mu.Unlock()
ctx, cancel := context.WithCancel(context.Background())
h.cancel = cancel
h.stopped = make(chan struct{})
// Run initial probe synchronously so callers don't hit stale state.
h.probe(ctx, endpoints)
go h.loop(ctx, endpoints)
}
// Stop terminates the background health checking goroutine and waits for
// it to finish.
func (h *HealthChecker) Stop() {
if h.cancel != nil {
h.cancel()
<-h.stopped
}
}
// IsHealthy reports whether the given endpoint is currently healthy.
func (h *HealthChecker) IsHealthy(ep Endpoint) bool {
h.mu.RLock()
defer h.mu.RUnlock()
healthy, ok := h.status[ep.URL]
if !ok {
return false
}
return healthy
}
// Healthy returns the subset of endpoints that are currently healthy.
func (h *HealthChecker) Healthy(endpoints []Endpoint) []Endpoint {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]Endpoint, 0, len(endpoints))
for _, ep := range endpoints {
if h.status[ep.URL] {
result = append(result, ep)
}
}
return result
}
func (h *HealthChecker) loop(ctx context.Context, endpoints []Endpoint) {
defer close(h.stopped)
ticker := time.NewTicker(h.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
h.probe(ctx, endpoints)
}
}
}
func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) {
var wg sync.WaitGroup
wg.Add(len(endpoints))
for _, ep := range endpoints {
go func() {
defer wg.Done()
healthy := h.check(ctx, ep)
h.mu.Lock()
h.status[ep.URL] = healthy
h.mu.Unlock()
}()
}
wg.Wait()
}
func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ep.URL+h.path, nil)
if err != nil {
return false
}
resp, err := h.client.Do(req)
if err != nil {
return false
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
return resp.StatusCode >= 200 && resp.StatusCode < 300
}

25
balancer/options.go Normal file
View File

@@ -0,0 +1,25 @@
package balancer
// options holds configuration for the load balancer transport.
type options struct {
strategy Strategy // default RoundRobin
healthChecker *HealthChecker // optional
}
// Option configures the load balancer transport.
type Option func(*options)
// WithStrategy sets the endpoint selection strategy.
// If not specified, RoundRobin is used.
func WithStrategy(s Strategy) Option {
return func(o *options) {
o.strategy = s
}
}
// WithHealthCheck enables active health checking of endpoints.
func WithHealthCheck(opts ...HealthOption) Option {
return func(o *options) {
o.healthChecker = newHealthChecker(opts...)
}
}

21
balancer/roundrobin.go Normal file
View File

@@ -0,0 +1,21 @@
package balancer
import "sync/atomic"
type roundRobin struct {
counter atomic.Uint64
}
// RoundRobin returns a strategy that cycles through healthy endpoints
// sequentially using an atomic counter.
func RoundRobin() Strategy {
return &roundRobin{}
}
func (r *roundRobin) Next(healthy []Endpoint) (Endpoint, error) {
if len(healthy) == 0 {
return Endpoint{}, ErrNoHealthy
}
idx := r.counter.Add(1) - 1
return healthy[idx%uint64(len(healthy))], nil
}

42
balancer/weighted.go Normal file
View File

@@ -0,0 +1,42 @@
package balancer
import "math/rand/v2"
type weightedRandom struct{}
// WeightedRandom returns a strategy that selects endpoints randomly,
// weighted by each endpoint's Weight field. Endpoints with Weight <= 0
// are treated as having a weight of 1.
func WeightedRandom() Strategy {
return &weightedRandom{}
}
func (w *weightedRandom) Next(healthy []Endpoint) (Endpoint, error) {
if len(healthy) == 0 {
return Endpoint{}, ErrNoHealthy
}
totalWeight := 0
for _, ep := range healthy {
weight := ep.Weight
if weight <= 0 {
weight = 1
}
totalWeight += weight
}
r := rand.IntN(totalWeight)
for _, ep := range healthy {
weight := ep.Weight
if weight <= 0 {
weight = 1
}
r -= weight
if r < 0 {
return ep, nil
}
}
// Should never reach here, but return last endpoint as a safeguard.
return healthy[len(healthy)-1], nil
}

176
circuitbreaker/breaker.go Normal file
View File

@@ -0,0 +1,176 @@
package circuitbreaker
import (
"errors"
"net/http"
"sync"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
// ErrCircuitOpen is returned by Allow when the breaker is in the Open state.
var ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
// State represents the current state of a circuit breaker.
type State int
const (
StateClosed State = iota // normal operation
StateOpen // failing, reject requests
StateHalfOpen // testing recovery
)
// String returns a human-readable name for the state.
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateOpen:
return "open"
case StateHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Breaker implements a per-endpoint circuit breaker state machine.
//
// State transitions:
//
// Closed → Open: after failureThreshold consecutive failures
// Open → HalfOpen: after openDuration passes
// HalfOpen → Closed: on success
// HalfOpen → Open: on failure (timer resets)
type Breaker struct {
mu sync.Mutex
opts options
state State
failures int // consecutive failure count (Closed state)
openedAt time.Time
halfOpenCur int // current in-flight half-open requests
}
// NewBreaker creates a Breaker with the given options.
func NewBreaker(opts ...Option) *Breaker {
o := defaults()
for _, fn := range opts {
fn(&o)
}
return &Breaker{opts: o}
}
// State returns the current state of the breaker.
func (b *Breaker) State() State {
b.mu.Lock()
defer b.mu.Unlock()
return b.stateLocked()
}
// stateLocked returns the effective state, promoting Open → HalfOpen when the
// open duration has elapsed. Caller must hold b.mu.
func (b *Breaker) stateLocked() State {
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
b.state = StateHalfOpen
b.halfOpenCur = 0
}
return b.state
}
// Allow checks whether a request is permitted. If allowed it returns a done
// callback that the caller MUST invoke with the result of the request. If the
// breaker is open, it returns ErrCircuitOpen.
func (b *Breaker) Allow() (done func(success bool), err error) {
b.mu.Lock()
defer b.mu.Unlock()
switch b.stateLocked() {
case StateClosed:
// always allow
case StateOpen:
return nil, ErrCircuitOpen
case StateHalfOpen:
if b.halfOpenCur >= b.opts.halfOpenMax {
return nil, ErrCircuitOpen
}
b.halfOpenCur++
}
return b.doneFunc(), nil
}
// doneFunc returns the callback for a single in-flight request. Caller must
// hold b.mu when calling doneFunc, but the returned function acquires the lock
// itself.
func (b *Breaker) doneFunc() func(success bool) {
var once sync.Once
return func(success bool) {
once.Do(func() {
b.mu.Lock()
defer b.mu.Unlock()
b.record(success)
})
}
}
// record processes the outcome of a single request. Caller must hold b.mu.
func (b *Breaker) record(success bool) {
switch b.state {
case StateClosed:
if success {
b.failures = 0
return
}
b.failures++
if b.failures >= b.opts.failureThreshold {
b.tripLocked()
}
case StateHalfOpen:
b.halfOpenCur--
if success {
b.state = StateClosed
b.failures = 0
} else {
b.tripLocked()
}
}
}
// tripLocked transitions to the Open state and records the timestamp.
func (b *Breaker) tripLocked() {
b.state = StateOpen
b.openedAt = time.Now()
b.halfOpenCur = 0
}
// Transport returns a middleware that applies per-host circuit breaking. It
// maintains an internal map of host → *Breaker so each target host is tracked
// independently.
func Transport(opts ...Option) middleware.Middleware {
var hosts sync.Map // map[string]*Breaker
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
host := req.URL.Host
val, ok := hosts.Load(host)
if !ok {
val, _ = hosts.LoadOrStore(host, NewBreaker(opts...))
}
cb := val.(*Breaker)
done, err := cb.Allow()
if err != nil {
return nil, err
}
resp, rtErr := next.RoundTrip(req)
done(rtErr == nil && resp != nil && resp.StatusCode < 500)
return resp, rtErr
})
}
}

View File

@@ -0,0 +1,249 @@
package circuitbreaker
import (
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func errResponse(code int) *http.Response {
return &http.Response{
StatusCode: code,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func TestBreaker_StartsInClosedState(t *testing.T) {
b := NewBreaker()
if s := b.State(); s != StateClosed {
t.Fatalf("state = %v, want %v", s, StateClosed)
}
}
func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) {
const threshold = 3
b := NewBreaker(
WithFailureThreshold(threshold),
WithOpenDuration(time.Hour), // long duration so it stays open
)
for i := 0; i < threshold; i++ {
done, err := b.Allow()
if err != nil {
t.Fatalf("iteration %d: Allow returned error: %v", i, err)
}
done(false)
}
if s := b.State(); s != StateOpen {
t.Fatalf("state = %v, want %v", s, StateOpen)
}
}
func TestBreaker_OpenRejectsRequests(t *testing.T) {
b := NewBreaker(
WithFailureThreshold(1),
WithOpenDuration(time.Hour),
)
// Trip the breaker.
done, err := b.Allow()
if err != nil {
t.Fatalf("Allow returned error: %v", err)
}
done(false)
// Subsequent requests should be rejected.
_, err = b.Allow()
if !errors.Is(err, ErrCircuitOpen) {
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
}
}
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
const openDuration = 50 * time.Millisecond
b := NewBreaker(
WithFailureThreshold(1),
WithOpenDuration(openDuration),
)
// Trip the breaker.
done, err := b.Allow()
if err != nil {
t.Fatal(err)
}
done(false)
if s := b.State(); s != StateOpen {
t.Fatalf("state = %v, want %v", s, StateOpen)
}
// Wait for the open duration to elapse.
time.Sleep(openDuration + 10*time.Millisecond)
if s := b.State(); s != StateHalfOpen {
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
}
}
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
const openDuration = 50 * time.Millisecond
b := NewBreaker(
WithFailureThreshold(1),
WithOpenDuration(openDuration),
)
// Trip the breaker.
done, err := b.Allow()
if err != nil {
t.Fatal(err)
}
done(false)
// Wait for half-open.
time.Sleep(openDuration + 10*time.Millisecond)
// A successful request in half-open should close the breaker.
done, err = b.Allow()
if err != nil {
t.Fatalf("Allow in half-open returned error: %v", err)
}
done(true)
if s := b.State(); s != StateClosed {
t.Fatalf("state = %v, want %v", s, StateClosed)
}
}
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
const openDuration = 50 * time.Millisecond
b := NewBreaker(
WithFailureThreshold(1),
WithOpenDuration(openDuration),
)
// Trip the breaker.
done, err := b.Allow()
if err != nil {
t.Fatal(err)
}
done(false)
// Wait for half-open.
time.Sleep(openDuration + 10*time.Millisecond)
// A failed request in half-open should re-open the breaker.
done, err = b.Allow()
if err != nil {
t.Fatalf("Allow in half-open returned error: %v", err)
}
done(false)
if s := b.State(); s != StateOpen {
t.Fatalf("state = %v, want %v", s, StateOpen)
}
}
func TestTransport_PerHostBreakers(t *testing.T) {
const threshold = 2
base := mockTransport(func(req *http.Request) (*http.Response, error) {
if req.URL.Host == "failing.example.com" {
return errResponse(http.StatusInternalServerError), nil
}
return okResponse(), nil
})
rt := Transport(
WithFailureThreshold(threshold),
WithOpenDuration(time.Hour),
)(base)
t.Run("failing host trips breaker", func(t *testing.T) {
for i := 0; i < threshold; i++ {
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
if err != nil {
t.Fatal(err)
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("iteration %d: unexpected error: %v", i, err)
}
resp.Body.Close()
}
// Next request to failing host should be rejected.
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
if err != nil {
t.Fatal(err)
}
_, err = rt.RoundTrip(req)
if !errors.Is(err, ErrCircuitOpen) {
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
}
})
t.Run("healthy host is unaffected", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil)
if err != nil {
t.Fatal(err)
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
})
}
func TestTransport_SuccessResetsFailures(t *testing.T) {
callCount := 0
base := mockTransport(func(req *http.Request) (*http.Response, error) {
callCount++
// Fail on odd calls, succeed on even.
if callCount%2 == 1 {
return errResponse(http.StatusInternalServerError), nil
}
return okResponse(), nil
})
rt := Transport(
WithFailureThreshold(3),
WithOpenDuration(time.Hour),
)(base)
// Alternate fail/success — should never trip because successes reset the
// consecutive failure counter.
for i := 0; i < 10; i++ {
req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil)
if err != nil {
t.Fatal(err)
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err)
}
resp.Body.Close()
}
}

27
circuitbreaker/doc.go Normal file
View File

@@ -0,0 +1,27 @@
// Package circuitbreaker provides a per-host circuit breaker as HTTP middleware.
//
// The circuit breaker monitors request failures and temporarily blocks requests
// to unhealthy hosts, allowing them time to recover before retrying.
//
// # State machine
//
// - Closed — normal operation, requests pass through
// - Open — too many failures, requests are rejected with ErrCircuitOpen
// - HalfOpen — after a cooldown period, one probe request is allowed through
//
// # Usage
//
// mw := circuitbreaker.Transport(
// circuitbreaker.WithThreshold(5),
// circuitbreaker.WithTimeout(30 * time.Second),
// )
// transport := mw(http.DefaultTransport)
//
// The circuit breaker is per-host: each unique request host gets its own
// independent breaker state machine stored in a sync.Map.
//
// # Sentinel errors
//
// ErrCircuitOpen is returned when a request is rejected because the circuit
// is in the Open state.
package circuitbreaker

50
circuitbreaker/options.go Normal file
View File

@@ -0,0 +1,50 @@
package circuitbreaker
import "time"
type options struct {
failureThreshold int // consecutive failures to trip
openDuration time.Duration // how long to stay open before half-open
halfOpenMax int // max concurrent requests in half-open
}
func defaults() options {
return options{
failureThreshold: 5,
openDuration: 30 * time.Second,
halfOpenMax: 1,
}
}
// Option configures a Breaker.
type Option func(*options)
// WithFailureThreshold sets the number of consecutive failures required to
// trip the breaker from Closed to Open. Default is 5.
func WithFailureThreshold(n int) Option {
return func(o *options) {
if n > 0 {
o.failureThreshold = n
}
}
}
// WithOpenDuration sets how long the breaker stays in the Open state before
// transitioning to HalfOpen. Default is 30s.
func WithOpenDuration(d time.Duration) Option {
return func(o *options) {
if d > 0 {
o.openDuration = d
}
}
}
// WithHalfOpenMax sets the maximum number of concurrent probe requests
// allowed while the breaker is in the HalfOpen state. Default is 1.
func WithHalfOpenMax(n int) Option {
return func(o *options) {
if n > 0 {
o.halfOpenMax = n
}
}
}

204
client.go Normal file
View File

@@ -0,0 +1,204 @@
package httpx
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/circuitbreaker"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/retry"
)
// Client is a high-level HTTP client that composes middleware for retry,
// circuit breaking, load balancing, logging, and more.
type Client struct {
httpClient *http.Client
baseURL string
errorMapper ErrorMapper
balancerCloser *balancer.Closer
maxResponseBody int64
}
// New creates a new Client with the given options.
//
// The middleware chain is assembled as (outermost → innermost):
//
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Base Transport
func New(opts ...Option) *Client {
o := &clientOptions{
transport: http.DefaultTransport,
}
for _, opt := range opts {
opt(o)
}
// Build the middleware chain from inside out.
var chain []middleware.Middleware
// Balancer (innermost, wraps base transport).
var balancerCloser *balancer.Closer
if len(o.endpoints) > 0 {
var mw middleware.Middleware
mw, balancerCloser = balancer.Transport(o.endpoints, o.balancerOpts...)
chain = append(chain, mw)
}
// Circuit breaker wraps balancer.
if o.enableCB {
chain = append(chain, circuitbreaker.Transport(o.cbOpts...))
}
// Retry wraps circuit breaker + balancer.
if o.enableRetry {
chain = append(chain, retry.Transport(o.retryOpts...))
}
// User middlewares.
for i := len(o.middlewares) - 1; i >= 0; i-- {
chain = append(chain, o.middlewares[i])
}
// Logging (outermost).
if o.logger != nil {
chain = append(chain, middleware.Logging(o.logger))
}
// Assemble: chain[last] is outermost.
rt := o.transport
for _, mw := range chain {
rt = mw(rt)
}
return &Client{
httpClient: &http.Client{
Transport: rt,
Timeout: o.timeout,
},
baseURL: o.baseURL,
errorMapper: o.errorMapper,
balancerCloser: balancerCloser,
maxResponseBody: o.maxResponseBody,
}
}
// Do executes an HTTP request.
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
req = req.WithContext(ctx)
if err := c.resolveURL(req); err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, &Error{
Op: req.Method,
URL: req.URL.String(),
Err: err,
}
}
if c.maxResponseBody > 0 {
resp.Body = &limitedReadCloser{
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody},
C: resp.Body,
}
}
r := newResponse(resp)
if c.errorMapper != nil {
if mapErr := c.errorMapper(resp); mapErr != nil {
return r, &Error{
Op: req.Method,
URL: req.URL.String(),
StatusCode: resp.StatusCode,
Err: mapErr,
}
}
}
return r, nil
}
// Get performs a GET request to the given URL.
func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Post performs a POST request to the given URL with the given body.
func (c *Client) Post(ctx context.Context, url string, body io.Reader) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Put performs a PUT request to the given URL with the given body.
func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Patch performs a PATCH request to the given URL with the given body.
func (c *Client) Patch(ctx context.Context, url string, body io.Reader) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, body)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Delete performs a DELETE request to the given URL.
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Close releases resources associated with the Client, such as background
// health checker goroutines. It is safe to call multiple times.
func (c *Client) Close() {
if c.balancerCloser != nil {
c.balancerCloser.Close()
}
}
// HTTPClient returns the underlying *http.Client for advanced use cases.
// Mutating the returned client may bypass the configured middleware chain.
func (c *Client) HTTPClient() *http.Client {
return c.httpClient
}
func (c *Client) resolveURL(req *http.Request) error {
if c.baseURL == "" {
return nil
}
// Only resolve relative URLs (no scheme).
if req.URL.Scheme == "" && req.URL.Host == "" {
path := req.URL.String()
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
base := strings.TrimRight(c.baseURL, "/")
u, err := req.URL.Parse(base + path)
if err != nil {
return fmt.Errorf("httpx: resolving URL %q with base %q: %w", path, c.baseURL, err)
}
req.URL = u
}
return nil
}

96
client_options.go Normal file
View File

@@ -0,0 +1,96 @@
package httpx
import (
"log/slog"
"net/http"
"time"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/circuitbreaker"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/retry"
)
type clientOptions struct {
baseURL string
timeout time.Duration
transport http.RoundTripper
logger *slog.Logger
errorMapper ErrorMapper
middlewares []middleware.Middleware
retryOpts []retry.Option
enableRetry bool
cbOpts []circuitbreaker.Option
enableCB bool
endpoints []balancer.Endpoint
balancerOpts []balancer.Option
maxResponseBody int64
}
// Option configures a Client.
type Option func(*clientOptions)
// WithBaseURL sets the base URL prepended to all relative request paths.
func WithBaseURL(url string) Option {
return func(o *clientOptions) { o.baseURL = url }
}
// WithTimeout sets the overall request timeout.
func WithTimeout(d time.Duration) Option {
return func(o *clientOptions) { o.timeout = d }
}
// WithTransport sets the base http.RoundTripper. Defaults to http.DefaultTransport.
func WithTransport(rt http.RoundTripper) Option {
return func(o *clientOptions) { o.transport = rt }
}
// WithLogger enables structured logging of requests and responses.
func WithLogger(l *slog.Logger) Option {
return func(o *clientOptions) { o.logger = l }
}
// WithErrorMapper sets a function that maps HTTP responses to errors.
func WithErrorMapper(m ErrorMapper) Option {
return func(o *clientOptions) { o.errorMapper = m }
}
// WithMiddleware appends user middlewares to the chain.
// These run between logging and retry in the middleware stack.
func WithMiddleware(mws ...middleware.Middleware) Option {
return func(o *clientOptions) { o.middlewares = append(o.middlewares, mws...) }
}
// WithRetry enables retry with the given options.
func WithRetry(opts ...retry.Option) Option {
return func(o *clientOptions) {
o.enableRetry = true
o.retryOpts = opts
}
}
// WithCircuitBreaker enables per-host circuit breaking.
func WithCircuitBreaker(opts ...circuitbreaker.Option) Option {
return func(o *clientOptions) {
o.enableCB = true
o.cbOpts = opts
}
}
// WithEndpoints sets the endpoints for load balancing.
func WithEndpoints(eps ...balancer.Endpoint) Option {
return func(o *clientOptions) { o.endpoints = eps }
}
// WithBalancer configures the load balancer strategy and options.
func WithBalancer(opts ...balancer.Option) Option {
return func(o *clientOptions) { o.balancerOpts = opts }
}
// WithMaxResponseBody limits the number of bytes read from response bodies
// by Response.Bytes (and by extension String, JSON, XML). If the response
// body exceeds n bytes, reading stops and returns an error.
// A value of 0 means no limit (the default).
func WithMaxResponseBody(n int64) Option {
return func(o *clientOptions) { o.maxResponseBody = n }
}

45
client_patch_test.go Normal file
View File

@@ -0,0 +1,45 @@
package httpx_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestClient_Patch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
t.Errorf("expected PATCH, got %s", r.Method)
}
b, _ := io.ReadAll(r.Body)
if string(b) != `{"name":"updated"}` {
t.Errorf("expected body %q, got %q", `{"name":"updated"}`, string(b))
}
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "patched")
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Patch(context.Background(), srv.URL+"/item/1", strings.NewReader(`{"name":"updated"}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "patched" {
t.Errorf("expected body %q, got %q", "patched", body)
}
}

311
client_test.go Normal file
View File

@@ -0,0 +1,311 @@
package httpx_test
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/retry"
)
func TestClient_Get(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "hello")
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Get(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "hello" {
t.Errorf("expected body %q, got %q", "hello", body)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestClient_Post(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
b, _ := io.ReadAll(r.Body)
if string(b) != "request-body" {
t.Errorf("expected body %q, got %q", "request-body", string(b))
}
w.WriteHeader(http.StatusCreated)
fmt.Fprint(w, "created")
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Post(context.Background(), srv.URL+"/items", strings.NewReader("request-body"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusCreated {
t.Errorf("expected status 201, got %d", resp.StatusCode)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "created" {
t.Errorf("expected body %q, got %q", "created", body)
}
}
func TestClient_BaseURL(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/users" {
t.Errorf("expected path /api/v1/users, got %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
client := httpx.New(httpx.WithBaseURL(srv.URL + "/api/v1"))
// Use a relative path (no scheme/host).
resp, err := client.Get(context.Background(), "/users")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestClient_WithMiddleware(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get("X-Custom-Header")
if val != "test-value" {
t.Errorf("expected header X-Custom-Header=%q, got %q", "test-value", val)
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
addHeader := func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Set("X-Custom-Header", "test-value")
return next.RoundTrip(req)
})
}
client := httpx.New(httpx.WithMiddleware(addHeader))
resp, err := client.Get(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestClient_RetryIntegration(t *testing.T) {
var calls atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := calls.Add(1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "success")
}))
defer srv.Close()
client := httpx.New(
httpx.WithRetry(
retry.WithMaxAttempts(3),
retry.WithBackoff(retry.ConstantBackoff(1*time.Millisecond)),
),
)
resp, err := client.Get(context.Background(), srv.URL+"/flaky")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "success" {
t.Errorf("expected body %q, got %q", "success", body)
}
if got := calls.Load(); got != 3 {
t.Errorf("expected 3 total requests, got %d", got)
}
}
func TestClient_BalancerIntegration(t *testing.T) {
var hits1, hits2 atomic.Int32
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits1.Add(1)
fmt.Fprint(w, "server1")
}))
defer srv1.Close()
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits2.Add(1)
fmt.Fprint(w, "server2")
}))
defer srv2.Close()
client := httpx.New(
httpx.WithEndpoints(
balancer.Endpoint{URL: srv1.URL},
balancer.Endpoint{URL: srv2.URL},
),
)
const totalRequests = 6
for i := range totalRequests {
resp, err := client.Get(context.Background(), fmt.Sprintf("/item/%d", i))
if err != nil {
t.Fatalf("request %d: unexpected error: %v", i, err)
}
resp.Close()
}
h1 := hits1.Load()
h2 := hits2.Load()
if h1+h2 != totalRequests {
t.Errorf("expected %d total hits, got %d", totalRequests, h1+h2)
}
if h1 == 0 || h2 == 0 {
t.Errorf("expected requests distributed across both servers, got server1=%d server2=%d", h1, h2)
}
}
func TestClient_ErrorMapper(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, "not found")
}))
defer srv.Close()
mapper := func(resp *http.Response) error {
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
return nil
}
client := httpx.New(httpx.WithErrorMapper(mapper))
resp, err := client.Get(context.Background(), srv.URL+"/missing")
if err == nil {
t.Fatal("expected error, got nil")
}
// The response should still be returned alongside the error.
if resp == nil {
t.Fatal("expected non-nil response even on mapped error")
}
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status 404, got %d", resp.StatusCode)
}
// Verify the error message contains the status code.
if !strings.Contains(err.Error(), "404") {
t.Errorf("expected error to contain 404, got: %v", err)
}
}
func TestClient_JSON(t *testing.T) {
type reqPayload struct {
Name string `json:"name"`
Age int `json:"age"`
}
type respPayload struct {
ID int `json:"id"`
Name string `json:"name"`
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("expected Content-Type application/json, got %q", ct)
}
var p reqPayload
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
t.Errorf("decoding request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
if p.Name != "Alice" || p.Age != 30 {
t.Errorf("unexpected payload: %+v", p)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(respPayload{ID: 1, Name: p.Name})
}))
defer srv.Close()
client := httpx.New()
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, srv.URL+"/users", reqPayload{
Name: "Alice",
Age: 30,
})
if err != nil {
t.Fatalf("creating JSON request: %v", err)
}
resp, err := client.Do(context.Background(), req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var result respPayload
if err := resp.JSON(&result); err != nil {
t.Fatalf("decoding JSON response: %v", err)
}
if result.ID != 1 {
t.Errorf("expected ID 1, got %d", result.ID)
}
if result.Name != "Alice" {
t.Errorf("expected Name %q, got %q", "Alice", result.Name)
}
}
// Ensure slog import is used (referenced in imports for completeness with the spec).
var _ = slog.Default

39
doc.go Normal file
View File

@@ -0,0 +1,39 @@
// Package httpx provides a high-level HTTP client with composable middleware
// for retry, circuit breaking, load balancing, structured logging, and more.
//
// The client is configured via functional options and assembled as a middleware
// chain around a standard http.RoundTripper:
//
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport
//
// # Quick start
//
// client := httpx.New(
// httpx.WithBaseURL("https://api.example.com"),
// httpx.WithTimeout(10 * time.Second),
// httpx.WithRetry(),
// httpx.WithCircuitBreaker(),
// )
// defer client.Close()
//
// resp, err := client.Get(ctx, "/users/1")
//
// # Request builders
//
// NewJSONRequest and NewFormRequest create requests with appropriate
// Content-Type headers and GetBody set for retry compatibility.
//
// # Error handling
//
// Failed requests return *httpx.Error with structured fields (Op, URL,
// StatusCode). Sentinel errors ErrRetryExhausted, ErrCircuitOpen, and
// ErrNoHealthy can be checked with errors.Is.
//
// # Sub-packages
//
// - middleware — client-side middleware (logging, auth, headers, recovery, request ID)
// - retry — configurable retry with backoff and Retry-After support
// - circuitbreaker — per-host circuit breaker (closed → open → half-open)
// - balancer — client-side load balancing with health checking
// - server — production HTTP server with router, middleware, and graceful shutdown
package httpx

49
error.go Normal file
View File

@@ -0,0 +1,49 @@
package httpx
import (
"fmt"
"net/http"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/circuitbreaker"
"git.codelab.vc/pkg/httpx/retry"
)
// Sentinel errors returned by httpx components.
// These are aliases for the canonical errors defined in sub-packages,
// so that errors.Is works regardless of which import the caller uses.
var (
ErrRetryExhausted = retry.ErrRetryExhausted
ErrCircuitOpen = circuitbreaker.ErrCircuitOpen
ErrNoHealthy = balancer.ErrNoHealthy
)
// Error provides structured error information for failed HTTP operations.
type Error struct {
// Op is the operation that failed (e.g. "Get", "Do").
Op string
// URL is the originally-requested URL.
URL string
// Endpoint is the resolved endpoint URL (after balancing).
Endpoint string
// StatusCode is the HTTP status code, if a response was received.
StatusCode int
// Retries is the number of retry attempts made.
Retries int
// Err is the underlying error.
Err error
}
func (e *Error) Error() string {
if e.Endpoint != "" && e.Endpoint != e.URL {
return fmt.Sprintf("httpx: %s %s (endpoint %s): %v", e.Op, e.URL, e.Endpoint, e.Err)
}
return fmt.Sprintf("httpx: %s %s: %v", e.Op, e.URL, e.Err)
}
func (e *Error) Unwrap() error { return e.Err }
// ErrorMapper maps an HTTP response to an error. If the response is
// acceptable, the mapper should return nil. Used by Client to convert
// non-successful HTTP responses into Go errors.
type ErrorMapper func(resp *http.Response) error

90
error_test.go Normal file
View File

@@ -0,0 +1,90 @@
package httpx_test
import (
"errors"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestError(t *testing.T) {
t.Run("formats without endpoint", func(t *testing.T) {
inner := errors.New("connection refused")
e := &httpx.Error{
Op: "Get",
URL: "http://example.com/api",
Err: inner,
}
want := "httpx: Get http://example.com/api: connection refused"
if got := e.Error(); got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("formats with endpoint different from url", func(t *testing.T) {
inner := errors.New("timeout")
e := &httpx.Error{
Op: "Do",
URL: "http://example.com/api",
Endpoint: "http://node1.example.com/api",
Err: inner,
}
want := "httpx: Do http://example.com/api (endpoint http://node1.example.com/api): timeout"
if got := e.Error(); got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("formats with endpoint same as url", func(t *testing.T) {
inner := errors.New("not found")
e := &httpx.Error{
Op: "Get",
URL: "http://example.com/api",
Endpoint: "http://example.com/api",
Err: inner,
}
want := "httpx: Get http://example.com/api: not found"
if got := e.Error(); got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("unwrap returns inner error", func(t *testing.T) {
inner := errors.New("underlying")
e := &httpx.Error{Op: "Get", URL: "http://example.com", Err: inner}
if got := e.Unwrap(); got != inner {
t.Errorf("Unwrap() = %v, want %v", got, inner)
}
if !errors.Is(e, inner) {
t.Error("errors.Is should find the inner error")
}
})
}
func TestSentinelErrors(t *testing.T) {
t.Run("ErrRetryExhausted", func(t *testing.T) {
if httpx.ErrRetryExhausted == nil {
t.Fatal("ErrRetryExhausted is nil")
}
if httpx.ErrRetryExhausted.Error() == "" {
t.Fatal("ErrRetryExhausted has empty message")
}
})
t.Run("ErrCircuitOpen", func(t *testing.T) {
if httpx.ErrCircuitOpen == nil {
t.Fatal("ErrCircuitOpen is nil")
}
})
t.Run("ErrNoHealthy", func(t *testing.T) {
if httpx.ErrNoHealthy == nil {
t.Fatal("ErrNoHealthy is nil")
}
})
}

View File

@@ -0,0 +1,67 @@
// Basic HTTP client with retry, timeout, and structured logging.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"time"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/retry"
)
func main() {
client := httpx.New(
httpx.WithBaseURL("https://httpbin.org"),
httpx.WithTimeout(10*time.Second),
httpx.WithRetry(
retry.WithMaxAttempts(3),
retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 2*time.Second, true)),
),
httpx.WithMiddleware(
middleware.UserAgent("httpx-example/1.0"),
),
httpx.WithMaxResponseBody(1<<20), // 1 MB limit
httpx.WithLogger(slog.Default()),
)
defer client.Close()
ctx := context.Background()
// GET request.
resp, err := client.Get(ctx, "/get")
if err != nil {
log.Fatal(err)
}
body, err := resp.String()
if err != nil {
log.Fatal(err)
}
fmt.Printf("GET /get → %d (%d bytes)\n", resp.StatusCode, len(body))
// POST with JSON body.
type payload struct {
Name string `json:"name"`
Email string `json:"email"`
}
req, err := httpx.NewJSONRequest(ctx, "POST", "/post", payload{
Name: "Alice",
Email: "alice@example.com",
})
if err != nil {
log.Fatal(err)
}
resp, err = client.Do(ctx, req)
if err != nil {
log.Fatal(err)
}
defer resp.Close()
fmt.Printf("POST /post → %d\n", resp.StatusCode)
}

View File

@@ -0,0 +1,41 @@
// Demonstrates form-encoded requests for OAuth token endpoints and similar APIs.
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/url"
"git.codelab.vc/pkg/httpx"
)
func main() {
client := httpx.New(
httpx.WithBaseURL("https://httpbin.org"),
)
defer client.Close()
ctx := context.Background()
// Form-encoded POST, common for OAuth token endpoints.
req, err := httpx.NewFormRequest(ctx, http.MethodPost, "/post", url.Values{
"grant_type": {"client_credentials"},
"client_id": {"my-app"},
"client_secret": {"secret"},
"scope": {"read write"},
})
if err != nil {
log.Fatal(err)
}
resp, err := client.Do(ctx, req)
if err != nil {
log.Fatal(err)
}
defer resp.Close()
body, _ := resp.String()
fmt.Printf("Status: %d\nBody: %s\n", resp.StatusCode, body)
}

View File

@@ -0,0 +1,54 @@
// Demonstrates load balancing across multiple backend endpoints with
// circuit breaking and health-checked failover.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"time"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/circuitbreaker"
"git.codelab.vc/pkg/httpx/retry"
)
func main() {
client := httpx.New(
httpx.WithEndpoints(
balancer.Endpoint{URL: "http://localhost:8081", Weight: 3},
balancer.Endpoint{URL: "http://localhost:8082", Weight: 1},
),
httpx.WithBalancer(
balancer.WithStrategy(balancer.WeightedRandom()),
balancer.WithHealthCheck(
balancer.WithHealthInterval(5*time.Second),
balancer.WithHealthPath("/healthz"),
),
),
httpx.WithCircuitBreaker(
circuitbreaker.WithFailureThreshold(5),
circuitbreaker.WithOpenDuration(30*time.Second),
),
httpx.WithRetry(
retry.WithMaxAttempts(3),
),
httpx.WithLogger(slog.Default()),
)
defer client.Close()
ctx := context.Background()
for i := range 5 {
resp, err := client.Get(ctx, fmt.Sprintf("/api/item/%d", i))
if err != nil {
log.Printf("request %d failed: %v", i, err)
continue
}
body, _ := resp.String()
fmt.Printf("request %d → %d: %s\n", i, resp.StatusCode, body)
}
}

View File

@@ -0,0 +1,54 @@
// Demonstrates request ID propagation between server and client.
// The server assigns a request ID to incoming requests, and the client
// middleware forwards it to downstream services via X-Request-Id header.
package main
import (
"fmt"
"log"
"log/slog"
"net/http"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
// Client that propagates request IDs from context.
client := httpx.New(
httpx.WithBaseURL("http://localhost:9090"),
httpx.WithMiddleware(
middleware.RequestID(), // Picks up ID from context, sets X-Request-Id.
),
)
defer client.Close()
r := server.NewRouter()
r.HandleFunc("GET /proxy", func(w http.ResponseWriter, r *http.Request) {
// The request ID is in r.Context() thanks to server.RequestID().
id := server.RequestIDFromContext(r.Context())
logger.Info("handling request", "request_id", id)
// Client automatically forwards the request ID to downstream.
resp, err := client.Get(r.Context(), "/downstream")
if err != nil {
server.WriteError(w, http.StatusBadGateway, fmt.Sprintf("downstream error: %v", err))
return
}
defer resp.Close()
body, _ := resp.String()
server.WriteJSON(w, http.StatusOK, map[string]string{
"request_id": id,
"downstream_response": body,
})
})
// Server with RequestID middleware that assigns IDs to incoming requests.
srv := server.New(r, server.Defaults(logger)...)
log.Fatal(srv.ListenAndServe())
}

View File

@@ -0,0 +1,54 @@
// Basic HTTP server with routing, middleware, health checks, and graceful shutdown.
package main
import (
"log"
"log/slog"
"net/http"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
r := server.NewRouter(
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, http.StatusNotFound, "resource not found")
})),
)
// Public endpoints.
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
server.WriteJSON(w, http.StatusOK, map[string]string{
"message": "Hello, World!",
})
})
// API group with shared prefix.
api := r.Group("/api/v1")
api.HandleFunc("GET /users/{id}", getUser)
api.HandleFunc("POST /users", createUser)
// Health checks.
r.Mount("/", server.HealthHandler())
// Server with production defaults (RequestID → Recovery → Logging).
srv := server.New(r, server.Defaults(logger)...)
log.Fatal(srv.ListenAndServe())
}
func getUser(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
server.WriteJSON(w, http.StatusOK, map[string]string{
"id": id,
"name": "Alice",
})
}
func createUser(w http.ResponseWriter, r *http.Request) {
server.WriteJSON(w, http.StatusCreated, map[string]any{
"id": 1,
"message": "user created",
})
}

View File

@@ -0,0 +1,61 @@
// Production server with CORS, rate limiting, body size limits, and timeouts.
package main
import (
"log"
"log/slog"
"net/http"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
r := server.NewRouter(
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, http.StatusNotFound, "not found")
})),
)
r.HandleFunc("GET /api/data", func(w http.ResponseWriter, _ *http.Request) {
server.WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
})
r.HandleFunc("POST /api/upload", func(w http.ResponseWriter, r *http.Request) {
// Body is already limited by MaxBodySize middleware.
server.WriteJSON(w, http.StatusAccepted, map[string]string{"status": "received"})
})
r.Mount("/", server.HealthHandler())
srv := server.New(r,
append(
server.Defaults(logger),
server.WithMiddleware(
// CORS for browser-facing APIs.
server.CORS(
server.AllowOrigins("https://app.example.com", "https://admin.example.com"),
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
server.AllowHeaders("Authorization", "Content-Type"),
server.ExposeHeaders("X-Request-Id"),
server.AllowCredentials(true),
server.MaxAge(3600),
),
// Rate limit: 100 req/s per IP, burst of 200.
server.RateLimit(
server.WithRate(100),
server.WithBurst(200),
),
// Limit request body to 1 MB.
server.MaxBodySize(1<<20),
// Per-request timeout of 30 seconds.
server.Timeout(30*time.Second),
),
server.WithAddr(":8080"),
)...,
)
log.Fatal(srv.ListenAndServe())
}

3
go.mod Normal file
View File

@@ -0,0 +1,3 @@
module git.codelab.vc/pkg/httpx
go 1.24

175
internal/clock/clock.go Normal file
View File

@@ -0,0 +1,175 @@
package clock
import (
"sync"
"time"
)
// Clock abstracts time operations for deterministic testing.
type Clock interface {
Now() time.Time
Since(t time.Time) time.Duration
NewTimer(d time.Duration) Timer
After(d time.Duration) <-chan time.Time
}
// Timer abstracts time.Timer for testability.
type Timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// System returns a Clock backed by the real system time.
func System() Clock { return systemClock{} }
type systemClock struct{}
func (systemClock) Now() time.Time { return time.Now() }
func (systemClock) Since(t time.Time) time.Duration { return time.Since(t) }
func (systemClock) NewTimer(d time.Duration) Timer { return &systemTimer{t: time.NewTimer(d)} }
func (systemClock) After(d time.Duration) <-chan time.Time { return time.After(d) }
type systemTimer struct{ t *time.Timer }
func (s *systemTimer) C() <-chan time.Time { return s.t.C }
func (s *systemTimer) Stop() bool { return s.t.Stop() }
func (s *systemTimer) Reset(d time.Duration) bool { return s.t.Reset(d) }
// Mock returns a manually-controlled Clock for tests.
func Mock(now time.Time) *MockClock {
return &MockClock{now: now}
}
// MockClock is a deterministic clock for testing.
type MockClock struct {
mu sync.Mutex
now time.Time
timers []*mockTimer
}
func (m *MockClock) Now() time.Time {
m.mu.Lock()
defer m.mu.Unlock()
return m.now
}
func (m *MockClock) Since(t time.Time) time.Duration {
m.mu.Lock()
defer m.mu.Unlock()
return m.now.Sub(t)
}
func (m *MockClock) NewTimer(d time.Duration) Timer {
m.mu.Lock()
defer m.mu.Unlock()
t := &mockTimer{
clock: m,
ch: make(chan time.Time, 1),
deadline: m.now.Add(d),
active: true,
}
m.timers = append(m.timers, t)
if d <= 0 {
t.fire(m.now)
}
return t
}
func (m *MockClock) After(d time.Duration) <-chan time.Time {
return m.NewTimer(d).C()
}
// Advance moves the clock forward by d and fires any expired timers.
func (m *MockClock) Advance(d time.Duration) {
m.mu.Lock()
m.now = m.now.Add(d)
now := m.now
m.mu.Unlock()
m.fireExpired(now)
}
// Set sets the clock to an absolute time and fires any expired timers.
func (m *MockClock) Set(t time.Time) {
m.mu.Lock()
m.now = t
now := m.now
m.mu.Unlock()
m.fireExpired(now)
}
// fireExpired fires all active timers whose deadline has passed, then
// removes inactive timers to prevent unbounded growth.
func (m *MockClock) fireExpired(now time.Time) {
m.mu.Lock()
timers := m.timers
m.mu.Unlock()
for _, t := range timers {
t.mu.Lock()
if t.active && !now.Before(t.deadline) {
t.fire(now)
}
t.mu.Unlock()
}
// Compact: remove inactive timers. Use a new slice to avoid aliasing
// the backing array (NewTimer may have appended between snapshots).
m.mu.Lock()
n := len(m.timers)
active := make([]*mockTimer, 0, n)
for _, t := range m.timers {
t.mu.Lock()
keep := t.active
t.mu.Unlock()
if keep {
active = append(active, t)
}
}
m.timers = active
m.mu.Unlock()
}
type mockTimer struct {
mu sync.Mutex
clock *MockClock
ch chan time.Time
deadline time.Time
active bool
}
func (t *mockTimer) C() <-chan time.Time { return t.ch }
func (t *mockTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
was := t.active
t.active = false
return was
}
func (t *mockTimer) Reset(d time.Duration) bool {
// Acquire clock lock first to match the lock ordering in fireExpired
// (clock.mu → t.mu), preventing deadlock.
t.clock.mu.Lock()
deadline := t.clock.now.Add(d)
t.clock.mu.Unlock()
t.mu.Lock()
defer t.mu.Unlock()
was := t.active
t.active = true
t.deadline = deadline
return was
}
// fire sends the time on the channel. Caller must hold t.mu.
func (t *mockTimer) fire(now time.Time) {
t.active = false
select {
case t.ch <- now:
default:
}
}

View File

@@ -0,0 +1,19 @@
// Package requestid provides a shared context key for request IDs,
// allowing both client and server packages to access request IDs
// without circular imports.
package requestid
import "context"
type key struct{}
// NewContext returns a context with the given request ID.
func NewContext(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, key{}, id)
}
// FromContext returns the request ID from ctx, or empty string if not set.
func FromContext(ctx context.Context) string {
id, _ := ctx.Value(key{}).(string)
return id
}

33
middleware/auth.go Normal file
View File

@@ -0,0 +1,33 @@
package middleware
import (
"context"
"net/http"
)
// BearerAuth returns a middleware that sets the Authorization header to a
// Bearer token obtained by calling tokenFunc on each request. If tokenFunc
// returns an error, the request is not sent and the error is returned.
func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenFunc(req.Context())
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
return next.RoundTrip(req)
})
}
}
// BasicAuth returns a middleware that sets HTTP Basic Authentication
// credentials on every outgoing request.
func BasicAuth(username, password string) Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req.SetBasicAuth(username, password)
return next.RoundTrip(req)
})
}
}

89
middleware/auth_test.go Normal file
View File

@@ -0,0 +1,89 @@
package middleware_test
import (
"context"
"errors"
"net/http"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestBearerAuth(t *testing.T) {
t.Run("sets authorization header", func(t *testing.T) {
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
tokenFunc := func(_ context.Context) (string, error) {
return "my-secret-token", nil
}
transport := middleware.BearerAuth(tokenFunc)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "Bearer my-secret-token"
if got := captured.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
base := mockTransport(func(req *http.Request) (*http.Response, error) {
t.Fatal("base transport should not be called")
return nil, nil
})
tokenErr := errors.New("token expired")
tokenFunc := func(_ context.Context) (string, error) {
return "", tokenErr
}
transport := middleware.BearerAuth(tokenFunc)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tokenErr) {
t.Errorf("got error %v, want %v", err, tokenErr)
}
})
}
func TestBasicAuth(t *testing.T) {
t.Run("sets basic auth header", func(t *testing.T) {
var capturedReq *http.Request
base := mockTransport(func(req *http.Request) (*http.Response, error) {
capturedReq = req
return okResponse(), nil
})
transport := middleware.BasicAuth("user", "pass")(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
username, password, ok := capturedReq.BasicAuth()
if !ok {
t.Fatal("BasicAuth() returned ok=false")
}
if username != "user" {
t.Errorf("username = %q, want %q", username, "user")
}
if password != "pass" {
t.Errorf("password = %q, want %q", password, "pass")
}
})
}

28
middleware/doc.go Normal file
View File

@@ -0,0 +1,28 @@
// Package middleware provides client-side HTTP middleware for use with
// httpx.Client or any http.RoundTripper-based transport chain.
//
// Each middleware is a function of type func(http.RoundTripper) http.RoundTripper.
// Compose them with Chain:
//
// chain := middleware.Chain(
// middleware.Logging(logger),
// middleware.Recovery(),
// middleware.UserAgent("my-service/1.0"),
// )
// transport := chain(http.DefaultTransport)
//
// # Available middleware
//
// - Logging — structured request/response logging via slog
// - Recovery — panic recovery, converts panics to errors
// - DefaultHeaders — adds default headers to outgoing requests
// - UserAgent — sets User-Agent header
// - BearerAuth — dynamic Bearer token authentication
// - BasicAuth — HTTP Basic authentication
// - RequestID — propagates request ID from context to X-Request-Id header
//
// # RoundTripperFunc
//
// RoundTripperFunc adapts plain functions to http.RoundTripper, similar to
// http.HandlerFunc. Useful for testing and inline middleware.
package middleware

29
middleware/headers.go Normal file
View File

@@ -0,0 +1,29 @@
package middleware
import "net/http"
// DefaultHeaders returns a middleware that adds the given headers to every
// outgoing request. Existing headers on the request are not overwritten.
func DefaultHeaders(headers http.Header) Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
for key, values := range headers {
if req.Header.Get(key) != "" {
continue
}
for _, v := range values {
req.Header.Add(key, v)
}
}
return next.RoundTrip(req)
})
}
}
// UserAgent returns a middleware that sets the User-Agent header on every
// outgoing request, unless one is already present.
func UserAgent(ua string) Middleware {
return DefaultHeaders(http.Header{
"User-Agent": {ua},
})
}

107
middleware/headers_test.go Normal file
View File

@@ -0,0 +1,107 @@
package middleware_test
import (
"net/http"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestDefaultHeaders(t *testing.T) {
t.Run("adds headers without overwriting existing", func(t *testing.T) {
defaults := http.Header{
"X-Custom": {"default-value"},
"X-Untouched": {"from-middleware"},
}
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
transport := middleware.DefaultHeaders(defaults)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("X-Custom", "request-value")
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := captured.Get("X-Custom"); got != "request-value" {
t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value")
}
if got := captured.Get("X-Untouched"); got != "from-middleware" {
t.Errorf("X-Untouched = %q, want %q", got, "from-middleware")
}
})
t.Run("adds headers when absent", func(t *testing.T) {
defaults := http.Header{
"Accept": {"application/json"},
}
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
transport := middleware.DefaultHeaders(defaults)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := captured.Get("Accept"); got != "application/json" {
t.Errorf("Accept = %q, want %q", got, "application/json")
}
})
}
func TestUserAgent(t *testing.T) {
t.Run("sets user agent header", func(t *testing.T) {
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
transport := middleware.UserAgent("httpx/1.0")(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := captured.Get("User-Agent"); got != "httpx/1.0" {
t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0")
}
})
t.Run("does not overwrite existing user agent", func(t *testing.T) {
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
transport := middleware.UserAgent("httpx/1.0")(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("User-Agent", "custom-agent")
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := captured.Get("User-Agent"); got != "custom-agent" {
t.Errorf("User-Agent = %q, want %q", got, "custom-agent")
}
})
}

38
middleware/logging.go Normal file
View File

@@ -0,0 +1,38 @@
package middleware
import (
"log/slog"
"net/http"
"time"
)
// Logging returns a middleware that logs each request's method, URL, status
// code, duration, and error (if any) using the provided structured logger.
// Successful responses are logged at Info level; errors at Error level.
func Logging(logger *slog.Logger) Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
start := time.Now()
resp, err := next.RoundTrip(req)
duration := time.Since(start)
attrs := []slog.Attr{
slog.String("method", req.Method),
slog.String("url", req.URL.String()),
slog.Duration("duration", duration),
}
if err != nil {
attrs = append(attrs, slog.String("error", err.Error()))
logger.LogAttrs(req.Context(), slog.LevelError, "request failed", attrs...)
return resp, err
}
attrs = append(attrs, slog.Int("status", resp.StatusCode))
logger.LogAttrs(req.Context(), slog.LevelInfo, "request completed", attrs...)
return resp, nil
})
}
}

124
middleware/logging_test.go Normal file
View File

@@ -0,0 +1,124 @@
package middleware_test
import (
"context"
"errors"
"io"
"log/slog"
"net/http"
"strings"
"sync"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
// captureHandler is a slog.Handler that captures log records for inspection.
type captureHandler struct {
mu sync.Mutex
records []slog.Record
}
func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
h.records = append(h.records, r)
return nil
}
func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
func (h *captureHandler) WithGroup(_ string) slog.Handler { return h }
func (h *captureHandler) lastRecord() slog.Record {
h.mu.Lock()
defer h.mu.Unlock()
return h.records[len(h.records)-1]
}
func TestLogging(t *testing.T) {
t.Run("logs method url status duration on success", func(t *testing.T) {
handler := &captureHandler{}
logger := slog.New(handler)
base := mockTransport(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
Header: make(http.Header),
}, nil
})
transport := middleware.Logging(logger)(base)
req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil)
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
rec := handler.lastRecord()
if rec.Level != slog.LevelInfo {
t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo)
}
attrs := map[string]string{}
rec.Attrs(func(a slog.Attr) bool {
attrs[a.Key] = a.Value.String()
return true
})
if attrs["method"] != "POST" {
t.Errorf("method = %q, want %q", attrs["method"], "POST")
}
if attrs["url"] != "http://example.com/api" {
t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api")
}
if _, ok := attrs["status"]; !ok {
t.Error("missing status attribute")
}
if _, ok := attrs["duration"]; !ok {
t.Error("missing duration attribute")
}
})
t.Run("logs error on failure", func(t *testing.T) {
handler := &captureHandler{}
logger := slog.New(handler)
base := mockTransport(func(req *http.Request) (*http.Response, error) {
return nil, errors.New("connection refused")
})
transport := middleware.Logging(logger)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil)
_, err := transport.RoundTrip(req)
if err == nil {
t.Fatal("expected error, got nil")
}
rec := handler.lastRecord()
if rec.Level != slog.LevelError {
t.Errorf("got level %v, want %v", rec.Level, slog.LevelError)
}
attrs := map[string]string{}
rec.Attrs(func(a slog.Attr) bool {
attrs[a.Key] = a.Value.String()
return true
})
if attrs["error"] != "connection refused" {
t.Errorf("error = %q, want %q", attrs["error"], "connection refused")
}
if _, ok := attrs["method"]; !ok {
t.Error("missing method attribute")
}
if _, ok := attrs["url"]; !ok {
t.Error("missing url attribute")
}
})
}

29
middleware/middleware.go Normal file
View File

@@ -0,0 +1,29 @@
package middleware
import "net/http"
// Middleware wraps an http.RoundTripper to add behavior.
// This is the fundamental building block of the httpx library.
type Middleware func(http.RoundTripper) http.RoundTripper
// Chain composes middlewares so that Chain(A, B, C)(base) == A(B(C(base))).
// Middlewares are applied from right to left: C wraps base first, then B wraps
// the result, then A wraps last. This means A is the outermost layer and sees
// every request first.
func Chain(mws ...Middleware) Middleware {
return func(rt http.RoundTripper) http.RoundTripper {
for i := len(mws) - 1; i >= 0; i-- {
rt = mws[i](rt)
}
return rt
}
}
// RoundTripperFunc is an adapter to allow the use of ordinary functions as
// http.RoundTripper. It works exactly like http.HandlerFunc for handlers.
type RoundTripperFunc func(*http.Request) (*http.Response, error)
// RoundTrip implements http.RoundTripper.
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

View File

@@ -0,0 +1,115 @@
package middleware_test
import (
"io"
"net/http"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
Header: make(http.Header),
}
}
func TestChain(t *testing.T) {
t.Run("applies middlewares in correct order", func(t *testing.T) {
var order []string
mwA := func(next http.RoundTripper) http.RoundTripper {
return mockTransport(func(req *http.Request) (*http.Response, error) {
order = append(order, "A-before")
resp, err := next.RoundTrip(req)
order = append(order, "A-after")
return resp, err
})
}
mwB := func(next http.RoundTripper) http.RoundTripper {
return mockTransport(func(req *http.Request) (*http.Response, error) {
order = append(order, "B-before")
resp, err := next.RoundTrip(req)
order = append(order, "B-after")
return resp, err
})
}
mwC := func(next http.RoundTripper) http.RoundTripper {
return mockTransport(func(req *http.Request) (*http.Response, error) {
order = append(order, "C-before")
resp, err := next.RoundTrip(req)
order = append(order, "C-after")
return resp, err
})
}
base := mockTransport(func(req *http.Request) (*http.Response, error) {
order = append(order, "base")
return okResponse(), nil
})
chained := middleware.Chain(mwA, mwB, mwC)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := chained.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
})
t.Run("empty chain returns base transport", func(t *testing.T) {
called := false
base := mockTransport(func(req *http.Request) (*http.Response, error) {
called = true
return okResponse(), nil
})
chained := middleware.Chain()(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := chained.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Fatal("base transport was not called")
}
})
}
func TestRoundTripperFunc(t *testing.T) {
t.Run("implements RoundTripper", func(t *testing.T) {
fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return okResponse(), nil
})
var rt http.RoundTripper = fn
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
})
}

22
middleware/recovery.go Normal file
View File

@@ -0,0 +1,22 @@
package middleware
import (
"fmt"
"net/http"
)
// Recovery returns a middleware that recovers from panics in the inner
// RoundTripper chain. A recovered panic is converted to an error wrapping
// the panic value.
func Recovery() Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic recovered in round trip: %v", r)
}
}()
return next.RoundTrip(req)
})
}
}

View File

@@ -0,0 +1,51 @@
package middleware_test
import (
"net/http"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestRecovery(t *testing.T) {
t.Run("recovers from panic and returns error", func(t *testing.T) {
base := mockTransport(func(req *http.Request) (*http.Response, error) {
panic("something went wrong")
})
transport := middleware.Recovery()(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := transport.RoundTrip(req)
if err == nil {
t.Fatal("expected error, got nil")
}
if resp != nil {
t.Errorf("expected nil response, got %v", resp)
}
if !strings.Contains(err.Error(), "panic recovered") {
t.Errorf("error = %q, want it to contain %q", err.Error(), "panic recovered")
}
if !strings.Contains(err.Error(), "something went wrong") {
t.Errorf("error = %q, want it to contain %q", err.Error(), "something went wrong")
}
})
t.Run("passes through normal responses", func(t *testing.T) {
base := mockTransport(func(req *http.Request) (*http.Response, error) {
return okResponse(), nil
})
transport := middleware.Recovery()(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
})
}

23
middleware/requestid.go Normal file
View File

@@ -0,0 +1,23 @@
package middleware
import (
"net/http"
"git.codelab.vc/pkg/httpx/internal/requestid"
)
// RequestID returns a middleware that propagates the request ID from the
// request context to the outgoing X-Request-Id header. This pairs with
// the server.RequestID middleware: the server stores the ID in the context,
// and the client middleware forwards it to downstream services.
func RequestID() Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
if id := requestid.FromContext(req.Context()); id != "" {
req = req.Clone(req.Context())
req.Header.Set("X-Request-Id", id)
}
return next.RoundTrip(req)
})
}
}

View File

@@ -0,0 +1,69 @@
package middleware_test
import (
"context"
"net/http"
"testing"
"git.codelab.vc/pkg/httpx/internal/requestid"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestRequestID(t *testing.T) {
t.Run("propagates ID from context", func(t *testing.T) {
var gotHeader string
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotHeader = req.Header.Get("X-Request-Id")
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
ctx := requestid.NewContext(context.Background(), "test-id-123")
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
_, err := mw.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotHeader != "test-id-123" {
t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123")
}
})
t.Run("no ID in context skips header", func(t *testing.T) {
var gotHeader string
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotHeader = req.Header.Get("X-Request-Id")
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
_, err := mw.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotHeader != "" {
t.Fatalf("expected no X-Request-Id header, got %q", gotHeader)
}
})
t.Run("does not mutate original request", func(t *testing.T) {
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
ctx := requestid.NewContext(context.Background(), "test-id")
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
_, _ = mw.RoundTrip(req)
if req.Header.Get("X-Request-Id") != "" {
t.Fatal("original request was mutated")
}
})
}

52
request.go Normal file
View File

@@ -0,0 +1,52 @@
package httpx
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
)
// NewRequest creates an http.Request with context. It is a convenience
// wrapper around http.NewRequestWithContext.
func NewRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
return http.NewRequestWithContext(ctx, method, url, body)
}
// NewJSONRequest creates an http.Request with a JSON-encoded body and
// sets Content-Type to application/json.
func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Request, error) {
b, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("httpx: encoding JSON body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(b))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(b)), nil
}
return req, nil
}
// NewFormRequest creates an http.Request with a form-encoded body and
// sets Content-Type to application/x-www-form-urlencoded.
// The GetBody function is set so that the request can be retried.
func NewFormRequest(ctx context.Context, method, rawURL string, values url.Values) (*http.Request, error) {
encoded := values.Encode()
b := []byte(encoded)
req, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(b))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(b)), nil
}
return req, nil
}

80
request_form_test.go Normal file
View File

@@ -0,0 +1,80 @@
package httpx_test
import (
"context"
"io"
"net/http"
"net/url"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestNewFormRequest(t *testing.T) {
t.Run("body is form-encoded", func(t *testing.T) {
values := url.Values{"username": {"alice"}, "scope": {"read"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com/token", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
parsed, err := url.ParseQuery(string(body))
if err != nil {
t.Fatalf("parsing form: %v", err)
}
if parsed.Get("username") != "alice" {
t.Errorf("username = %q, want %q", parsed.Get("username"), "alice")
}
if parsed.Get("scope") != "read" {
t.Errorf("scope = %q, want %q", parsed.Get("scope"), "read")
}
})
t.Run("content type is set", func(t *testing.T) {
values := url.Values{"key": {"value"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ct := req.Header.Get("Content-Type")
if ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type = %q, want %q", ct, "application/x-www-form-urlencoded")
}
})
t.Run("GetBody works for retry", func(t *testing.T) {
values := url.Values{"key": {"value"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if req.GetBody == nil {
t.Fatal("GetBody is nil")
}
b1, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
body2, err := req.GetBody()
if err != nil {
t.Fatalf("GetBody(): %v", err)
}
b2, err := io.ReadAll(body2)
if err != nil {
t.Fatalf("reading body2: %v", err)
}
if string(b1) != string(b2) {
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
}
})
}

78
request_test.go Normal file
View File

@@ -0,0 +1,78 @@
package httpx_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestNewJSONRequest(t *testing.T) {
t.Run("body is JSON encoded", func(t *testing.T) {
payload := map[string]string{"key": "value"}
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
var decoded map[string]string
if err := json.Unmarshal(body, &decoded); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if decoded["key"] != "value" {
t.Errorf("decoded[key] = %q, want %q", decoded["key"], "value")
}
})
t.Run("content type is set", func(t *testing.T) {
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ct := req.Header.Get("Content-Type")
if ct != "application/json" {
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
}
})
t.Run("GetBody works", func(t *testing.T) {
payload := map[string]int{"num": 123}
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if req.GetBody == nil {
t.Fatal("GetBody is nil")
}
// Read body first time
b1, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
// Get a fresh body
body2, err := req.GetBody()
if err != nil {
t.Fatalf("GetBody(): %v", err)
}
b2, err := io.ReadAll(body2)
if err != nil {
t.Fatalf("reading body2: %v", err)
}
if string(b1) != string(b2) {
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
}
})
}

114
response.go Normal file
View File

@@ -0,0 +1,114 @@
package httpx
import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
)
// Response wraps http.Response with convenience methods.
type Response struct {
*http.Response
body []byte
read bool
}
func newResponse(resp *http.Response) *Response {
return &Response{Response: resp}
}
// Bytes reads and returns the entire response body.
// The body is cached so subsequent calls return the same data.
func (r *Response) Bytes() ([]byte, error) {
if r.read {
return r.body, nil
}
defer r.Body.Close()
b, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.body = b
r.read = true
return b, nil
}
// String reads the response body and returns it as a string.
func (r *Response) String() (string, error) {
b, err := r.Bytes()
if err != nil {
return "", err
}
return string(b), nil
}
// JSON decodes the response body as JSON into v.
func (r *Response) JSON(v any) error {
b, err := r.Bytes()
if err != nil {
return fmt.Errorf("httpx: reading response body: %w", err)
}
if err := json.Unmarshal(b, v); err != nil {
return fmt.Errorf("httpx: decoding JSON: %w", err)
}
return nil
}
// XML decodes the response body as XML into v.
func (r *Response) XML(v any) error {
b, err := r.Bytes()
if err != nil {
return fmt.Errorf("httpx: reading response body: %w", err)
}
if err := xml.Unmarshal(b, v); err != nil {
return fmt.Errorf("httpx: decoding XML: %w", err)
}
return nil
}
// IsSuccess returns true if the status code is in the 2xx range.
func (r *Response) IsSuccess() bool {
return r.StatusCode >= 200 && r.StatusCode < 300
}
// IsError returns true if the status code is 4xx or 5xx.
func (r *Response) IsError() bool {
return r.StatusCode >= 400
}
// Close drains and closes the response body.
func (r *Response) Close() error {
if r.read {
return nil
}
_, _ = io.Copy(io.Discard, r.Body)
return r.Body.Close()
}
// BodyReader returns a reader for the response body.
// If the body has already been read via Bytes/String/JSON/XML,
// returns a reader over the cached bytes.
func (r *Response) BodyReader() io.Reader {
if r.read {
return bytes.NewReader(r.body)
}
return r.Body
}
// limitedReadCloser wraps an io.LimitedReader with a separate Closer
// so the original body can be closed.
type limitedReadCloser struct {
R io.LimitedReader
C io.Closer
}
func (l *limitedReadCloser) Read(p []byte) (int, error) {
return l.R.Read(p)
}
func (l *limitedReadCloser) Close() error {
return l.C.Close()
}

76
response_limit_test.go Normal file
View File

@@ -0,0 +1,76 @@
package httpx_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestClient_MaxResponseBody(t *testing.T) {
t.Run("allows response within limit", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, "hello")
}))
defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(1024))
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "hello" {
t.Fatalf("body = %q, want %q", body, "hello")
}
})
t.Run("truncates response exceeding limit", func(t *testing.T) {
largeBody := strings.Repeat("x", 1000)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, largeBody)
}))
defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(100))
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := resp.Bytes()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if len(b) != 100 {
t.Fatalf("body length = %d, want %d", len(b), 100)
}
})
t.Run("no limit when zero", func(t *testing.T) {
largeBody := strings.Repeat("x", 10000)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, largeBody)
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := resp.Bytes()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if len(b) != 10000 {
t.Fatalf("body length = %d, want %d", len(b), 10000)
}
})
}

96
response_test.go Normal file
View File

@@ -0,0 +1,96 @@
package httpx
import (
"io"
"net/http"
"strings"
"testing"
)
func makeTestResponse(statusCode int, body string) *Response {
return newResponse(&http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
})
}
func TestResponse(t *testing.T) {
t.Run("Bytes returns body", func(t *testing.T) {
r := makeTestResponse(200, "hello world")
b, err := r.Bytes()
if err != nil {
t.Fatalf("Bytes() error: %v", err)
}
if string(b) != "hello world" {
t.Errorf("Bytes() = %q, want %q", string(b), "hello world")
}
})
t.Run("body caching returns same data", func(t *testing.T) {
r := makeTestResponse(200, "cached body")
b1, err := r.Bytes()
if err != nil {
t.Fatalf("first Bytes() error: %v", err)
}
b2, err := r.Bytes()
if err != nil {
t.Fatalf("second Bytes() error: %v", err)
}
if string(b1) != string(b2) {
t.Errorf("Bytes() returned different data: %q vs %q", b1, b2)
}
})
t.Run("String returns body as string", func(t *testing.T) {
r := makeTestResponse(200, "string body")
s, err := r.String()
if err != nil {
t.Fatalf("String() error: %v", err)
}
if s != "string body" {
t.Errorf("String() = %q, want %q", s, "string body")
}
})
t.Run("JSON decodes body", func(t *testing.T) {
r := makeTestResponse(200, `{"name":"test","value":42}`)
var result struct {
Name string `json:"name"`
Value int `json:"value"`
}
if err := r.JSON(&result); err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result.Name != "test" {
t.Errorf("Name = %q, want %q", result.Name, "test")
}
if result.Value != 42 {
t.Errorf("Value = %d, want %d", result.Value, 42)
}
})
t.Run("IsSuccess for 2xx", func(t *testing.T) {
for _, code := range []int{200, 201, 204, 299} {
r := makeTestResponse(code, "")
if !r.IsSuccess() {
t.Errorf("IsSuccess() = false for status %d", code)
}
if r.IsError() {
t.Errorf("IsError() = true for status %d", code)
}
}
})
t.Run("IsError for 4xx and 5xx", func(t *testing.T) {
for _, code := range []int{400, 404, 500, 503} {
r := makeTestResponse(code, "")
if !r.IsError() {
t.Errorf("IsError() = false for status %d", code)
}
if r.IsSuccess() {
t.Errorf("IsSuccess() = true for status %d", code)
}
}
})
}

64
retry/backoff.go Normal file
View File

@@ -0,0 +1,64 @@
package retry
import (
"math/rand/v2"
"time"
)
// Backoff computes the delay before the next retry attempt.
type Backoff interface {
// Delay returns the wait duration for the given attempt number (zero-based).
Delay(attempt int) time.Duration
}
// ExponentialBackoff returns a Backoff that doubles the delay on each attempt.
// The delay is calculated as base * 2^attempt, capped at max. When withJitter
// is true, a random duration in [0, delay*0.5) is added.
func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff {
return &exponentialBackoff{
base: base,
max: max,
withJitter: withJitter,
}
}
// ConstantBackoff returns a Backoff that always returns the same delay.
func ConstantBackoff(d time.Duration) Backoff {
return constantBackoff{delay: d}
}
type exponentialBackoff struct {
base time.Duration
max time.Duration
withJitter bool
}
func (b *exponentialBackoff) Delay(attempt int) time.Duration {
delay := b.base
for range attempt {
delay *= 2
if delay >= b.max {
delay = b.max
break
}
}
if b.withJitter {
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
delay += jitter
}
if delay > b.max {
delay = b.max
}
return delay
}
type constantBackoff struct {
delay time.Duration
}
func (b constantBackoff) Delay(_ int) time.Duration {
return b.delay
}

77
retry/backoff_test.go Normal file
View File

@@ -0,0 +1,77 @@
package retry
import (
"testing"
"time"
)
func TestExponentialBackoff(t *testing.T) {
t.Run("doubles each attempt", func(t *testing.T) {
b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false)
want := []time.Duration{
100 * time.Millisecond, // attempt 0: base
200 * time.Millisecond, // attempt 1: base*2
400 * time.Millisecond, // attempt 2: base*4
800 * time.Millisecond, // attempt 3: base*8
1600 * time.Millisecond, // attempt 4: base*16
}
for i, expected := range want {
got := b.Delay(i)
if got != expected {
t.Errorf("attempt %d: expected %v, got %v", i, expected, got)
}
}
})
t.Run("caps at max", func(t *testing.T) {
b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false)
// attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms
for _, attempt := range []int{3, 4, 10} {
got := b.Delay(attempt)
if got != 500*time.Millisecond {
t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got)
}
}
})
t.Run("with jitter adds randomness", func(t *testing.T) {
base := 100 * time.Millisecond
b := ExponentialBackoff(base, 10*time.Second, true)
// Run multiple times; with jitter, delay >= base for attempt 0.
// Also verify not all values are identical (randomness).
seen := make(map[time.Duration]bool)
for range 20 {
d := b.Delay(0)
if d < base {
t.Fatalf("delay %v is less than base %v", d, base)
}
// With jitter: delay = base + rand in [0, base/2), so max is base*1.5
maxExpected := base + base/2
if d > maxExpected {
t.Fatalf("delay %v exceeds expected max %v", d, maxExpected)
}
seen[d] = true
}
if len(seen) < 2 {
t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen))
}
})
}
func TestConstantBackoff(t *testing.T) {
t.Run("always returns same value", func(t *testing.T) {
d := 250 * time.Millisecond
b := ConstantBackoff(d)
for _, attempt := range []int{0, 1, 2, 5, 100} {
got := b.Delay(attempt)
if got != d {
t.Errorf("attempt %d: expected %v, got %v", attempt, d, got)
}
}
})
}

31
retry/doc.go Normal file
View File

@@ -0,0 +1,31 @@
// Package retry provides configurable HTTP request retry as client middleware.
//
// The retry middleware wraps an http.RoundTripper and automatically retries
// failed requests based on a configurable policy, with exponential backoff
// and optional jitter.
//
// # Usage
//
// mw := retry.Transport(
// retry.WithMaxAttempts(3),
// retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 5*time.Second)),
// )
// transport := mw(http.DefaultTransport)
//
// # Retry-After
//
// The retry middleware respects the Retry-After response header. If a server
// returns 429 or 503 with Retry-After, the delay from the header overrides
// the backoff strategy.
//
// # Request bodies
//
// For requests with bodies to be retried, the request must have GetBody set.
// Use httpx.NewJSONRequest or httpx.NewFormRequest which set GetBody
// automatically.
//
// # Sentinel errors
//
// ErrRetryExhausted is returned when all attempts fail. The original error
// is wrapped and accessible via errors.Unwrap.
package retry

56
retry/options.go Normal file
View File

@@ -0,0 +1,56 @@
package retry
import "time"
type options struct {
maxAttempts int // default 3
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
retryAfter bool // default true, respect Retry-After header
}
// Option configures the retry transport.
type Option func(*options)
func defaults() options {
return options{
maxAttempts: 3,
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
policy: defaultPolicy{},
retryAfter: true,
}
}
// WithMaxAttempts sets the maximum number of attempts (including the first).
// Values less than 1 are treated as 1 (no retries).
func WithMaxAttempts(n int) Option {
return func(o *options) {
if n < 1 {
n = 1
}
o.maxAttempts = n
}
}
// WithBackoff sets the backoff strategy used to compute delays between retries.
func WithBackoff(b Backoff) Option {
return func(o *options) {
o.backoff = b
}
}
// WithPolicy sets the retry policy that decides whether to retry a request.
func WithPolicy(p Policy) Option {
return func(o *options) {
o.policy = p
}
}
// WithRetryAfter controls whether the Retry-After response header is respected.
// When enabled and present, the Retry-After delay is used if it exceeds the
// backoff delay.
func WithRetryAfter(enable bool) Option {
return func(o *options) {
o.retryAfter = enable
}
}

137
retry/retry.go Normal file
View File

@@ -0,0 +1,137 @@
package retry
import (
"errors"
"fmt"
"io"
"net/http"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
// ErrRetryExhausted is returned when all retry attempts have been exhausted
// and the last attempt also failed.
var ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted")
// Policy decides whether a failed request should be retried.
type Policy interface {
// ShouldRetry reports whether the request should be retried. The extra
// duration, if non-zero, is a policy-suggested delay that overrides the
// backoff strategy.
ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration)
}
// Transport returns a middleware that retries failed requests according to
// the provided options.
func Transport(opts ...Option) middleware.Middleware {
cfg := defaults()
for _, o := range opts {
o(&cfg)
}
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
var resp *http.Response
var err error
var exhausted bool
for attempt := range cfg.maxAttempts {
// For retries (attempt > 0), restore the request body.
if attempt > 0 {
if req.GetBody != nil {
body, bodyErr := req.GetBody()
if bodyErr != nil {
return resp, bodyErr
}
req.Body = body
} else if req.Body != nil {
// Body was consumed and cannot be re-created.
return resp, err
}
}
resp, err = next.RoundTrip(req)
// Last attempt — return whatever we got.
if attempt == cfg.maxAttempts-1 {
exhausted = true
break
}
shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err)
if !shouldRetry {
break
}
// Compute delay: use backoff or policy delay, whichever is larger.
delay := cfg.backoff.Delay(attempt)
if policyDelay > delay {
delay = policyDelay
}
// Respect Retry-After header if enabled.
if cfg.retryAfter && resp != nil {
if ra, ok := ParseRetryAfter(resp); ok && ra > delay {
delay = ra
}
}
// Drain and close the response body to release the connection.
if resp != nil {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}
// Wait for the delay or context cancellation.
timer := time.NewTimer(delay)
select {
case <-req.Context().Done():
timer.Stop()
return nil, req.Context().Err()
case <-timer.C:
}
}
// Wrap with ErrRetryExhausted only when all attempts were used.
if exhausted && err != nil {
err = fmt.Errorf("%w: %w", ErrRetryExhausted, err)
}
return resp, err
})
}
}
// defaultPolicy retries on network errors, 429, and 5xx server errors.
// It refuses to retry non-idempotent methods.
type defaultPolicy struct{}
func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
if !isIdempotent(req.Method) {
return false, 0
}
// Network error — always retry idempotent requests.
if err != nil {
return true, 0
}
switch resp.StatusCode {
case http.StatusTooManyRequests, // 429
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true, 0
}
return false, 0
}
// isIdempotent reports whether the HTTP method is safe to retry.
func isIdempotent(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
return true
}
return false
}

43
retry/retry_after.go Normal file
View File

@@ -0,0 +1,43 @@
package retry
import (
"net/http"
"strconv"
"time"
)
// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231).
// It supports both the delay-seconds format ("120") and the HTTP-date format
// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the
// header was present and valid; otherwise returns 0 and false.
func ParseRetryAfter(resp *http.Response) (time.Duration, bool) {
if resp == nil {
return 0, false
}
val := resp.Header.Get("Retry-After")
if val == "" {
return 0, false
}
// Try delay-seconds first (most common).
if seconds, err := strconv.ParseInt(val, 10, 64); err == nil {
if seconds < 0 {
return 0, false
}
return time.Duration(seconds) * time.Second, true
}
// Try HTTP-date format (RFC 7231 section 7.1.1.1).
t, err := http.ParseTime(val)
if err != nil {
return 0, false
}
delay := time.Until(t)
if delay < 0 {
// The date is in the past; no need to wait.
return 0, true
}
return delay, true
}

58
retry/retry_after_test.go Normal file
View File

@@ -0,0 +1,58 @@
package retry
import (
"net/http"
"testing"
"time"
)
func TestParseRetryAfter(t *testing.T) {
t.Run("seconds format", func(t *testing.T) {
resp := &http.Response{
Header: http.Header{"Retry-After": []string{"120"}},
}
d, ok := ParseRetryAfter(resp)
if !ok {
t.Fatal("expected ok=true")
}
if d != 120*time.Second {
t.Fatalf("expected 120s, got %v", d)
}
})
t.Run("empty header", func(t *testing.T) {
resp := &http.Response{
Header: make(http.Header),
}
d, ok := ParseRetryAfter(resp)
if ok {
t.Fatal("expected ok=false for empty header")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
t.Run("nil response", func(t *testing.T) {
d, ok := ParseRetryAfter(nil)
if ok {
t.Fatal("expected ok=false for nil response")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
t.Run("negative value", func(t *testing.T) {
resp := &http.Response{
Header: http.Header{"Retry-After": []string{"-5"}},
}
d, ok := ParseRetryAfter(resp)
if ok {
t.Fatal("expected ok=false for negative value")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
}

237
retry/retry_test.go Normal file
View File

@@ -0,0 +1,237 @@
package retry
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func statusResponse(code int) *http.Response {
return &http.Response{
StatusCode: code,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func TestTransport(t *testing.T) {
t.Run("successful request no retry", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call, got %d", got)
}
})
t.Run("retries on 503 then succeeds", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n < 3 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 3 {
t.Fatalf("expected 3 calls, got %d", got)
}
})
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
}
})
t.Run("stops on context cancellation", func(t *testing.T) {
var calls atomic.Int32
ctx, cancel := context.WithCancel(context.Background())
rt := Transport(
WithMaxAttempts(5),
WithBackoff(ConstantBackoff(50*time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
cancel()
}
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
}
})
t.Run("respects maxAttempts", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(2),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusBadGateway), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
}
})
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
var calls atomic.Int32
var bodies []string
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
b, _ := io.ReadAll(req.Body)
bodies = append(bodies, string(b))
if len(bodies) < 2 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
bodyContent := "request-body"
body := bytes.NewReader([]byte(bodyContent))
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
for i, b := range bodies {
if b != bodyContent {
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
}
}
})
t.Run("custom policy", func(t *testing.T) {
var calls atomic.Int32
// Custom policy: retry only on 418
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
if resp != nil && resp.StatusCode == http.StatusTeapot {
return true, 0
}
return false, 0
})
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
WithPolicy(custom),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
return statusResponse(http.StatusTeapot), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
})
}
// policyFunc adapts a function into a Policy.
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
return f(attempt, req, resp, err)
}

38
server/doc.go Normal file
View File

@@ -0,0 +1,38 @@
// Package server provides a production-ready HTTP server with graceful
// shutdown, middleware composition, routing, and JSON response helpers.
//
// # Server
//
// Server wraps http.Server with net.Listener, signal-based graceful shutdown
// (SIGINT/SIGTERM), and lifecycle hooks. It is configured via functional options:
//
// srv := server.New(handler,
// server.WithAddr(":8080"),
// server.Defaults(logger),
// )
// srv.ListenAndServe()
//
// # Router
//
// Router wraps http.ServeMux with middleware groups, prefix-based route groups,
// and sub-handler mounting. It supports Go 1.22+ method-based patterns:
//
// r := server.NewRouter()
// r.HandleFunc("GET /users/{id}", getUser)
//
// api := r.Group("/api/v1", authMiddleware)
// api.HandleFunc("GET /items", listItems)
//
// # Middleware
//
// Server middleware follows the func(http.Handler) http.Handler pattern.
// Available middleware: RequestID, Recovery, Logging, CORS, RateLimit,
// MaxBodySize, Timeout. Use Chain to compose them:
//
// chain := server.Chain(server.RequestID(), server.Recovery(logger), server.Logging(logger))
//
// # Response helpers
//
// WriteJSON and WriteError provide JSON response writing with proper
// Content-Type headers.
package server

55
server/health.go Normal file
View File

@@ -0,0 +1,55 @@
package server
import (
"encoding/json"
"net/http"
)
// ReadinessChecker is a function that reports whether a dependency is ready.
// Return nil if healthy, or an error describing the problem.
type ReadinessChecker func() error
// HealthHandler returns an http.Handler that exposes liveness and readiness
// endpoints:
//
// - GET /healthz — liveness check, always returns 200 OK
// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise
func HealthHandler(checkers ...ReadinessChecker) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
var errs []string
for _, check := range checkers {
if err := check(); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
w.WriteHeader(http.StatusServiceUnavailable)
_ = json.NewEncoder(w).Encode(healthResponse{
Status: "unavailable",
Errors: errs,
})
return
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
return mux
}
type healthResponse struct {
Status string `json:"status"`
Errors []string `json:"errors,omitempty"`
}

166
server/health_test.go Normal file
View File

@@ -0,0 +1,166 @@
package server_test
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestHealthHandler(t *testing.T) {
t.Run("liveness always returns 200", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "ok" {
t.Fatalf("got status %q, want %q", resp["status"], "ok")
}
})
t.Run("readiness returns 200 when all checks pass", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return nil },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("readiness returns 503 when a check fails", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return errors.New("db down") },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "unavailable" {
t.Fatalf("got status %q, want %q", resp["status"], "unavailable")
}
errs, ok := resp["errors"].([]any)
if !ok || len(errs) != 1 {
t.Fatalf("expected 1 error, got %v", resp["errors"])
}
if errs[0] != "db down" {
t.Fatalf("got error %q, want %q", errs[0], "db down")
}
})
t.Run("readiness returns 200 with no checkers", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestHealth_MultipleFailingCheckers(t *testing.T) {
h := server.HealthHandler(
func() error { return errors.New("db down") },
func() error { return errors.New("cache down") },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
errs, ok := resp["errors"].([]any)
if !ok || len(errs) != 2 {
t.Fatalf("expected 2 errors, got %v", resp["errors"])
}
errStrs := make(map[string]bool)
for _, e := range errs {
errStrs[e.(string)] = true
}
if !errStrs["db down"] || !errStrs["cache down"] {
t.Fatalf("expected 'db down' and 'cache down', got %v", errs)
}
}
func TestHealth_LivenessContentType(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
h.ServeHTTP(w, req)
ct := w.Header().Get("Content-Type")
if ct != "application/json" {
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
}
}
func TestHealth_ReadinessContentType(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
ct := w.Header().Get("Content-Type")
if ct != "application/json" {
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
}
}
func TestHealth_PostMethodNotAllowed(t *testing.T) {
h := server.HealthHandler()
for _, path := range []string{"/healthz", "/readyz"} {
t.Run("POST "+path, func(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, path, nil)
h.ServeHTTP(w, req)
// ServeMux with "GET /healthz" pattern should reject POST.
if w.Code == http.StatusOK {
t.Fatalf("POST %s should not return 200, got %d", path, w.Code)
}
})
}
}

56
server/middleware.go Normal file
View File

@@ -0,0 +1,56 @@
package server
import "net/http"
// Middleware wraps an http.Handler to add behavior.
// This is the server-side counterpart of the client middleware type
// func(http.RoundTripper) http.RoundTripper.
type Middleware func(http.Handler) http.Handler
// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))).
// Middlewares are applied from right to left: C wraps handler first, then B wraps
// the result, then A wraps last. This means A is the outermost layer and sees
// every request first.
func Chain(mws ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
for i := len(mws) - 1; i >= 0; i-- {
h = mws[i](h)
}
return h
}
}
// statusWriter wraps http.ResponseWriter to capture the response status code.
// It implements Unwrap() so that http.ResponseController can access the
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
type statusWriter struct {
http.ResponseWriter
status int
written bool
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (w *statusWriter) WriteHeader(code int) {
if !w.written {
w.status = code
w.written = true
}
w.ResponseWriter.WriteHeader(code)
}
// Write delegates to the underlying writer, defaulting status to 200 if
// WriteHeader was not called explicitly.
func (w *statusWriter) Write(b []byte) (int, error) {
if !w.written {
w.status = http.StatusOK
w.written = true
}
return w.ResponseWriter.Write(b)
}
// Unwrap returns the underlying ResponseWriter. This is required for
// http.ResponseController to detect optional interfaces like http.Flusher
// and http.Hijacker on the original writer.
func (w *statusWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@@ -0,0 +1,15 @@
package server
import "net/http"
// MaxBodySize returns a middleware that limits the size of incoming request
// bodies. If the body exceeds n bytes, the server returns 413 Request Entity
// Too Large. It wraps the body with http.MaxBytesReader.
func MaxBodySize(n int64) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, n)
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,61 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestMaxBodySize(t *testing.T) {
t.Run("allows body within limit", func(t *testing.T) {
handler := server.MaxBodySize(1024)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write(b)
}),
)
body := strings.NewReader("hello")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", body)
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if w.Body.String() != "hello" {
t.Fatalf("got body %q, want %q", w.Body.String(), "hello")
}
})
t.Run("rejects body exceeding limit", func(t *testing.T) {
handler := server.MaxBodySize(5)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}),
)
body := strings.NewReader("this is longer than 5 bytes")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", body)
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("got status %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
}

128
server/middleware_cors.go Normal file
View File

@@ -0,0 +1,128 @@
package server
import (
"net/http"
"strconv"
"strings"
)
type corsOptions struct {
allowOrigins []string
allowMethods []string
allowHeaders []string
exposeHeaders []string
allowCredentials bool
maxAge int
}
// CORSOption configures the CORS middleware.
type CORSOption func(*corsOptions)
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
// Default is no origins (CORS disabled).
func AllowOrigins(origins ...string) CORSOption {
return func(o *corsOptions) { o.allowOrigins = origins }
}
// AllowMethods sets the allowed HTTP methods for preflight requests.
// Default is GET, POST, HEAD.
func AllowMethods(methods ...string) CORSOption {
return func(o *corsOptions) { o.allowMethods = methods }
}
// AllowHeaders sets the allowed request headers for preflight requests.
func AllowHeaders(headers ...string) CORSOption {
return func(o *corsOptions) { o.allowHeaders = headers }
}
// ExposeHeaders sets headers that browsers are allowed to access.
func ExposeHeaders(headers ...string) CORSOption {
return func(o *corsOptions) { o.exposeHeaders = headers }
}
// AllowCredentials indicates whether the response to the request can be
// exposed when the credentials flag is true.
func AllowCredentials(allow bool) CORSOption {
return func(o *corsOptions) { o.allowCredentials = allow }
}
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
func MaxAge(seconds int) CORSOption {
return func(o *corsOptions) { o.maxAge = seconds }
}
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
// It processes preflight OPTIONS requests and sets the appropriate
// Access-Control-* response headers.
func CORS(opts ...CORSOption) Middleware {
o := &corsOptions{
allowMethods: []string{"GET", "POST", "HEAD"},
}
for _, opt := range opts {
opt(o)
}
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
allowAll := false
for _, origin := range o.allowOrigins {
if origin == "*" {
allowAll = true
}
allowedOrigins[origin] = struct{}{}
}
methods := strings.Join(o.allowMethods, ", ")
headers := strings.Join(o.allowHeaders, ", ")
expose := strings.Join(o.exposeHeaders, ", ")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
next.ServeHTTP(w, r)
return
}
allowed := allowAll
if !allowed {
_, allowed = allowedOrigins[origin]
}
if !allowed {
next.ServeHTTP(w, r)
return
}
// Set the allowed origin. When credentials are enabled,
// we must echo the specific origin, not "*".
if allowAll && !o.allowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if expose != "" {
w.Header().Set("Access-Control-Expose-Headers", expose)
}
w.Header().Add("Vary", "Origin")
// Handle preflight.
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
w.Header().Set("Access-Control-Allow-Methods", methods)
if headers != "" {
w.Header().Set("Access-Control-Allow-Headers", headers)
}
if o.maxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
}
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,143 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestCORS(t *testing.T) {
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Run("no Origin header passes through", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "" {
t.Fatal("expected no CORS headers without Origin")
}
})
t.Run("wildcard origin", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
}
})
t.Run("specific origin allowed", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
}
})
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://evil.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
}
})
t.Run("preflight OPTIONS", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("http://example.com"),
server.AllowMethods("GET", "POST", "PUT"),
server.AllowHeaders("Authorization", "Content-Type"),
server.MaxAge(3600),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
req.Header.Set("Origin", "http://example.com")
req.Header.Set("Access-Control-Request-Method", "POST")
mw.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
}
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
}
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
}
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
}
})
t.Run("credentials with specific origin", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("*"),
server.AllowCredentials(true),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
// With credentials, must echo specific origin even with wildcard config.
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
}
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
}
})
t.Run("expose headers", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("*"),
server.ExposeHeaders("X-Custom", "X-Request-Id"),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
}
})
t.Run("Vary header is set", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Vary"); got != "Origin" {
t.Fatalf("Vary = %q, want %q", got, "Origin")
}
})
}

View File

@@ -0,0 +1,39 @@
package server
import (
"log/slog"
"net/http"
"time"
)
// Logging returns a middleware that logs each request's method, path,
// status code, duration, and request ID using the provided structured logger.
func Logging(logger *slog.Logger) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
duration := time.Since(start)
attrs := []slog.Attr{
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", sw.status),
slog.Duration("duration", duration),
}
if id := RequestIDFromContext(r.Context()); id != "" {
attrs = append(attrs, slog.String("request_id", id))
}
level := slog.LevelInfo
if sw.status >= http.StatusInternalServerError {
level = slog.LevelError
}
logger.LogAttrs(r.Context(), level, "request completed", attrs...)
})
}
}

View File

@@ -0,0 +1,129 @@
package server
import (
"net"
"net/http"
"strconv"
"sync"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
type rateLimitOptions struct {
rate float64
burst int
keyFunc func(r *http.Request) string
clock clock.Clock
}
// RateLimitOption configures the RateLimit middleware.
type RateLimitOption func(*rateLimitOptions)
// WithRate sets the token refill rate (tokens per second).
func WithRate(tokensPerSecond float64) RateLimitOption {
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
}
// WithBurst sets the maximum burst size (bucket capacity).
func WithBurst(n int) RateLimitOption {
return func(o *rateLimitOptions) { o.burst = n }
}
// WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP address is used.
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn }
}
// withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c }
}
// RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header.
func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{
rate: 10,
burst: 20,
clock: clock.System(),
}
for _, opt := range opts {
opt(o)
}
if o.keyFunc == nil {
o.keyFunc = clientIP
}
var buckets sync.Map
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r)
val, _ := buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
b := val.(*bucket)
b.mu.Lock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
retryAfter := (1 - b.tokens) / o.rate
b.mu.Unlock()
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
b.tokens--
b.mu.Unlock()
next.ServeHTTP(w, r)
})
}
}
type bucket struct {
mu sync.Mutex
tokens float64
lastTime time.Time
}
// clientIP extracts the client IP from the request. It checks
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// First IP in the comma-separated list is the original client.
if i := indexOf(xff, ','); i > 0 {
return xff[:i]
}
return xff
}
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
return xri
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func indexOf(s string, b byte) int {
for i := range len(s) {
if s[i] == b {
return i
}
}
return -1
}

View File

@@ -0,0 +1,171 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestRateLimit(t *testing.T) {
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Run("allows requests within limit", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(100),
server.WithBurst(10),
)(okHandler)
for i := range 10 {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
}
}
})
t.Run("rejects when burst exhausted", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(2),
)(okHandler)
// Exhaust burst.
for range 2 {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
}
// Next request should be rejected.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
if w.Header().Get("Retry-After") == "" {
t.Fatal("expected Retry-After header")
}
})
t.Run("different IPs have independent limits", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
)(okHandler)
// First IP exhausts its limit.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
// Second IP should still be allowed.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "5.6.7.8:5678"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("uses X-Forwarded-For", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
)(okHandler)
// Exhaust limit for 10.0.0.1.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req)
// Same forwarded IP should be rate limited.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
})
t.Run("custom key function", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
server.WithKeyFunc(func(r *http.Request) string {
return r.Header.Get("X-API-Key")
}),
)(okHandler)
// Exhaust key "abc".
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "abc")
mw.ServeHTTP(w, req)
// Same key should be rate limited.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "abc")
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
// Different key should be allowed.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "xyz")
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("tokens refill over time", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1000), // Very fast refill for test
server.WithBurst(1),
)(okHandler)
// Exhaust burst.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
// Wait a bit for refill.
time.Sleep(5 * time.Millisecond)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
}
})
}

View File

@@ -0,0 +1,50 @@
package server
import (
"log/slog"
"net/http"
"runtime/debug"
)
// RecoveryOption configures the Recovery middleware.
type RecoveryOption func(*recoveryOptions)
type recoveryOptions struct {
logger *slog.Logger
}
// WithRecoveryLogger sets the logger for the Recovery middleware.
// If not set, panics are recovered silently (500 is still returned).
func WithRecoveryLogger(l *slog.Logger) RecoveryOption {
return func(o *recoveryOptions) { o.logger = l }
}
// Recovery returns a middleware that recovers from panics in downstream
// handlers. A recovered panic results in a 500 Internal Server Error
// response and is logged (if a logger is configured) with the stack trace.
func Recovery(opts ...RecoveryOption) Middleware {
o := &recoveryOptions{}
for _, opt := range opts {
opt(o)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if v := recover(); v != nil {
if o.logger != nil {
o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered",
slog.Any("panic", v),
slog.String("stack", string(debug.Stack())),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.String("request_id", RequestIDFromContext(r.Context())),
)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,51 @@
package server
import (
"context"
"crypto/rand"
"fmt"
"net/http"
"git.codelab.vc/pkg/httpx/internal/requestid"
)
// RequestID returns a middleware that assigns a unique request ID to each
// request. If the incoming request already has an X-Request-Id header, that
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
//
// The request ID is stored in the request context (retrieve with
// RequestIDFromContext) and set on the response X-Request-Id header.
func RequestID() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-Id")
if id == "" {
id = newUUID()
}
ctx := requestid.NewContext(r.Context(), id)
w.Header().Set("X-Request-Id", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequestIDFromContext returns the request ID from the context, or an empty
// string if none is set.
func RequestIDFromContext(ctx context.Context) string {
return requestid.FromContext(ctx)
}
// newUUID generates a UUID v4 string using crypto/rand.
func newUUID() string {
var uuid [16]byte
_, _ = rand.Read(uuid[:])
// Set version 4 (bits 12-15 of time_hi_and_version).
uuid[6] = (uuid[6] & 0x0f) | 0x40
// Set variant bits (10xx).
uuid[8] = (uuid[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
}

508
server/middleware_test.go Normal file
View File

@@ -0,0 +1,508 @@
package server_test
import (
"bytes"
"log/slog"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestChain(t *testing.T) {
t.Run("applies middlewares in correct order", func(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A-before")
next.ServeHTTP(w, r)
order = append(order, "A-after")
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B-before")
next.ServeHTTP(w, r)
order = append(order, "B-after")
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mwA, mwB)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
})
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
chained := server.Chain()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("handler was not called")
}
})
}
func TestChain_SingleMiddleware(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mw)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("single middleware was not called")
}
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
}
func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
w.WriteHeader(http.StatusNotFound) // second call should not change captured status
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=201") {
t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String())
}
}
func TestStatusWriter_WriteDefaultsTo200(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello")) // Write without WriteHeader
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String())
}
}
func TestStatusWriter_Unwrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
rc := http.NewResponseController(w)
if err := rc.Flush(); err != nil {
// httptest.ResponseRecorder implements Flusher, so this should succeed
// if Unwrap works correctly.
http.Error(w, "flush failed", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
})
// Use Logging to wrap in statusWriter
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code == http.StatusInternalServerError {
t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher")
}
}
func TestRequestID(t *testing.T) {
t.Run("generates ID when not present", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if gotID == "" {
t.Fatal("expected request ID in context, got empty")
}
if w.Header().Get("X-Request-Id") != gotID {
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
}
// UUID v4 format: 8-4-4-4-12 hex chars.
if len(gotID) != 36 {
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
}
})
t.Run("preserves existing ID", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", "custom-123")
mw.ServeHTTP(w, req)
if gotID != "custom-123" {
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
}
})
t.Run("context without ID returns empty", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if id := server.RequestIDFromContext(req.Context()); id != "" {
t.Fatalf("expected empty, got %q", id)
}
})
}
func TestRequestID_UUIDFormat(t *testing.T) {
uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !uuidV4Re.MatchString(gotID) {
t.Fatalf("generated ID %q does not match UUID v4 format", gotID)
}
}
func TestRequestID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 1000)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := server.RequestIDFromContext(r.Context())
if _, exists := seen[id]; exists {
t.Fatalf("duplicate request ID: %q", id)
}
seen[id] = struct{}{}
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
for i := 0; i < 1000; i++ {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
}
if len(seen) != 1000 {
t.Fatalf("expected 1000 unique IDs, got %d", len(seen))
}
}
func TestRecovery(t *testing.T) {
t.Run("recovers from panic and returns 500", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("something went wrong")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
})
t.Run("logs panic with logger", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
})
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
}
if !strings.Contains(buf.String(), "boom") {
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
}
})
t.Run("passes through without panic", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestRecovery_PanicWithNonString(t *testing.T) {
tests := []struct {
name string
value any
}{
{"integer", 42},
{"struct", struct{ X int }{X: 1}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic(tt.value)
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected 'panic recovered' in log, got %q", buf.String())
}
})
}
}
func TestRecovery_ResponseBody(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("fail")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
body := strings.TrimSpace(w.Body.String())
if body != "Internal Server Error" {
t.Fatalf("got body %q, want %q", body, "Internal Server Error")
}
}
func TestRecovery_LogAttributes(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
// Put RequestID before Recovery so request_id is in context
handler := server.RequestID()(
server.Recovery(server.WithRecoveryLogger(logger))(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
handler.ServeHTTP(w, req)
logOutput := buf.String()
for _, attr := range []string{"method=", "path=", "request_id="} {
if !strings.Contains(logOutput, attr) {
t.Fatalf("expected %q in log, got %q", attr, logOutput)
}
}
}
func TestLogging(t *testing.T) {
t.Run("logs request details", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "request completed") {
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "POST") {
t.Fatalf("expected method in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "/api/users") {
t.Fatalf("expected path in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "status=201") {
t.Fatalf("expected status=201 in log, got %q", logOutput)
}
})
t.Run("logs error level for 5xx", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("expected ERROR level in log, got %q", logOutput)
}
})
}
func TestLogging_4xxIsInfoLevel(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=INFO") {
t.Fatalf("expected INFO level for 404, got %q", logOutput)
}
if strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("404 should not be logged as ERROR, got %q", logOutput)
}
}
func TestLogging_DefaultStatus200(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello"))
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String())
}
}
func TestLogging_IncludesRequestID(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := server.RequestID()(
server.Logging(logger)(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "request_id=") {
t.Fatalf("expected request_id in log output, got %q", buf.String())
}
}

View File

@@ -0,0 +1,15 @@
package server
import (
"net/http"
"time"
)
// Timeout returns a middleware that limits request processing time.
// If the handler does not complete within d, the client receives a
// 503 Service Unavailable response. It wraps http.TimeoutHandler.
func Timeout(d time.Duration) Middleware {
return func(next http.Handler) http.Handler {
return http.TimeoutHandler(next, d, "Service Unavailable\n")
}
}

View File

@@ -0,0 +1,49 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestTimeout(t *testing.T) {
t.Run("handler completes within timeout", func(t *testing.T) {
handler := server.Timeout(1 * time.Second)(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("handler exceeds timeout returns 503", func(t *testing.T) {
handler := server.Timeout(10 * time.Millisecond)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-time.After(1 * time.Second):
case <-r.Context().Done():
}
w.WriteHeader(http.StatusOK)
}),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
})
}

89
server/options.go Normal file
View File

@@ -0,0 +1,89 @@
package server
import (
"log/slog"
"time"
)
type serverOptions struct {
addr string
readTimeout time.Duration
readHeaderTimeout time.Duration
writeTimeout time.Duration
idleTimeout time.Duration
shutdownTimeout time.Duration
logger *slog.Logger
middlewares []Middleware
onShutdown []func()
}
// Option configures a Server.
type Option func(*serverOptions)
// WithAddr sets the listen address. Default is ":8080".
func WithAddr(addr string) Option {
return func(o *serverOptions) { o.addr = addr }
}
// WithReadTimeout sets the maximum duration for reading the entire request.
func WithReadTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readTimeout = d }
}
// WithReadHeaderTimeout sets the maximum duration for reading request headers.
func WithReadHeaderTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readHeaderTimeout = d }
}
// WithWriteTimeout sets the maximum duration before timing out writes of the response.
func WithWriteTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.writeTimeout = d }
}
// WithIdleTimeout sets the maximum amount of time to wait for the next request
// when keep-alives are enabled.
func WithIdleTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.idleTimeout = d }
}
// WithShutdownTimeout sets the maximum duration to wait for active connections
// to close during graceful shutdown. Default is 15 seconds.
func WithShutdownTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.shutdownTimeout = d }
}
// WithLogger sets the structured logger used by the server for lifecycle events.
func WithLogger(l *slog.Logger) Option {
return func(o *serverOptions) { o.logger = l }
}
// WithMiddleware appends server middlewares to the chain.
// These are applied to the handler in the order given.
func WithMiddleware(mws ...Middleware) Option {
return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) }
}
// WithOnShutdown registers a function to be called during graceful shutdown,
// before the HTTP server begins draining connections.
func WithOnShutdown(fn func()) Option {
return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) }
}
// Defaults returns a production-ready set of options including standard
// middleware (RequestID, Recovery, Logging), sensible timeouts, and the
// provided logger.
//
// Middleware order: RequestID → Recovery → Logging → user handler.
func Defaults(logger *slog.Logger) []Option {
return []Option{
WithReadHeaderTimeout(10 * time.Second),
WithIdleTimeout(120 * time.Second),
WithShutdownTimeout(15 * time.Second),
WithLogger(logger),
WithMiddleware(
RequestID(),
Recovery(WithRecoveryLogger(logger)),
Logging(logger),
),
}
}

29
server/respond.go Normal file
View File

@@ -0,0 +1,29 @@
package server
import (
"encoding/json"
"net/http"
)
// WriteJSON encodes v as JSON and writes it to w with the given status code.
// It sets Content-Type to application/json.
func WriteJSON(w http.ResponseWriter, status int, v any) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_, err = w.Write(b)
return err
}
// WriteError writes a JSON error response with the given status code and
// message. The response body is {"error": "<message>"}.
func WriteError(w http.ResponseWriter, status int, msg string) error {
return WriteJSON(w, status, errorBody{Error: msg})
}
type errorBody struct {
Error string `json:"error"`
}

72
server/respond_test.go Normal file
View File

@@ -0,0 +1,72 @@
package server_test
import (
"encoding/json"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestWriteJSON(t *testing.T) {
t.Run("writes JSON with status and content type", func(t *testing.T) {
w := httptest.NewRecorder()
type resp struct {
ID int `json:"id"`
Name string `json:"name"`
}
err := server.WriteJSON(w, 201, resp{ID: 1, Name: "Alice"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if w.Code != 201 {
t.Fatalf("got status %d, want %d", w.Code, 201)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
}
var decoded resp
if err := json.Unmarshal(w.Body.Bytes(), &decoded); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if decoded.ID != 1 || decoded.Name != "Alice" {
t.Fatalf("got %+v, want {ID:1 Name:Alice}", decoded)
}
})
t.Run("returns error for unmarshalable input", func(t *testing.T) {
w := httptest.NewRecorder()
err := server.WriteJSON(w, 200, make(chan int))
if err == nil {
t.Fatal("expected error for channel type")
}
})
}
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
err := server.WriteError(w, 404, "not found")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if w.Code != 404 {
t.Fatalf("got status %d, want %d", w.Code, 404)
}
var body struct {
Error string `json:"error"`
}
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if body.Error != "not found" {
t.Fatalf("error = %q, want %q", body.Error, "not found")
}
}

126
server/route.go Normal file
View File

@@ -0,0 +1,126 @@
package server
import (
"net/http"
"strings"
)
// Router is a lightweight wrapper around http.ServeMux that adds middleware
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
// like "GET /users/{id}".
type Router struct {
mux *http.ServeMux
prefix string
middlewares []Middleware
notFoundHandler http.Handler
}
// RouterOption configures a Router.
type RouterOption func(*Router)
// WithNotFoundHandler sets a custom handler for requests that don't match
// any registered pattern. This is useful for returning JSON 404/405 responses
// instead of the default plain text.
func WithNotFoundHandler(h http.Handler) RouterOption {
return func(r *Router) { r.notFoundHandler = h }
}
// NewRouter creates a new Router backed by a fresh http.ServeMux.
func NewRouter(opts ...RouterOption) *Router {
r := &Router{
mux: http.NewServeMux(),
}
for _, opt := range opts {
opt(r)
}
return r
}
// Handle registers a handler for the given pattern. The pattern follows
// http.ServeMux conventions, including method-based patterns like "GET /users".
func (r *Router) Handle(pattern string, handler http.Handler) {
if len(r.middlewares) > 0 {
handler = Chain(r.middlewares...)(handler)
}
r.mux.Handle(r.prefixedPattern(pattern), handler)
}
// HandleFunc registers a handler function for the given pattern.
func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) {
r.Handle(pattern, fn)
}
// Group creates a sub-router with a shared prefix and optional middleware.
// Patterns registered on the group are prefixed automatically. The group
// shares the underlying ServeMux with the parent router.
//
// Example:
//
// api := router.Group("/api/v1", authMiddleware)
// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users"
func (r *Router) Group(prefix string, mws ...Middleware) *Router {
return &Router{
mux: r.mux,
prefix: r.prefix + prefix,
middlewares: append(r.middlewaresSnapshot(), mws...),
}
}
// Mount attaches an http.Handler under the given prefix. All requests
// starting with prefix are forwarded to the handler with the prefix stripped.
func (r *Router) Mount(prefix string, handler http.Handler) {
full := r.prefix + prefix
if !strings.HasSuffix(full, "/") {
full += "/"
}
r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler))
}
// ServeHTTP implements http.Handler, making Router usable as a handler.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.notFoundHandler != nil {
// Use the mux to check for a match. If none, use the custom handler.
_, pattern := r.mux.Handler(req)
if pattern == "" {
r.notFoundHandler.ServeHTTP(w, req)
return
}
}
r.mux.ServeHTTP(w, req)
}
// prefixedPattern inserts the router prefix into a pattern. It is aware of
// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users".
func (r *Router) prefixedPattern(pattern string) string {
if r.prefix == "" {
return pattern
}
// Split method prefix if present: "GET /users" → method="GET ", path="/users"
method, path, hasMethod := splitMethodPattern(pattern)
path = r.prefix + path
if hasMethod {
return method + path
}
return path
}
// splitMethodPattern splits "GET /path" into ("GET ", "/path", true).
// If there is no method prefix, returns ("", pattern, false).
func splitMethodPattern(pattern string) (method, path string, hasMethod bool) {
if idx := strings.IndexByte(pattern, ' '); idx >= 0 {
return pattern[:idx+1], pattern[idx+1:], true
}
return "", pattern, false
}
func (r *Router) middlewaresSnapshot() []Middleware {
if len(r.middlewares) == 0 {
return nil
}
cp := make([]Middleware, len(r.middlewares))
copy(cp, r.middlewares)
return cp
}

View File

@@ -0,0 +1,70 @@
package server_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestRouter_NotFoundHandler(t *testing.T) {
t.Run("custom 404 handler", func(t *testing.T) {
notFound := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
})
r := server.NewRouter(server.WithNotFoundHandler(notFound))
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Matched route works normally.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/exists", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("matched route: got status %d, want %d", w.Code, http.StatusOK)
}
// Unmatched route uses custom handler.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/nope", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("not found: got status %d, want %d", w.Code, http.StatusNotFound)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if body["error"] != "not found" {
t.Fatalf("error = %q, want %q", body["error"], "not found")
}
})
t.Run("default behavior without custom handler", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/nope", nil)
r.ServeHTTP(w, req)
// Default ServeMux returns 404.
if w.Code != http.StatusNotFound {
t.Fatalf("got status %d, want %d", w.Code, http.StatusNotFound)
}
})
}

336
server/route_test.go Normal file
View File

@@ -0,0 +1,336 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestRouter(t *testing.T) {
t.Run("basic route", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("world"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "world" {
t.Fatalf("got body %q, want %q", body, "world")
}
})
t.Run("Handle with http.Handler", func(t *testing.T) {
r := server.NewRouter()
r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("pong"))
}))
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "pong" {
t.Fatalf("got body %q, want %q", body, "pong")
}
})
t.Run("path parameter", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("user:" + req.PathValue("id")))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/users/42", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "user:42" {
t.Fatalf("got body %q, want %q", body, "user:42")
}
})
}
func TestRouterGroup(t *testing.T) {
t.Run("prefix is applied", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api/v1")
api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("users"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "users" {
t.Fatalf("got body %q, want %q", body, "users")
}
})
t.Run("nested groups", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api")
v1 := api.Group("/v1")
v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("items"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "items" {
t.Fatalf("got body %q, want %q", body, "items")
}
})
t.Run("group middleware", func(t *testing.T) {
var mwCalled bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mwCalled = true
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
g := r.Group("/admin", mw)
g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil)
r.ServeHTTP(w, req)
if !mwCalled {
t.Fatal("group middleware was not called")
}
})
}
func TestRouterMount(t *testing.T) {
t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) {
sub := http.NewServeMux()
sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("info"))
})
r := server.NewRouter()
r.Mount("/sub", sub)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/sub/info", nil)
r.ServeHTTP(w, req)
body, _ := io.ReadAll(w.Body)
if string(body) != "info" {
t.Fatalf("got body %q, want %q", body, "info")
}
})
t.Run("mount with trailing slash", func(t *testing.T) {
sub := http.NewServeMux()
sub.HandleFunc("GET /data", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("data"))
})
r := server.NewRouter()
r.Mount("/sub/", sub)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/sub/data", nil)
r.ServeHTTP(w, req)
body, _ := io.ReadAll(w.Body)
if string(body) != "data" {
t.Fatalf("got body %q, want %q", body, "data")
}
})
}
func TestRouter_PatternWithoutMethod(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("/static/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("static"))
})
for _, method := range []string{http.MethodGet, http.MethodPost} {
w := httptest.NewRecorder()
req := httptest.NewRequest(method, "/static/file.css", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("%s /static/file.css: got status %d, want %d", method, w.Code, http.StatusOK)
}
}
}
func TestRouter_GroupEmptyPrefix(t *testing.T) {
r := server.NewRouter()
g := r.Group("")
g.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "hello" {
t.Fatalf("got body %q, want %q", body, "hello")
}
}
func TestRouter_GroupInheritsMiddleware(t *testing.T) {
var order []string
parentMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "parent")
next.ServeHTTP(w, r)
})
}
childMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "child")
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
parent := r.Group("/api", parentMW)
child := parent.Group("/v1", childMW)
child.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
expected := []string{"parent", "child", "handler"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
}
func TestRouter_GroupMiddlewareOrder(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A")
next.ServeHTTP(w, r)
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B")
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
g := r.Group("/api", mwA)
sub := g.Group("/v1", mwB)
sub.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/test", nil)
r.ServeHTTP(w, req)
// Parent MW (A) should run before child MW (B), then handler.
expected := []string{"A", "B", "handler"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
}
func TestRouter_PathParamWithGroup(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api")
api.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("id=" + req.PathValue("id")))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/users/42", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "id=42" {
t.Fatalf("got body %q, want %q", body, "id=42")
}
}
func TestRouter_MiddlewareNotAppliedToOtherRoutes(t *testing.T) {
var mwCalled bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mwCalled = true
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
// Add middleware only to /admin group.
admin := r.Group("/admin", mw)
admin.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("admin"))
})
// Route outside the group.
r.HandleFunc("GET /public", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("public"))
})
// Request to /public should NOT trigger group middleware.
mwCalled = false
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/public", nil)
r.ServeHTTP(w, req)
if mwCalled {
t.Fatal("group middleware should not be called for routes outside the group")
}
if w.Body.String() != "public" {
t.Fatalf("got body %q, want %q", w.Body.String(), "public")
}
}

173
server/server.go Normal file
View File

@@ -0,0 +1,173 @@
package server
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"os/signal"
"sync/atomic"
"syscall"
"time"
)
// Server is a production-ready HTTP server with graceful shutdown,
// middleware support, and signal handling.
type Server struct {
httpServer *http.Server
listener net.Listener
addr atomic.Value
logger *slog.Logger
shutdownTimeout time.Duration
onShutdown []func()
listenAddr string
}
// New creates a new Server that will serve the given handler with the
// provided options. Middleware from options is applied to the handler.
func New(handler http.Handler, opts ...Option) *Server {
o := &serverOptions{
addr: ":8080",
shutdownTimeout: 15 * time.Second,
}
for _, opt := range opts {
opt(o)
}
// Apply middleware chain to the handler.
if len(o.middlewares) > 0 {
handler = Chain(o.middlewares...)(handler)
}
srv := &Server{
httpServer: &http.Server{
Handler: handler,
ReadTimeout: o.readTimeout,
ReadHeaderTimeout: o.readHeaderTimeout,
WriteTimeout: o.writeTimeout,
IdleTimeout: o.idleTimeout,
},
logger: o.logger,
shutdownTimeout: o.shutdownTimeout,
onShutdown: o.onShutdown,
listenAddr: o.addr,
}
return srv
}
// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM
// signal is received. It then performs a graceful shutdown within the
// configured shutdown timeout.
//
// Returns nil on clean shutdown or an error if listen/shutdown fails.
func (s *Server) ListenAndServe() error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started", slog.String("addr", ln.Addr().String()))
// Wait for signal in context.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// ListenAndServeTLS starts the server with TLS and blocks until a signal
// is received.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started (TLS)", slog.String("addr", ln.Addr().String()))
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// Shutdown gracefully shuts down the server. It calls any registered
// onShutdown hooks, then waits for active connections to drain within
// the shutdown timeout.
func (s *Server) Shutdown(ctx context.Context) error {
s.runOnShutdown()
return s.httpServer.Shutdown(ctx)
}
// Addr returns the listener address after the server has started.
// Returns an empty string if the server has not started yet.
func (s *Server) Addr() string {
v := s.addr.Load()
if v == nil {
return ""
}
return v.(string)
}
func (s *Server) shutdown() error {
s.log("shutting down")
s.runOnShutdown()
ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.log("shutdown error", slog.String("error", err.Error()))
return err
}
s.log("server stopped")
return nil
}
func (s *Server) runOnShutdown() {
for _, fn := range s.onShutdown {
fn()
}
}
func (s *Server) log(msg string, attrs ...slog.Attr) {
if s.logger != nil {
s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...)
}
}

268
server/server_test.go Normal file
View File

@@ -0,0 +1,268 @@
package server_test
import (
"bytes"
"context"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestServerLifecycle(t *testing.T) {
t.Run("starts and serves requests", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello"))
})
srv := server.New(handler, server.WithAddr(":0"))
// Start in background and wait for addr.
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if string(body) != "hello" {
t.Fatalf("got body %q, want %q", body, "hello")
}
// Shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
})
t.Run("addr returns empty before start", func(t *testing.T) {
srv := server.New(http.NotFoundHandler())
if addr := srv.Addr(); addr != "" {
t.Fatalf("got addr %q before start, want empty", addr)
}
})
}
func TestGracefulShutdown(t *testing.T) {
t.Run("calls onShutdown hooks", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithOnShutdown(func() { called = true }),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
if !called {
t.Fatal("onShutdown hook was not called")
}
})
}
func TestServerWithMiddleware(t *testing.T) {
t.Run("applies middleware from options", func(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithMiddleware(mw),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
if !called {
t.Fatal("middleware was not called")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
})
}
func TestServerDefaults(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler, append(server.Defaults(logger), server.WithAddr(":0"))...)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
// Defaults includes RequestID middleware, so response should have X-Request-Id.
if resp.Header.Get("X-Request-Id") == "" {
t.Fatal("expected X-Request-Id header from Defaults middleware")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
}
func TestServerListenError(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Use an invalid address to trigger a listen error.
srv := server.New(handler, server.WithAddr(":-1"))
err := srv.ListenAndServe()
if err == nil {
t.Fatal("expected error from invalid address, got nil")
}
}
func TestServerMultipleOnShutdownHooks(t *testing.T) {
var calls []int
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithOnShutdown(func() { calls = append(calls, 1) }),
server.WithOnShutdown(func() { calls = append(calls, 2) }),
server.WithOnShutdown(func() { calls = append(calls, 3) }),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
if len(calls) != 3 {
t.Fatalf("expected 3 hooks called, got %d: %v", len(calls), calls)
}
}
func TestServerShutdownWithLogger(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithLogger(logger),
)
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
waitForAddr(t, srv)
// Send SIGINT to trigger graceful shutdown via ListenAndServe's signal handler.
// Instead, use Shutdown directly and check log from server start.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
// The server logs "server started" on ListenAndServe.
logOutput := buf.String()
if !strings.Contains(logOutput, "server started") {
t.Fatalf("expected 'server started' in log, got %q", logOutput)
}
}
func TestServerOptions(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Verify options don't panic and server starts correctly.
srv := server.New(handler,
server.WithAddr(":0"),
server.WithReadTimeout(5*time.Second),
server.WithReadHeaderTimeout(3*time.Second),
server.WithWriteTimeout(10*time.Second),
server.WithIdleTimeout(60*time.Second),
server.WithShutdownTimeout(5*time.Second),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
}
// waitForAddr polls until the server's Addr() is non-empty.
func waitForAddr(t *testing.T, srv *server.Server) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if srv.Addr() != "" {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("server did not start in time")
}