Compare commits

...

20 Commits

Author SHA1 Message Date
b6350185d9 Fix publish workflow: use git archive instead of go install
All checks were successful
CI / test (push) Successful in 51s
Publish / publish (push) Successful in 51s
golang.org/x/mod/zip is a library, not a CLI tool. Use git archive
with --prefix to create the module zip in the format Gitea expects.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 10:39:35 +03:00
85cdc5e2c9 Add publish workflow for Gitea Go Package Registry
Some checks failed
CI / test (push) Successful in 32s
Publish / publish (push) Failing after 37s
Publishes the module to Gitea Package Registry on tag push (v*).
Runs vet and tests before publishing to prevent broken releases.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 10:25:19 +03:00
25beb2f5c2 Add AGENTS.md reference to CLAUDE.md
All checks were successful
CI / test (push) Successful in 31s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:42 +03:00
16ff427c93 Add AI agent configuration files (AGENTS.md, .cursorrules, copilot-instructions)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:38 +03:00
138d4b6c6d Add package-level doc comments for go doc and gopls
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:20:33 +03:00
3aa7536328 Add examples/ with runnable usage demos for all major features
All checks were successful
CI / test (push) Successful in 31s
Six examples covering the full API surface:
- basic-client: retry, timeout, logging, response size limit
- form-request: form-encoded POST for OAuth/webhooks
- load-balancing: weighted endpoints, circuit breaker, health checks
- server-basic: routing, groups, JSON helpers, health, custom 404
- server-protected: CORS, rate limiting, body limits, timeouts
- request-id-propagation: cross-service request ID forwarding

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 22:05:08 +03:00
89cfc38f0e Update documentation with new client and server features
All checks were successful
CI / test (push) Successful in 30s
README: add PATCH, NewFormRequest, CORS, RateLimit, MaxBodySize,
Timeout, WriteJSON/WriteError, request ID propagation, response body
limit, and custom 404 handler examples. CLAUDE.md: document new
architecture details for all added components.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:55:15 +03:00
8a63f142a7 Add WithNotFoundHandler option for custom 404 responses on Router
Allows configuring a custom handler for unmatched routes, enabling
consistent JSON error responses instead of ServeMux's default plain
text. NewRouter now accepts RouterOption functional options.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:48:13 +03:00
21274c178a Add WithMaxResponseBody option to prevent client-side OOM
Wraps response body with io.LimitedReader when configured, preventing
unbounded reads from io.ReadAll in Response.Bytes(). Protects against
upstream services returning unexpectedly large responses.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:48:05 +03:00
49be6f8a7e Add client RequestID middleware for cross-service propagation
Introduces internal/requestid package with shared context key to avoid
circular imports between server and middleware packages. Server's
RequestID middleware now uses the shared key. Client middleware picks up
the ID from context and sets X-Request-Id on outgoing requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:58 +03:00
3395f70abd Add server RateLimit middleware with per-key token bucket
Protects against abuse with configurable rate/burst per client IP.
Supports custom key functions, X-Forwarded-For extraction, and
Retry-After headers on 429 responses. Uses internal/clock for
testability.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:51 +03:00
7a2cef00c3 Add server WriteJSON and WriteError response helpers
Eliminates repeated marshal-set-header-write boilerplate in handlers.
WriteError produces consistent {"error": "..."} JSON responses.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:45 +03:00
de5bf9a6d9 Add server CORS middleware with preflight handling
Supports AllowOrigins, AllowMethods, AllowHeaders, ExposeHeaders,
AllowCredentials, and MaxAge options. Handles preflight OPTIONS requests
correctly, including Vary header and credential-aware origin echoing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:39 +03:00
7f12b0c87a Add server Timeout middleware for context-based request deadlines
Wraps http.TimeoutHandler to return 503 when handlers exceed the
configured duration. Unlike http.Server.WriteTimeout, this allows
handlers to complete gracefully via context cancellation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:33 +03:00
1b322c8c81 Add server MaxBodySize middleware to prevent memory exhaustion
Wraps request body with http.MaxBytesReader to limit incoming payload
size. Without this, any endpoint accepting a body is vulnerable to
large uploads consuming all available memory.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:26 +03:00
b40a373675 Add NewFormRequest for form-encoded HTTP requests
Creates requests with application/x-www-form-urlencoded body from
url.Values. Supports GetBody for retry compatibility, following the
same pattern as NewJSONRequest.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:19 +03:00
f6384ecbea Add Client.Patch method for PATCH HTTP requests
Follows the same pattern as Put/Post, accepting context, URL, and body.
Closes an obvious gap in the REST client API.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:47:11 +03:00
4d47918a66 Update documentation with server package details
All checks were successful
CI / test (push) Successful in 30s
Add server package description, component table, and usage example to
README. Document server architecture, middleware chain, and test
conventions in CLAUDE.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 22:56:12 +03:00
7fae6247d5 Add comprehensive test coverage for server/ package
All checks were successful
CI / test (push) Successful in 30s
Cover edge cases: statusWriter multi-call/default/unwrap, UUID v4 format
and uniqueness, non-string panics, recovery body and log attributes,
4xx log level, default status in logging, request ID propagation,
server defaults/options/listen-error/multiple-hooks/logger, router
groups with empty prefix/inherited middleware/ordering/path params/
isolation, mount trailing slash, health content-type and POST rejection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:55:22 +03:00
cea75d198b Add production-ready HTTP server package with routing, health checks, and middleware
Introduces server/ sub-package as the server-side companion to the existing Client.
Includes Router (over http.ServeMux with groups and mounting), graceful shutdown with
signal handling, health endpoints (/healthz, /readyz), and built-in middlewares
(RequestID, Recovery, Logging). Zero external dependencies.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 13:41:54 +03:00
51 changed files with 4139 additions and 20 deletions

48
.cursorrules Normal file
View File

@@ -0,0 +1,48 @@
You are working on `git.codelab.vc/pkg/httpx`, a Go 1.24 HTTP client/server library with zero external dependencies.
## Architecture
- Client middleware: `func(http.RoundTripper) http.RoundTripper` — compose with `middleware.Chain`
- Server middleware: `func(http.Handler) http.Handler` — compose with `server.Chain`
- All configuration uses functional options pattern (`WithXxx` functions)
- Chain order for client: Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
## Package structure
- `httpx` (root) — Client, request builders (NewJSONRequest, NewFormRequest), error types
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
- `retry/` — retry middleware with exponential backoff and Retry-After support
- `circuitbreaker/` — per-host circuit breaker (sync.Map of host → Breaker)
- `balancer/` — load balancing with health checking (RoundRobin, Weighted, Failover)
- `server/` — Server, Router, server middleware (RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout), response helpers (WriteJSON, WriteError)
- `internal/requestid/` — shared context key (avoids circular import between server and middleware)
- `internal/clock/` — deterministic time for tests
## Code conventions
- Zero external dependencies — stdlib only, do not add imports outside the module
- Functional options: `type Option func(*options)` with `With<Name>` constructors
- Test with stdlib only: `testing`, `httptest`, `net/http`. No testify/gomock
- Client test helper: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc`
- Server test helper: `httptest.NewRecorder`, `httptest.NewRequest`, `waitForAddr(t, srv)`
- Thread safety with `sync.Mutex`, `sync.Map`, or `atomic`
- Use `internal/clock` for time-dependent tests, not `time.Now()` directly
- Sentinel errors in sub-packages, re-exported as aliases in root package
## When writing new code
- Client middleware → file in `middleware/`, return `middleware.Middleware`
- Server middleware → file in `server/middleware_<name>.go`, return `server.Middleware`
- New option → add field to options struct, create `With<Name>` func, apply in constructor
- Do NOT import `server` from `middleware` or vice versa (use `internal/requestid` for shared context)
- Client.Close() must be called when using WithEndpoints() (stops health checker goroutine)
- Request bodies must have GetBody set for retry — use NewJSONRequest/NewFormRequest
## Commands
```bash
go build ./... # compile
go test ./... # test
go test -race ./... # test with race detector
go vet ./... # static analysis
```

View File

@@ -0,0 +1,39 @@
name: Publish
on:
push:
tags: ["v*"]
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...
- name: Publish to Gitea Package Registry
run: |
VERSION=${GITHUB_REF#refs/tags/}
MODULE=$(go list -m)
# Create module zip with required prefix: module@version/
git archive --format=zip --prefix="${MODULE}@${VERSION}/" HEAD -o module.zip
# Gitea Go Package Registry API
curl -s -f \
-X PUT \
-H "Authorization: token ${{ secrets.PUBLISH_TOKEN }}" \
-H "Content-Type: application/zip" \
--data-binary @module.zip \
"${{ github.server_url }}/api/packages/pkg/go/upload?module=${MODULE}&version=${VERSION}"
env:
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}

50
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,50 @@
# Copilot instructions — httpx
## Project
`git.codelab.vc/pkg/httpx` is a Go 1.24 HTTP client and server library with zero external dependencies.
## Architecture
### Client (`httpx` root package)
- Middleware type: `func(http.RoundTripper) http.RoundTripper`
- Chain assembly (outermost → innermost): Logging → User MW → Retry → Circuit Breaker → Balancer → Transport
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
- Circuit breaker is per-host (sync.Map of host → Breaker)
- Client.Close() required when using WithEndpoints() — stops health checker goroutine
### Server (`server/` package)
- Middleware type: `func(http.Handler) http.Handler`
- Router wraps http.ServeMux with groups, prefix routing, Mount for sub-handlers
- Defaults() preset: RequestID → Recovery → Logging + production timeouts
- Available middleware: RequestID, Recovery, Logging, CORS, RateLimit, MaxBodySize, Timeout
- WriteJSON/WriteError for JSON responses
### Sub-packages
- `middleware/` — client-side middleware (Logging, Recovery, Auth, Headers, RequestID)
- `retry/` — retry with exponential backoff and Retry-After
- `circuitbreaker/` — per-host circuit breaker
- `balancer/` — load balancing with health checking
- `internal/requestid/` — shared context key between server and middleware
- `internal/clock/` — deterministic time for tests
## Conventions
- All configuration uses functional options (`WithXxx` functions)
- Zero external dependencies — do not add requires to go.mod
- Tests use stdlib only (testing, httptest) — no testify or gomock
- Thread safety with sync.Mutex, sync.Map, or atomic
- Client test mock: `mockTransport(fn)` using `middleware.RoundTripperFunc`
- Server test helpers: `httptest.NewRecorder`, `httptest.NewRequest`
- Do NOT import server from middleware or vice versa — use internal/requestid for shared context
- Sentinel errors in sub-packages, re-exported in root package
- Use internal/clock for time-dependent tests
## Commands
```bash
go build ./... # compile
go test ./... # test
go test -race ./... # test with race detector
go vet ./... # static analysis
```

139
AGENTS.md Normal file
View File

@@ -0,0 +1,139 @@
# AGENTS.md — httpx
Universal guide for AI coding agents working with this codebase.
## Overview
`git.codelab.vc/pkg/httpx` is a Go HTTP toolkit with **zero external dependencies** (Go 1.24, stdlib only). It provides:
- A composable HTTP **client** with retry, circuit breaking, load balancing
- A production-ready HTTP **server** with routing, middleware, graceful shutdown
## Package map
```
httpx/ Root — Client, request builders, error types
├── middleware/ Client-side middleware (RoundTripper wrappers)
├── retry/ Retry middleware with backoff
├── circuitbreaker/ Per-host circuit breaker
├── balancer/ Client-side load balancing + health checking
├── server/ Server, Router, server-side middleware, response helpers
└── internal/
├── requestid/ Shared context key (avoids circular imports)
└── clock/ Deterministic time for testing
```
## Middleware chain architecture
### Client middleware: `func(http.RoundTripper) http.RoundTripper`
```
Request flow (outermost → innermost):
Logging
└→ User Middlewares
└→ Retry
└→ Circuit Breaker
└→ Balancer
└→ Base Transport (http.DefaultTransport)
```
Retry wraps CB+Balancer so each attempt can hit a different endpoint.
### Server middleware: `func(http.Handler) http.Handler`
```
Chain(A, B, C)(handler) == A(B(C(handler)))
A is outermost (sees request first, response last)
```
Defaults() preset: `RequestID → Recovery → Logging`
## Common tasks
### Add a client middleware
1. Create file in `middleware/` (or inline)
2. Return `middleware.Middleware` (`func(http.RoundTripper) http.RoundTripper`)
3. Use `middleware.RoundTripperFunc` for the inner adapter
4. Test with `middleware.RoundTripperFunc` as mock transport
```go
func MyMiddleware() middleware.Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
// before
resp, err := next.RoundTrip(req)
// after
return resp, err
})
}
}
```
### Add a server middleware
1. Create file in `server/` named `middleware_<name>.go`
2. Return `server.Middleware` (`func(http.Handler) http.Handler`)
3. Use `server.statusWriter` if you need to capture the response status
4. Test with `httptest.NewRecorder` + `httptest.NewRequest`
```go
func MyMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// before
next.ServeHTTP(w, r)
// after
})
}
}
```
### Add a route
```go
r := server.NewRouter()
r.HandleFunc("GET /users/{id}", getUser)
r.HandleFunc("POST /users", createUser)
// Group with prefix + middleware
api := r.Group("/api/v1", authMiddleware)
api.HandleFunc("GET /items", listItems)
// Mount sub-handler
r.Mount("/health", server.HealthHandler())
```
### Add a functional option
1. Add field to the options struct (`clientOptions` or `serverOptions`)
2. Create `With<Name>` function returning `Option`
3. Apply the field in the constructor (`New`)
## Gotchas
- **Middleware order matters**: Retry wraps CB+Balancer intentionally — each retry attempt can hit a different endpoint and a different circuit breaker
- **Circular imports via `internal/`**: Both `server` and `middleware` packages need request ID context. The shared key lives in `internal/requestid` — do NOT import `server` from `middleware` or vice versa
- **Client.Close() is required** when using `WithEndpoints()` — the balancer starts a background health checker goroutine that must be stopped
- **GetBody for retries**: Request bodies must be replayable. Use `NewJSONRequest`/`NewFormRequest` (they set `GetBody`) or set it manually
- **statusWriter.Unwrap()**: Server middleware must not type-assert `http.ResponseWriter` directly — use `http.ResponseController` which calls `Unwrap()` to find `http.Flusher`, `http.Hijacker`, etc.
- **No external deps**: This is a zero-dependency library. Do not add any `require` to `go.mod`
## Commands
```bash
go build ./... # compile
go test ./... # all tests
go test -race ./... # tests with race detector
go test -v -run TestName ./package/ # single test
go vet ./... # static analysis
```
## Conventions
- **Functional options** for all configuration (client and server)
- **stdlib only** testing — no testify, no gomock
- **Thread safety** — use `sync.Mutex`, `sync.Map`, or `atomic` where needed
- **`internal/clock`** — use for deterministic time in tests (never `time.Now()` directly in testable code)
- **Test helpers**: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server), `waitForAddr(t, srv)` for server integration tests
- **Sentinel errors** live in sub-packages, root package re-exports as aliases

View File

@@ -13,6 +13,8 @@ go vet ./... # static analysis
## Architecture
- **Module**: `git.codelab.vc/pkg/httpx`, Go 1.24, zero external dependencies
### Client
- **Core pattern**: middleware is `func(http.RoundTripper) http.RoundTripper`
- **Chain assembly order** (client.go): Logging → User MW → Retry → CB → Balancer → Transport
- Retry wraps CB+Balancer so each attempt can hit a different endpoint
@@ -20,11 +22,35 @@ go vet ./... # static analysis
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
- **Client.Close()** stops the health checker goroutine
- **Client.Patch()** — PATCH method, same pattern as Put/Post
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
### Server (`server/`)
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers, `WithNotFoundHandler` for custom 404
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock`
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
## Conventions
- Functional options for all configuration
- Test helpers: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc`
- Functional options for all configuration (client and server)
- Test helpers: `mockTransport(fn)` wrapping `middleware.RoundTripperFunc` (client), `httptest.NewRecorder`/`httptest.NewRequest` (server)
- Server tests use `waitForAddr(t, srv)` helper to poll until server is ready
- No external test frameworks — stdlib only
- Thread safety required (`sync.Mutex`/`atomic`)
- `internal/clock` for deterministic time testing
## See also
- `AGENTS.md` — universal AI agent guide with common tasks, gotchas, and ASCII diagrams

124
README.md
View File

@@ -1,6 +1,6 @@
# httpx
HTTP client for Go microservices. Retry, load balancing, circuit breaking, all as `http.RoundTripper` middleware. stdlib only, zero external deps.
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking, request ID propagation, response size limits — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging, CORS, rate limiting, body limits, timeouts), health checks, JSON helpers, graceful shutdown. stdlib only, zero external deps.
```
go get git.codelab.vc/pkg/httpx
@@ -29,18 +29,50 @@ if err != nil {
var user User
resp.JSON(&user)
// PATCH request
resp, err = client.Patch(ctx, "/users/123", strings.NewReader(`{"name":"updated"}`))
// Form-encoded request (OAuth, webhooks, etc.)
req, _ := httpx.NewFormRequest(ctx, http.MethodPost, "/oauth/token", url.Values{
"grant_type": {"client_credentials"},
"scope": {"read write"},
})
resp, err = client.Do(ctx, req)
```
## Packages
Everything is a `func(http.RoundTripper) http.RoundTripper`. Use them with `httpx.Client` or plug into a plain `http.Client`.
### Client
Client middleware is `func(http.RoundTripper) http.RoundTripper`. Use them with `httpx.Client` or plug into a plain `http.Client`.
| Package | What it does |
|---------|-------------|
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery. |
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery, request ID propagation. |
### Server
Server middleware is `func(http.Handler) http.Handler`. The `server` package provides a production-ready HTTP server.
| Component | What it does |
|-----------|-------------|
| `server.Server` | Wraps `http.Server` with graceful shutdown, signal handling, lifecycle logging. |
| `server.Router` | Lightweight wrapper around `http.ServeMux` with groups, prefix routing, sub-router mounting. |
| `server.RequestID` | Assigns/propagates `X-Request-Id` (UUID v4 via `crypto/rand`). |
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. |
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
| `server.WriteError` | JSON error response (`{"error": "..."}`) helper. |
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
The client assembles them in this order:
@@ -90,6 +122,92 @@ httpClient := &http.Client{
}
```
## Server
```go
logger := slog.Default()
r := server.NewRouter(
// Custom JSON 404 instead of plain text
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, 404, "not found")
})),
)
r.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
server.WriteJSON(w, 200, map[string]string{"message": "world"})
})
// Groups with middleware
api := r.Group("/api/v1", authMiddleware)
api.HandleFunc("GET /users/{id}", getUser)
// Health checks
r.Mount("/", server.HealthHandler(
func() error { return db.Ping() },
))
srv := server.New(r,
append(server.Defaults(logger),
// Protection middleware
server.WithMiddleware(
server.CORS(
server.AllowOrigins("https://app.example.com"),
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
server.AllowHeaders("Authorization", "Content-Type"),
server.MaxAge(3600),
),
server.RateLimit(
server.WithRate(100),
server.WithBurst(200),
),
server.MaxBodySize(1<<20), // 1 MB
server.Timeout(30*time.Second),
),
)...,
)
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
```
## Client request ID propagation
In microservices, forward the incoming request ID to downstream calls:
```go
client := httpx.New(
httpx.WithMiddleware(middleware.RequestID()),
)
// In a server handler — the context already has the request ID from server.RequestID():
func handler(w http.ResponseWriter, r *http.Request) {
// ID is automatically forwarded as X-Request-Id
resp, err := client.Get(r.Context(), "https://downstream/api")
}
```
## Response body limit
Protect against OOM from unexpectedly large upstream responses:
```go
client := httpx.New(
httpx.WithMaxResponseBody(10 << 20), // 10 MB max
)
```
## Examples
See the [`examples/`](examples/) directory for runnable programs:
| Example | Description |
|---------|-------------|
| [`basic-client`](examples/basic-client/) | HTTP client with retry, timeout, logging, and response size limit |
| [`form-request`](examples/form-request/) | Form-encoded POST requests (OAuth, webhooks) |
| [`load-balancing`](examples/load-balancing/) | Multi-endpoint client with weighted balancing, circuit breaker, and health checks |
| [`server-basic`](examples/server-basic/) | Server with routing, groups, JSON helpers, health checks, and custom 404 |
| [`server-protected`](examples/server-protected/) | Production server with CORS, rate limiting, body limits, and timeouts |
| [`request-id-propagation`](examples/request-id-propagation/) | Request ID forwarding between server and client for distributed tracing |
## Requirements
Go 1.24+, stdlib only.

34
balancer/doc.go Normal file
View File

@@ -0,0 +1,34 @@
// Package balancer provides client-side load balancing as HTTP middleware.
//
// It distributes requests across multiple backend endpoints using pluggable
// strategies (round-robin, weighted, failover) with optional health checking.
//
// # Usage
//
// mw, closer := balancer.Transport(
// []balancer.Endpoint{
// {URL: "http://backend1:8080"},
// {URL: "http://backend2:8080"},
// },
// balancer.WithStrategy(balancer.RoundRobin()),
// balancer.WithHealthCheck(5 * time.Second),
// )
// defer closer.Close()
// transport := mw(http.DefaultTransport)
//
// # Strategies
//
// - RoundRobin — cycles through healthy endpoints
// - Weighted — distributes based on endpoint Weight field
// - Failover — prefers primary, falls back to secondaries
//
// # Health checking
//
// When enabled, a background goroutine periodically probes each endpoint.
// The returned Closer must be closed to stop the health checker goroutine.
// In httpx.Client, this is handled by Client.Close().
//
// # Sentinel errors
//
// ErrNoHealthy is returned when no healthy endpoints are available.
package balancer

27
circuitbreaker/doc.go Normal file
View File

@@ -0,0 +1,27 @@
// Package circuitbreaker provides a per-host circuit breaker as HTTP middleware.
//
// The circuit breaker monitors request failures and temporarily blocks requests
// to unhealthy hosts, allowing them time to recover before retrying.
//
// # State machine
//
// - Closed — normal operation, requests pass through
// - Open — too many failures, requests are rejected with ErrCircuitOpen
// - HalfOpen — after a cooldown period, one probe request is allowed through
//
// # Usage
//
// mw := circuitbreaker.Transport(
// circuitbreaker.WithThreshold(5),
// circuitbreaker.WithTimeout(30 * time.Second),
// )
// transport := mw(http.DefaultTransport)
//
// The circuit breaker is per-host: each unique request host gets its own
// independent breaker state machine stored in a sync.Map.
//
// # Sentinel errors
//
// ErrCircuitOpen is returned when a request is rejected because the circuit
// is in the Open state.
package circuitbreaker

View File

@@ -20,6 +20,7 @@ type Client struct {
baseURL string
errorMapper ErrorMapper
balancerCloser *balancer.Closer
maxResponseBody int64
}
// New creates a new Client with the given options.
@@ -80,6 +81,7 @@ func New(opts ...Option) *Client {
baseURL: o.baseURL,
errorMapper: o.errorMapper,
balancerCloser: balancerCloser,
maxResponseBody: o.maxResponseBody,
}
}
@@ -99,6 +101,13 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
}
}
if c.maxResponseBody > 0 {
resp.Body = &limitedReadCloser{
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody},
C: resp.Body,
}
}
r := newResponse(resp)
if c.errorMapper != nil {
@@ -142,6 +151,15 @@ func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response
return c.Do(ctx, req)
}
// Patch performs a PATCH request to the given URL with the given body.
func (c *Client) Patch(ctx context.Context, url string, body io.Reader) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, body)
if err != nil {
return nil, err
}
return c.Do(ctx, req)
}
// Delete performs a DELETE request to the given URL.
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)

View File

@@ -24,6 +24,7 @@ type clientOptions struct {
enableCB bool
endpoints []balancer.Endpoint
balancerOpts []balancer.Option
maxResponseBody int64
}
// Option configures a Client.
@@ -85,3 +86,11 @@ func WithEndpoints(eps ...balancer.Endpoint) Option {
func WithBalancer(opts ...balancer.Option) Option {
return func(o *clientOptions) { o.balancerOpts = opts }
}
// WithMaxResponseBody limits the number of bytes read from response bodies
// by Response.Bytes (and by extension String, JSON, XML). If the response
// body exceeds n bytes, reading stops and returns an error.
// A value of 0 means no limit (the default).
func WithMaxResponseBody(n int64) Option {
return func(o *clientOptions) { o.maxResponseBody = n }
}

45
client_patch_test.go Normal file
View File

@@ -0,0 +1,45 @@
package httpx_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestClient_Patch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
t.Errorf("expected PATCH, got %s", r.Method)
}
b, _ := io.ReadAll(r.Body)
if string(b) != `{"name":"updated"}` {
t.Errorf("expected body %q, got %q", `{"name":"updated"}`, string(b))
}
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "patched")
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Patch(context.Background(), srv.URL+"/item/1", strings.NewReader(`{"name":"updated"}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "patched" {
t.Errorf("expected body %q, got %q", "patched", body)
}
}

39
doc.go Normal file
View File

@@ -0,0 +1,39 @@
// Package httpx provides a high-level HTTP client with composable middleware
// for retry, circuit breaking, load balancing, structured logging, and more.
//
// The client is configured via functional options and assembled as a middleware
// chain around a standard http.RoundTripper:
//
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport
//
// # Quick start
//
// client := httpx.New(
// httpx.WithBaseURL("https://api.example.com"),
// httpx.WithTimeout(10 * time.Second),
// httpx.WithRetry(),
// httpx.WithCircuitBreaker(),
// )
// defer client.Close()
//
// resp, err := client.Get(ctx, "/users/1")
//
// # Request builders
//
// NewJSONRequest and NewFormRequest create requests with appropriate
// Content-Type headers and GetBody set for retry compatibility.
//
// # Error handling
//
// Failed requests return *httpx.Error with structured fields (Op, URL,
// StatusCode). Sentinel errors ErrRetryExhausted, ErrCircuitOpen, and
// ErrNoHealthy can be checked with errors.Is.
//
// # Sub-packages
//
// - middleware — client-side middleware (logging, auth, headers, recovery, request ID)
// - retry — configurable retry with backoff and Retry-After support
// - circuitbreaker — per-host circuit breaker (closed → open → half-open)
// - balancer — client-side load balancing with health checking
// - server — production HTTP server with router, middleware, and graceful shutdown
package httpx

View File

@@ -0,0 +1,67 @@
// Basic HTTP client with retry, timeout, and structured logging.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"time"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/retry"
)
func main() {
client := httpx.New(
httpx.WithBaseURL("https://httpbin.org"),
httpx.WithTimeout(10*time.Second),
httpx.WithRetry(
retry.WithMaxAttempts(3),
retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 2*time.Second, true)),
),
httpx.WithMiddleware(
middleware.UserAgent("httpx-example/1.0"),
),
httpx.WithMaxResponseBody(1<<20), // 1 MB limit
httpx.WithLogger(slog.Default()),
)
defer client.Close()
ctx := context.Background()
// GET request.
resp, err := client.Get(ctx, "/get")
if err != nil {
log.Fatal(err)
}
body, err := resp.String()
if err != nil {
log.Fatal(err)
}
fmt.Printf("GET /get → %d (%d bytes)\n", resp.StatusCode, len(body))
// POST with JSON body.
type payload struct {
Name string `json:"name"`
Email string `json:"email"`
}
req, err := httpx.NewJSONRequest(ctx, "POST", "/post", payload{
Name: "Alice",
Email: "alice@example.com",
})
if err != nil {
log.Fatal(err)
}
resp, err = client.Do(ctx, req)
if err != nil {
log.Fatal(err)
}
defer resp.Close()
fmt.Printf("POST /post → %d\n", resp.StatusCode)
}

View File

@@ -0,0 +1,41 @@
// Demonstrates form-encoded requests for OAuth token endpoints and similar APIs.
package main
import (
"context"
"fmt"
"log"
"net/http"
"net/url"
"git.codelab.vc/pkg/httpx"
)
func main() {
client := httpx.New(
httpx.WithBaseURL("https://httpbin.org"),
)
defer client.Close()
ctx := context.Background()
// Form-encoded POST, common for OAuth token endpoints.
req, err := httpx.NewFormRequest(ctx, http.MethodPost, "/post", url.Values{
"grant_type": {"client_credentials"},
"client_id": {"my-app"},
"client_secret": {"secret"},
"scope": {"read write"},
})
if err != nil {
log.Fatal(err)
}
resp, err := client.Do(ctx, req)
if err != nil {
log.Fatal(err)
}
defer resp.Close()
body, _ := resp.String()
fmt.Printf("Status: %d\nBody: %s\n", resp.StatusCode, body)
}

View File

@@ -0,0 +1,54 @@
// Demonstrates load balancing across multiple backend endpoints with
// circuit breaking and health-checked failover.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"time"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/balancer"
"git.codelab.vc/pkg/httpx/circuitbreaker"
"git.codelab.vc/pkg/httpx/retry"
)
func main() {
client := httpx.New(
httpx.WithEndpoints(
balancer.Endpoint{URL: "http://localhost:8081", Weight: 3},
balancer.Endpoint{URL: "http://localhost:8082", Weight: 1},
),
httpx.WithBalancer(
balancer.WithStrategy(balancer.WeightedRandom()),
balancer.WithHealthCheck(
balancer.WithHealthInterval(5*time.Second),
balancer.WithHealthPath("/healthz"),
),
),
httpx.WithCircuitBreaker(
circuitbreaker.WithFailureThreshold(5),
circuitbreaker.WithOpenDuration(30*time.Second),
),
httpx.WithRetry(
retry.WithMaxAttempts(3),
),
httpx.WithLogger(slog.Default()),
)
defer client.Close()
ctx := context.Background()
for i := range 5 {
resp, err := client.Get(ctx, fmt.Sprintf("/api/item/%d", i))
if err != nil {
log.Printf("request %d failed: %v", i, err)
continue
}
body, _ := resp.String()
fmt.Printf("request %d → %d: %s\n", i, resp.StatusCode, body)
}
}

View File

@@ -0,0 +1,54 @@
// Demonstrates request ID propagation between server and client.
// The server assigns a request ID to incoming requests, and the client
// middleware forwards it to downstream services via X-Request-Id header.
package main
import (
"fmt"
"log"
"log/slog"
"net/http"
"git.codelab.vc/pkg/httpx"
"git.codelab.vc/pkg/httpx/middleware"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
// Client that propagates request IDs from context.
client := httpx.New(
httpx.WithBaseURL("http://localhost:9090"),
httpx.WithMiddleware(
middleware.RequestID(), // Picks up ID from context, sets X-Request-Id.
),
)
defer client.Close()
r := server.NewRouter()
r.HandleFunc("GET /proxy", func(w http.ResponseWriter, r *http.Request) {
// The request ID is in r.Context() thanks to server.RequestID().
id := server.RequestIDFromContext(r.Context())
logger.Info("handling request", "request_id", id)
// Client automatically forwards the request ID to downstream.
resp, err := client.Get(r.Context(), "/downstream")
if err != nil {
server.WriteError(w, http.StatusBadGateway, fmt.Sprintf("downstream error: %v", err))
return
}
defer resp.Close()
body, _ := resp.String()
server.WriteJSON(w, http.StatusOK, map[string]string{
"request_id": id,
"downstream_response": body,
})
})
// Server with RequestID middleware that assigns IDs to incoming requests.
srv := server.New(r, server.Defaults(logger)...)
log.Fatal(srv.ListenAndServe())
}

View File

@@ -0,0 +1,54 @@
// Basic HTTP server with routing, middleware, health checks, and graceful shutdown.
package main
import (
"log"
"log/slog"
"net/http"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
r := server.NewRouter(
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, http.StatusNotFound, "resource not found")
})),
)
// Public endpoints.
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
server.WriteJSON(w, http.StatusOK, map[string]string{
"message": "Hello, World!",
})
})
// API group with shared prefix.
api := r.Group("/api/v1")
api.HandleFunc("GET /users/{id}", getUser)
api.HandleFunc("POST /users", createUser)
// Health checks.
r.Mount("/", server.HealthHandler())
// Server with production defaults (RequestID → Recovery → Logging).
srv := server.New(r, server.Defaults(logger)...)
log.Fatal(srv.ListenAndServe())
}
func getUser(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
server.WriteJSON(w, http.StatusOK, map[string]string{
"id": id,
"name": "Alice",
})
}
func createUser(w http.ResponseWriter, r *http.Request) {
server.WriteJSON(w, http.StatusCreated, map[string]any{
"id": 1,
"message": "user created",
})
}

View File

@@ -0,0 +1,61 @@
// Production server with CORS, rate limiting, body size limits, and timeouts.
package main
import (
"log"
"log/slog"
"net/http"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func main() {
logger := slog.Default()
r := server.NewRouter(
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server.WriteError(w, http.StatusNotFound, "not found")
})),
)
r.HandleFunc("GET /api/data", func(w http.ResponseWriter, _ *http.Request) {
server.WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
})
r.HandleFunc("POST /api/upload", func(w http.ResponseWriter, r *http.Request) {
// Body is already limited by MaxBodySize middleware.
server.WriteJSON(w, http.StatusAccepted, map[string]string{"status": "received"})
})
r.Mount("/", server.HealthHandler())
srv := server.New(r,
append(
server.Defaults(logger),
server.WithMiddleware(
// CORS for browser-facing APIs.
server.CORS(
server.AllowOrigins("https://app.example.com", "https://admin.example.com"),
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
server.AllowHeaders("Authorization", "Content-Type"),
server.ExposeHeaders("X-Request-Id"),
server.AllowCredentials(true),
server.MaxAge(3600),
),
// Rate limit: 100 req/s per IP, burst of 200.
server.RateLimit(
server.WithRate(100),
server.WithBurst(200),
),
// Limit request body to 1 MB.
server.MaxBodySize(1<<20),
// Per-request timeout of 30 seconds.
server.Timeout(30*time.Second),
),
server.WithAddr(":8080"),
)...,
)
log.Fatal(srv.ListenAndServe())
}

View File

@@ -0,0 +1,19 @@
// Package requestid provides a shared context key for request IDs,
// allowing both client and server packages to access request IDs
// without circular imports.
package requestid
import "context"
type key struct{}
// NewContext returns a context with the given request ID.
func NewContext(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, key{}, id)
}
// FromContext returns the request ID from ctx, or empty string if not set.
func FromContext(ctx context.Context) string {
id, _ := ctx.Value(key{}).(string)
return id
}

28
middleware/doc.go Normal file
View File

@@ -0,0 +1,28 @@
// Package middleware provides client-side HTTP middleware for use with
// httpx.Client or any http.RoundTripper-based transport chain.
//
// Each middleware is a function of type func(http.RoundTripper) http.RoundTripper.
// Compose them with Chain:
//
// chain := middleware.Chain(
// middleware.Logging(logger),
// middleware.Recovery(),
// middleware.UserAgent("my-service/1.0"),
// )
// transport := chain(http.DefaultTransport)
//
// # Available middleware
//
// - Logging — structured request/response logging via slog
// - Recovery — panic recovery, converts panics to errors
// - DefaultHeaders — adds default headers to outgoing requests
// - UserAgent — sets User-Agent header
// - BearerAuth — dynamic Bearer token authentication
// - BasicAuth — HTTP Basic authentication
// - RequestID — propagates request ID from context to X-Request-Id header
//
// # RoundTripperFunc
//
// RoundTripperFunc adapts plain functions to http.RoundTripper, similar to
// http.HandlerFunc. Useful for testing and inline middleware.
package middleware

23
middleware/requestid.go Normal file
View File

@@ -0,0 +1,23 @@
package middleware
import (
"net/http"
"git.codelab.vc/pkg/httpx/internal/requestid"
)
// RequestID returns a middleware that propagates the request ID from the
// request context to the outgoing X-Request-Id header. This pairs with
// the server.RequestID middleware: the server stores the ID in the context,
// and the client middleware forwards it to downstream services.
func RequestID() Middleware {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
if id := requestid.FromContext(req.Context()); id != "" {
req = req.Clone(req.Context())
req.Header.Set("X-Request-Id", id)
}
return next.RoundTrip(req)
})
}
}

View File

@@ -0,0 +1,69 @@
package middleware_test
import (
"context"
"net/http"
"testing"
"git.codelab.vc/pkg/httpx/internal/requestid"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestRequestID(t *testing.T) {
t.Run("propagates ID from context", func(t *testing.T) {
var gotHeader string
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotHeader = req.Header.Get("X-Request-Id")
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
ctx := requestid.NewContext(context.Background(), "test-id-123")
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
_, err := mw.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotHeader != "test-id-123" {
t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123")
}
})
t.Run("no ID in context skips header", func(t *testing.T) {
var gotHeader string
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
gotHeader = req.Header.Get("X-Request-Id")
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
_, err := mw.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotHeader != "" {
t.Fatalf("expected no X-Request-Id header, got %q", gotHeader)
}
})
t.Run("does not mutate original request", func(t *testing.T) {
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
})
mw := middleware.RequestID()(base)
ctx := requestid.NewContext(context.Background(), "test-id")
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
_, _ = mw.RoundTrip(req)
if req.Header.Get("X-Request-Id") != "" {
t.Fatal("original request was mutated")
}
})
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
)
// NewRequest creates an http.Request with context. It is a convenience
@@ -32,3 +33,20 @@ func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Re
}
return req, nil
}
// NewFormRequest creates an http.Request with a form-encoded body and
// sets Content-Type to application/x-www-form-urlencoded.
// The GetBody function is set so that the request can be retried.
func NewFormRequest(ctx context.Context, method, rawURL string, values url.Values) (*http.Request, error) {
encoded := values.Encode()
b := []byte(encoded)
req, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(b))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(b)), nil
}
return req, nil
}

80
request_form_test.go Normal file
View File

@@ -0,0 +1,80 @@
package httpx_test
import (
"context"
"io"
"net/http"
"net/url"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestNewFormRequest(t *testing.T) {
t.Run("body is form-encoded", func(t *testing.T) {
values := url.Values{"username": {"alice"}, "scope": {"read"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com/token", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
parsed, err := url.ParseQuery(string(body))
if err != nil {
t.Fatalf("parsing form: %v", err)
}
if parsed.Get("username") != "alice" {
t.Errorf("username = %q, want %q", parsed.Get("username"), "alice")
}
if parsed.Get("scope") != "read" {
t.Errorf("scope = %q, want %q", parsed.Get("scope"), "read")
}
})
t.Run("content type is set", func(t *testing.T) {
values := url.Values{"key": {"value"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ct := req.Header.Get("Content-Type")
if ct != "application/x-www-form-urlencoded" {
t.Errorf("Content-Type = %q, want %q", ct, "application/x-www-form-urlencoded")
}
})
t.Run("GetBody works for retry", func(t *testing.T) {
values := url.Values{"key": {"value"}}
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if req.GetBody == nil {
t.Fatal("GetBody is nil")
}
b1, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
body2, err := req.GetBody()
if err != nil {
t.Fatalf("GetBody(): %v", err)
}
b2, err := io.ReadAll(body2)
if err != nil {
t.Fatalf("reading body2: %v", err)
}
if string(b1) != string(b2) {
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
}
})
}

View File

@@ -97,3 +97,18 @@ func (r *Response) BodyReader() io.Reader {
}
return r.Body
}
// limitedReadCloser wraps an io.LimitedReader with a separate Closer
// so the original body can be closed.
type limitedReadCloser struct {
R io.LimitedReader
C io.Closer
}
func (l *limitedReadCloser) Read(p []byte) (int, error) {
return l.R.Read(p)
}
func (l *limitedReadCloser) Close() error {
return l.C.Close()
}

76
response_limit_test.go Normal file
View File

@@ -0,0 +1,76 @@
package httpx_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx"
)
func TestClient_MaxResponseBody(t *testing.T) {
t.Run("allows response within limit", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, "hello")
}))
defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(1024))
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := resp.String()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if body != "hello" {
t.Fatalf("body = %q, want %q", body, "hello")
}
})
t.Run("truncates response exceeding limit", func(t *testing.T) {
largeBody := strings.Repeat("x", 1000)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, largeBody)
}))
defer srv.Close()
client := httpx.New(httpx.WithMaxResponseBody(100))
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := resp.Bytes()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if len(b) != 100 {
t.Fatalf("body length = %d, want %d", len(b), 100)
}
})
t.Run("no limit when zero", func(t *testing.T) {
largeBody := strings.Repeat("x", 10000)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, largeBody)
}))
defer srv.Close()
client := httpx.New()
resp, err := client.Get(context.Background(), srv.URL+"/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
b, err := resp.Bytes()
if err != nil {
t.Fatalf("reading body: %v", err)
}
if len(b) != 10000 {
t.Fatalf("body length = %d, want %d", len(b), 10000)
}
})
}

31
retry/doc.go Normal file
View File

@@ -0,0 +1,31 @@
// Package retry provides configurable HTTP request retry as client middleware.
//
// The retry middleware wraps an http.RoundTripper and automatically retries
// failed requests based on a configurable policy, with exponential backoff
// and optional jitter.
//
// # Usage
//
// mw := retry.Transport(
// retry.WithMaxAttempts(3),
// retry.WithBackoff(retry.ExponentialBackoff(100*time.Millisecond, 5*time.Second)),
// )
// transport := mw(http.DefaultTransport)
//
// # Retry-After
//
// The retry middleware respects the Retry-After response header. If a server
// returns 429 or 503 with Retry-After, the delay from the header overrides
// the backoff strategy.
//
// # Request bodies
//
// For requests with bodies to be retried, the request must have GetBody set.
// Use httpx.NewJSONRequest or httpx.NewFormRequest which set GetBody
// automatically.
//
// # Sentinel errors
//
// ErrRetryExhausted is returned when all attempts fail. The original error
// is wrapped and accessible via errors.Unwrap.
package retry

38
server/doc.go Normal file
View File

@@ -0,0 +1,38 @@
// Package server provides a production-ready HTTP server with graceful
// shutdown, middleware composition, routing, and JSON response helpers.
//
// # Server
//
// Server wraps http.Server with net.Listener, signal-based graceful shutdown
// (SIGINT/SIGTERM), and lifecycle hooks. It is configured via functional options:
//
// srv := server.New(handler,
// server.WithAddr(":8080"),
// server.Defaults(logger),
// )
// srv.ListenAndServe()
//
// # Router
//
// Router wraps http.ServeMux with middleware groups, prefix-based route groups,
// and sub-handler mounting. It supports Go 1.22+ method-based patterns:
//
// r := server.NewRouter()
// r.HandleFunc("GET /users/{id}", getUser)
//
// api := r.Group("/api/v1", authMiddleware)
// api.HandleFunc("GET /items", listItems)
//
// # Middleware
//
// Server middleware follows the func(http.Handler) http.Handler pattern.
// Available middleware: RequestID, Recovery, Logging, CORS, RateLimit,
// MaxBodySize, Timeout. Use Chain to compose them:
//
// chain := server.Chain(server.RequestID(), server.Recovery(logger), server.Logging(logger))
//
// # Response helpers
//
// WriteJSON and WriteError provide JSON response writing with proper
// Content-Type headers.
package server

55
server/health.go Normal file
View File

@@ -0,0 +1,55 @@
package server
import (
"encoding/json"
"net/http"
)
// ReadinessChecker is a function that reports whether a dependency is ready.
// Return nil if healthy, or an error describing the problem.
type ReadinessChecker func() error
// HealthHandler returns an http.Handler that exposes liveness and readiness
// endpoints:
//
// - GET /healthz — liveness check, always returns 200 OK
// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise
func HealthHandler(checkers ...ReadinessChecker) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
var errs []string
for _, check := range checkers {
if err := check(); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
w.WriteHeader(http.StatusServiceUnavailable)
_ = json.NewEncoder(w).Encode(healthResponse{
Status: "unavailable",
Errors: errs,
})
return
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
return mux
}
type healthResponse struct {
Status string `json:"status"`
Errors []string `json:"errors,omitempty"`
}

166
server/health_test.go Normal file
View File

@@ -0,0 +1,166 @@
package server_test
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestHealthHandler(t *testing.T) {
t.Run("liveness always returns 200", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "ok" {
t.Fatalf("got status %q, want %q", resp["status"], "ok")
}
})
t.Run("readiness returns 200 when all checks pass", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return nil },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("readiness returns 503 when a check fails", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return errors.New("db down") },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "unavailable" {
t.Fatalf("got status %q, want %q", resp["status"], "unavailable")
}
errs, ok := resp["errors"].([]any)
if !ok || len(errs) != 1 {
t.Fatalf("expected 1 error, got %v", resp["errors"])
}
if errs[0] != "db down" {
t.Fatalf("got error %q, want %q", errs[0], "db down")
}
})
t.Run("readiness returns 200 with no checkers", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestHealth_MultipleFailingCheckers(t *testing.T) {
h := server.HealthHandler(
func() error { return errors.New("db down") },
func() error { return errors.New("cache down") },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
errs, ok := resp["errors"].([]any)
if !ok || len(errs) != 2 {
t.Fatalf("expected 2 errors, got %v", resp["errors"])
}
errStrs := make(map[string]bool)
for _, e := range errs {
errStrs[e.(string)] = true
}
if !errStrs["db down"] || !errStrs["cache down"] {
t.Fatalf("expected 'db down' and 'cache down', got %v", errs)
}
}
func TestHealth_LivenessContentType(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
h.ServeHTTP(w, req)
ct := w.Header().Get("Content-Type")
if ct != "application/json" {
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
}
}
func TestHealth_ReadinessContentType(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
ct := w.Header().Get("Content-Type")
if ct != "application/json" {
t.Fatalf("got Content-Type %q, want %q", ct, "application/json")
}
}
func TestHealth_PostMethodNotAllowed(t *testing.T) {
h := server.HealthHandler()
for _, path := range []string{"/healthz", "/readyz"} {
t.Run("POST "+path, func(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, path, nil)
h.ServeHTTP(w, req)
// ServeMux with "GET /healthz" pattern should reject POST.
if w.Code == http.StatusOK {
t.Fatalf("POST %s should not return 200, got %d", path, w.Code)
}
})
}
}

56
server/middleware.go Normal file
View File

@@ -0,0 +1,56 @@
package server
import "net/http"
// Middleware wraps an http.Handler to add behavior.
// This is the server-side counterpart of the client middleware type
// func(http.RoundTripper) http.RoundTripper.
type Middleware func(http.Handler) http.Handler
// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))).
// Middlewares are applied from right to left: C wraps handler first, then B wraps
// the result, then A wraps last. This means A is the outermost layer and sees
// every request first.
func Chain(mws ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
for i := len(mws) - 1; i >= 0; i-- {
h = mws[i](h)
}
return h
}
}
// statusWriter wraps http.ResponseWriter to capture the response status code.
// It implements Unwrap() so that http.ResponseController can access the
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
type statusWriter struct {
http.ResponseWriter
status int
written bool
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (w *statusWriter) WriteHeader(code int) {
if !w.written {
w.status = code
w.written = true
}
w.ResponseWriter.WriteHeader(code)
}
// Write delegates to the underlying writer, defaulting status to 200 if
// WriteHeader was not called explicitly.
func (w *statusWriter) Write(b []byte) (int, error) {
if !w.written {
w.status = http.StatusOK
w.written = true
}
return w.ResponseWriter.Write(b)
}
// Unwrap returns the underlying ResponseWriter. This is required for
// http.ResponseController to detect optional interfaces like http.Flusher
// and http.Hijacker on the original writer.
func (w *statusWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@@ -0,0 +1,15 @@
package server
import "net/http"
// MaxBodySize returns a middleware that limits the size of incoming request
// bodies. If the body exceeds n bytes, the server returns 413 Request Entity
// Too Large. It wraps the body with http.MaxBytesReader.
func MaxBodySize(n int64) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, n)
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,61 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestMaxBodySize(t *testing.T) {
t.Run("allows body within limit", func(t *testing.T) {
handler := server.MaxBodySize(1024)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write(b)
}),
)
body := strings.NewReader("hello")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", body)
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if w.Body.String() != "hello" {
t.Fatalf("got body %q, want %q", w.Body.String(), "hello")
}
})
t.Run("rejects body exceeding limit", func(t *testing.T) {
handler := server.MaxBodySize(5)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}),
)
body := strings.NewReader("this is longer than 5 bytes")
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", body)
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("got status %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
}

128
server/middleware_cors.go Normal file
View File

@@ -0,0 +1,128 @@
package server
import (
"net/http"
"strconv"
"strings"
)
type corsOptions struct {
allowOrigins []string
allowMethods []string
allowHeaders []string
exposeHeaders []string
allowCredentials bool
maxAge int
}
// CORSOption configures the CORS middleware.
type CORSOption func(*corsOptions)
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
// Default is no origins (CORS disabled).
func AllowOrigins(origins ...string) CORSOption {
return func(o *corsOptions) { o.allowOrigins = origins }
}
// AllowMethods sets the allowed HTTP methods for preflight requests.
// Default is GET, POST, HEAD.
func AllowMethods(methods ...string) CORSOption {
return func(o *corsOptions) { o.allowMethods = methods }
}
// AllowHeaders sets the allowed request headers for preflight requests.
func AllowHeaders(headers ...string) CORSOption {
return func(o *corsOptions) { o.allowHeaders = headers }
}
// ExposeHeaders sets headers that browsers are allowed to access.
func ExposeHeaders(headers ...string) CORSOption {
return func(o *corsOptions) { o.exposeHeaders = headers }
}
// AllowCredentials indicates whether the response to the request can be
// exposed when the credentials flag is true.
func AllowCredentials(allow bool) CORSOption {
return func(o *corsOptions) { o.allowCredentials = allow }
}
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
func MaxAge(seconds int) CORSOption {
return func(o *corsOptions) { o.maxAge = seconds }
}
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
// It processes preflight OPTIONS requests and sets the appropriate
// Access-Control-* response headers.
func CORS(opts ...CORSOption) Middleware {
o := &corsOptions{
allowMethods: []string{"GET", "POST", "HEAD"},
}
for _, opt := range opts {
opt(o)
}
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
allowAll := false
for _, origin := range o.allowOrigins {
if origin == "*" {
allowAll = true
}
allowedOrigins[origin] = struct{}{}
}
methods := strings.Join(o.allowMethods, ", ")
headers := strings.Join(o.allowHeaders, ", ")
expose := strings.Join(o.exposeHeaders, ", ")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
next.ServeHTTP(w, r)
return
}
allowed := allowAll
if !allowed {
_, allowed = allowedOrigins[origin]
}
if !allowed {
next.ServeHTTP(w, r)
return
}
// Set the allowed origin. When credentials are enabled,
// we must echo the specific origin, not "*".
if allowAll && !o.allowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if o.allowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if expose != "" {
w.Header().Set("Access-Control-Expose-Headers", expose)
}
w.Header().Add("Vary", "Origin")
// Handle preflight.
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
w.Header().Set("Access-Control-Allow-Methods", methods)
if headers != "" {
w.Header().Set("Access-Control-Allow-Headers", headers)
}
if o.maxAge > 0 {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
}
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,143 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestCORS(t *testing.T) {
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Run("no Origin header passes through", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "" {
t.Fatal("expected no CORS headers without Origin")
}
})
t.Run("wildcard origin", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
}
})
t.Run("specific origin allowed", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
}
})
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://evil.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
}
})
t.Run("preflight OPTIONS", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("http://example.com"),
server.AllowMethods("GET", "POST", "PUT"),
server.AllowHeaders("Authorization", "Content-Type"),
server.MaxAge(3600),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
req.Header.Set("Origin", "http://example.com")
req.Header.Set("Access-Control-Request-Method", "POST")
mw.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
}
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
}
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
}
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
}
})
t.Run("credentials with specific origin", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("*"),
server.AllowCredentials(true),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
// With credentials, must echo specific origin even with wildcard config.
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
}
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
}
})
t.Run("expose headers", func(t *testing.T) {
mw := server.CORS(
server.AllowOrigins("*"),
server.ExposeHeaders("X-Custom", "X-Request-Id"),
)(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
}
})
t.Run("Vary header is set", func(t *testing.T) {
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://example.com")
mw.ServeHTTP(w, req)
if got := w.Header().Get("Vary"); got != "Origin" {
t.Fatalf("Vary = %q, want %q", got, "Origin")
}
})
}

View File

@@ -0,0 +1,39 @@
package server
import (
"log/slog"
"net/http"
"time"
)
// Logging returns a middleware that logs each request's method, path,
// status code, duration, and request ID using the provided structured logger.
func Logging(logger *slog.Logger) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
duration := time.Since(start)
attrs := []slog.Attr{
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", sw.status),
slog.Duration("duration", duration),
}
if id := RequestIDFromContext(r.Context()); id != "" {
attrs = append(attrs, slog.String("request_id", id))
}
level := slog.LevelInfo
if sw.status >= http.StatusInternalServerError {
level = slog.LevelError
}
logger.LogAttrs(r.Context(), level, "request completed", attrs...)
})
}
}

View File

@@ -0,0 +1,129 @@
package server
import (
"net"
"net/http"
"strconv"
"sync"
"time"
"git.codelab.vc/pkg/httpx/internal/clock"
)
type rateLimitOptions struct {
rate float64
burst int
keyFunc func(r *http.Request) string
clock clock.Clock
}
// RateLimitOption configures the RateLimit middleware.
type RateLimitOption func(*rateLimitOptions)
// WithRate sets the token refill rate (tokens per second).
func WithRate(tokensPerSecond float64) RateLimitOption {
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
}
// WithBurst sets the maximum burst size (bucket capacity).
func WithBurst(n int) RateLimitOption {
return func(o *rateLimitOptions) { o.burst = n }
}
// WithKeyFunc sets a custom function to extract the rate-limit key from a
// request. By default, the client IP address is used.
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
return func(o *rateLimitOptions) { o.keyFunc = fn }
}
// withRateLimitClock sets the clock for testing. Not exported.
func withRateLimitClock(c clock.Clock) RateLimitOption {
return func(o *rateLimitOptions) { o.clock = c }
}
// RateLimit returns a middleware that limits requests using a per-key token
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
// Requests with a Retry-After header.
func RateLimit(opts ...RateLimitOption) Middleware {
o := &rateLimitOptions{
rate: 10,
burst: 20,
clock: clock.System(),
}
for _, opt := range opts {
opt(o)
}
if o.keyFunc == nil {
o.keyFunc = clientIP
}
var buckets sync.Map
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := o.keyFunc(r)
val, _ := buckets.LoadOrStore(key, &bucket{
tokens: float64(o.burst),
lastTime: o.clock.Now(),
})
b := val.(*bucket)
b.mu.Lock()
now := o.clock.Now()
elapsed := now.Sub(b.lastTime).Seconds()
b.tokens += elapsed * o.rate
if b.tokens > float64(o.burst) {
b.tokens = float64(o.burst)
}
b.lastTime = now
if b.tokens < 1 {
retryAfter := (1 - b.tokens) / o.rate
b.mu.Unlock()
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
b.tokens--
b.mu.Unlock()
next.ServeHTTP(w, r)
})
}
}
type bucket struct {
mu sync.Mutex
tokens float64
lastTime time.Time
}
// clientIP extracts the client IP from the request. It checks
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// First IP in the comma-separated list is the original client.
if i := indexOf(xff, ','); i > 0 {
return xff[:i]
}
return xff
}
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
return xri
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func indexOf(s string, b byte) int {
for i := range len(s) {
if s[i] == b {
return i
}
}
return -1
}

View File

@@ -0,0 +1,171 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestRateLimit(t *testing.T) {
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Run("allows requests within limit", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(100),
server.WithBurst(10),
)(okHandler)
for i := range 10 {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
}
}
})
t.Run("rejects when burst exhausted", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(2),
)(okHandler)
// Exhaust burst.
for range 2 {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
}
// Next request should be rejected.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
if w.Header().Get("Retry-After") == "" {
t.Fatal("expected Retry-After header")
}
})
t.Run("different IPs have independent limits", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
)(okHandler)
// First IP exhausts its limit.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
// Second IP should still be allowed.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "5.6.7.8:5678"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("uses X-Forwarded-For", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
)(okHandler)
// Exhaust limit for 10.0.0.1.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req)
// Same forwarded IP should be rate limited.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
req.RemoteAddr = "192.168.1.1:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
})
t.Run("custom key function", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1),
server.WithBurst(1),
server.WithKeyFunc(func(r *http.Request) string {
return r.Header.Get("X-API-Key")
}),
)(okHandler)
// Exhaust key "abc".
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "abc")
mw.ServeHTTP(w, req)
// Same key should be rate limited.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "abc")
mw.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
}
// Different key should be allowed.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "xyz")
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("tokens refill over time", func(t *testing.T) {
mw := server.RateLimit(
server.WithRate(1000), // Very fast refill for test
server.WithBurst(1),
)(okHandler)
// Exhaust burst.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
// Wait a bit for refill.
time.Sleep(5 * time.Millisecond)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
}
})
}

View File

@@ -0,0 +1,50 @@
package server
import (
"log/slog"
"net/http"
"runtime/debug"
)
// RecoveryOption configures the Recovery middleware.
type RecoveryOption func(*recoveryOptions)
type recoveryOptions struct {
logger *slog.Logger
}
// WithRecoveryLogger sets the logger for the Recovery middleware.
// If not set, panics are recovered silently (500 is still returned).
func WithRecoveryLogger(l *slog.Logger) RecoveryOption {
return func(o *recoveryOptions) { o.logger = l }
}
// Recovery returns a middleware that recovers from panics in downstream
// handlers. A recovered panic results in a 500 Internal Server Error
// response and is logged (if a logger is configured) with the stack trace.
func Recovery(opts ...RecoveryOption) Middleware {
o := &recoveryOptions{}
for _, opt := range opts {
opt(o)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if v := recover(); v != nil {
if o.logger != nil {
o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered",
slog.Any("panic", v),
slog.String("stack", string(debug.Stack())),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.String("request_id", RequestIDFromContext(r.Context())),
)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,51 @@
package server
import (
"context"
"crypto/rand"
"fmt"
"net/http"
"git.codelab.vc/pkg/httpx/internal/requestid"
)
// RequestID returns a middleware that assigns a unique request ID to each
// request. If the incoming request already has an X-Request-Id header, that
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
//
// The request ID is stored in the request context (retrieve with
// RequestIDFromContext) and set on the response X-Request-Id header.
func RequestID() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-Id")
if id == "" {
id = newUUID()
}
ctx := requestid.NewContext(r.Context(), id)
w.Header().Set("X-Request-Id", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequestIDFromContext returns the request ID from the context, or an empty
// string if none is set.
func RequestIDFromContext(ctx context.Context) string {
return requestid.FromContext(ctx)
}
// newUUID generates a UUID v4 string using crypto/rand.
func newUUID() string {
var uuid [16]byte
_, _ = rand.Read(uuid[:])
// Set version 4 (bits 12-15 of time_hi_and_version).
uuid[6] = (uuid[6] & 0x0f) | 0x40
// Set variant bits (10xx).
uuid[8] = (uuid[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
}

508
server/middleware_test.go Normal file
View File

@@ -0,0 +1,508 @@
package server_test
import (
"bytes"
"log/slog"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestChain(t *testing.T) {
t.Run("applies middlewares in correct order", func(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A-before")
next.ServeHTTP(w, r)
order = append(order, "A-after")
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B-before")
next.ServeHTTP(w, r)
order = append(order, "B-after")
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mwA, mwB)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
})
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
chained := server.Chain()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("handler was not called")
}
})
}
func TestChain_SingleMiddleware(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mw)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("single middleware was not called")
}
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
}
func TestStatusWriter_WriteHeaderMultipleCalls(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
w.WriteHeader(http.StatusNotFound) // second call should not change captured status
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=201") {
t.Fatalf("expected status=201 (first WriteHeader call captured), got %q", buf.String())
}
}
func TestStatusWriter_WriteDefaultsTo200(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello")) // Write without WriteHeader
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 when Write called without WriteHeader, got %q", buf.String())
}
}
func TestStatusWriter_Unwrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
rc := http.NewResponseController(w)
if err := rc.Flush(); err != nil {
// httptest.ResponseRecorder implements Flusher, so this should succeed
// if Unwrap works correctly.
http.Error(w, "flush failed", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
})
// Use Logging to wrap in statusWriter
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code == http.StatusInternalServerError {
t.Fatal("Flush failed — Unwrap likely not exposing underlying Flusher")
}
}
func TestRequestID(t *testing.T) {
t.Run("generates ID when not present", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if gotID == "" {
t.Fatal("expected request ID in context, got empty")
}
if w.Header().Get("X-Request-Id") != gotID {
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
}
// UUID v4 format: 8-4-4-4-12 hex chars.
if len(gotID) != 36 {
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
}
})
t.Run("preserves existing ID", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", "custom-123")
mw.ServeHTTP(w, req)
if gotID != "custom-123" {
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
}
})
t.Run("context without ID returns empty", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if id := server.RequestIDFromContext(req.Context()); id != "" {
t.Fatalf("expected empty, got %q", id)
}
})
}
func TestRequestID_UUIDFormat(t *testing.T) {
uuidV4Re := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !uuidV4Re.MatchString(gotID) {
t.Fatalf("generated ID %q does not match UUID v4 format", gotID)
}
}
func TestRequestID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 1000)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := server.RequestIDFromContext(r.Context())
if _, exists := seen[id]; exists {
t.Fatalf("duplicate request ID: %q", id)
}
seen[id] = struct{}{}
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
for i := 0; i < 1000; i++ {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
}
if len(seen) != 1000 {
t.Fatalf("expected 1000 unique IDs, got %d", len(seen))
}
}
func TestRecovery(t *testing.T) {
t.Run("recovers from panic and returns 500", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("something went wrong")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
})
t.Run("logs panic with logger", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
})
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
}
if !strings.Contains(buf.String(), "boom") {
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
}
})
t.Run("passes through without panic", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestRecovery_PanicWithNonString(t *testing.T) {
tests := []struct {
name string
value any
}{
{"integer", 42},
{"struct", struct{ X int }{X: 1}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic(tt.value)
})
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected 'panic recovered' in log, got %q", buf.String())
}
})
}
}
func TestRecovery_ResponseBody(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("fail")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
body := strings.TrimSpace(w.Body.String())
if body != "Internal Server Error" {
t.Fatalf("got body %q, want %q", body, "Internal Server Error")
}
}
func TestRecovery_LogAttributes(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
// Put RequestID before Recovery so request_id is in context
handler := server.RequestID()(
server.Recovery(server.WithRecoveryLogger(logger))(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
handler.ServeHTTP(w, req)
logOutput := buf.String()
for _, attr := range []string{"method=", "path=", "request_id="} {
if !strings.Contains(logOutput, attr) {
t.Fatalf("expected %q in log, got %q", attr, logOutput)
}
}
}
func TestLogging(t *testing.T) {
t.Run("logs request details", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "request completed") {
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "POST") {
t.Fatalf("expected method in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "/api/users") {
t.Fatalf("expected path in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "status=201") {
t.Fatalf("expected status=201 in log, got %q", logOutput)
}
})
t.Run("logs error level for 5xx", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("expected ERROR level in log, got %q", logOutput)
}
})
}
func TestLogging_4xxIsInfoLevel(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/missing", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=INFO") {
t.Fatalf("expected INFO level for 404, got %q", logOutput)
}
if strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("404 should not be logged as ERROR, got %q", logOutput)
}
}
func TestLogging_DefaultStatus200(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello"))
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "status=200") {
t.Fatalf("expected status=200 in log when handler only calls Write, got %q", buf.String())
}
}
func TestLogging_IncludesRequestID(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := server.RequestID()(
server.Logging(logger)(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "request_id=") {
t.Fatalf("expected request_id in log output, got %q", buf.String())
}
}

View File

@@ -0,0 +1,15 @@
package server
import (
"net/http"
"time"
)
// Timeout returns a middleware that limits request processing time.
// If the handler does not complete within d, the client receives a
// 503 Service Unavailable response. It wraps http.TimeoutHandler.
func Timeout(d time.Duration) Middleware {
return func(next http.Handler) http.Handler {
return http.TimeoutHandler(next, d, "Service Unavailable\n")
}
}

View File

@@ -0,0 +1,49 @@
package server_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestTimeout(t *testing.T) {
t.Run("handler completes within timeout", func(t *testing.T) {
handler := server.Timeout(1 * time.Second)(
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("handler exceeds timeout returns 503", func(t *testing.T) {
handler := server.Timeout(10 * time.Millisecond)(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-time.After(1 * time.Second):
case <-r.Context().Done():
}
w.WriteHeader(http.StatusOK)
}),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
})
}

89
server/options.go Normal file
View File

@@ -0,0 +1,89 @@
package server
import (
"log/slog"
"time"
)
type serverOptions struct {
addr string
readTimeout time.Duration
readHeaderTimeout time.Duration
writeTimeout time.Duration
idleTimeout time.Duration
shutdownTimeout time.Duration
logger *slog.Logger
middlewares []Middleware
onShutdown []func()
}
// Option configures a Server.
type Option func(*serverOptions)
// WithAddr sets the listen address. Default is ":8080".
func WithAddr(addr string) Option {
return func(o *serverOptions) { o.addr = addr }
}
// WithReadTimeout sets the maximum duration for reading the entire request.
func WithReadTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readTimeout = d }
}
// WithReadHeaderTimeout sets the maximum duration for reading request headers.
func WithReadHeaderTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readHeaderTimeout = d }
}
// WithWriteTimeout sets the maximum duration before timing out writes of the response.
func WithWriteTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.writeTimeout = d }
}
// WithIdleTimeout sets the maximum amount of time to wait for the next request
// when keep-alives are enabled.
func WithIdleTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.idleTimeout = d }
}
// WithShutdownTimeout sets the maximum duration to wait for active connections
// to close during graceful shutdown. Default is 15 seconds.
func WithShutdownTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.shutdownTimeout = d }
}
// WithLogger sets the structured logger used by the server for lifecycle events.
func WithLogger(l *slog.Logger) Option {
return func(o *serverOptions) { o.logger = l }
}
// WithMiddleware appends server middlewares to the chain.
// These are applied to the handler in the order given.
func WithMiddleware(mws ...Middleware) Option {
return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) }
}
// WithOnShutdown registers a function to be called during graceful shutdown,
// before the HTTP server begins draining connections.
func WithOnShutdown(fn func()) Option {
return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) }
}
// Defaults returns a production-ready set of options including standard
// middleware (RequestID, Recovery, Logging), sensible timeouts, and the
// provided logger.
//
// Middleware order: RequestID → Recovery → Logging → user handler.
func Defaults(logger *slog.Logger) []Option {
return []Option{
WithReadHeaderTimeout(10 * time.Second),
WithIdleTimeout(120 * time.Second),
WithShutdownTimeout(15 * time.Second),
WithLogger(logger),
WithMiddleware(
RequestID(),
Recovery(WithRecoveryLogger(logger)),
Logging(logger),
),
}
}

29
server/respond.go Normal file
View File

@@ -0,0 +1,29 @@
package server
import (
"encoding/json"
"net/http"
)
// WriteJSON encodes v as JSON and writes it to w with the given status code.
// It sets Content-Type to application/json.
func WriteJSON(w http.ResponseWriter, status int, v any) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_, err = w.Write(b)
return err
}
// WriteError writes a JSON error response with the given status code and
// message. The response body is {"error": "<message>"}.
func WriteError(w http.ResponseWriter, status int, msg string) error {
return WriteJSON(w, status, errorBody{Error: msg})
}
type errorBody struct {
Error string `json:"error"`
}

72
server/respond_test.go Normal file
View File

@@ -0,0 +1,72 @@
package server_test
import (
"encoding/json"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestWriteJSON(t *testing.T) {
t.Run("writes JSON with status and content type", func(t *testing.T) {
w := httptest.NewRecorder()
type resp struct {
ID int `json:"id"`
Name string `json:"name"`
}
err := server.WriteJSON(w, 201, resp{ID: 1, Name: "Alice"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if w.Code != 201 {
t.Fatalf("got status %d, want %d", w.Code, 201)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
}
var decoded resp
if err := json.Unmarshal(w.Body.Bytes(), &decoded); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if decoded.ID != 1 || decoded.Name != "Alice" {
t.Fatalf("got %+v, want {ID:1 Name:Alice}", decoded)
}
})
t.Run("returns error for unmarshalable input", func(t *testing.T) {
w := httptest.NewRecorder()
err := server.WriteJSON(w, 200, make(chan int))
if err == nil {
t.Fatal("expected error for channel type")
}
})
}
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
err := server.WriteError(w, 404, "not found")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if w.Code != 404 {
t.Fatalf("got status %d, want %d", w.Code, 404)
}
var body struct {
Error string `json:"error"`
}
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if body.Error != "not found" {
t.Fatalf("error = %q, want %q", body.Error, "not found")
}
}

126
server/route.go Normal file
View File

@@ -0,0 +1,126 @@
package server
import (
"net/http"
"strings"
)
// Router is a lightweight wrapper around http.ServeMux that adds middleware
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
// like "GET /users/{id}".
type Router struct {
mux *http.ServeMux
prefix string
middlewares []Middleware
notFoundHandler http.Handler
}
// RouterOption configures a Router.
type RouterOption func(*Router)
// WithNotFoundHandler sets a custom handler for requests that don't match
// any registered pattern. This is useful for returning JSON 404/405 responses
// instead of the default plain text.
func WithNotFoundHandler(h http.Handler) RouterOption {
return func(r *Router) { r.notFoundHandler = h }
}
// NewRouter creates a new Router backed by a fresh http.ServeMux.
func NewRouter(opts ...RouterOption) *Router {
r := &Router{
mux: http.NewServeMux(),
}
for _, opt := range opts {
opt(r)
}
return r
}
// Handle registers a handler for the given pattern. The pattern follows
// http.ServeMux conventions, including method-based patterns like "GET /users".
func (r *Router) Handle(pattern string, handler http.Handler) {
if len(r.middlewares) > 0 {
handler = Chain(r.middlewares...)(handler)
}
r.mux.Handle(r.prefixedPattern(pattern), handler)
}
// HandleFunc registers a handler function for the given pattern.
func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) {
r.Handle(pattern, fn)
}
// Group creates a sub-router with a shared prefix and optional middleware.
// Patterns registered on the group are prefixed automatically. The group
// shares the underlying ServeMux with the parent router.
//
// Example:
//
// api := router.Group("/api/v1", authMiddleware)
// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users"
func (r *Router) Group(prefix string, mws ...Middleware) *Router {
return &Router{
mux: r.mux,
prefix: r.prefix + prefix,
middlewares: append(r.middlewaresSnapshot(), mws...),
}
}
// Mount attaches an http.Handler under the given prefix. All requests
// starting with prefix are forwarded to the handler with the prefix stripped.
func (r *Router) Mount(prefix string, handler http.Handler) {
full := r.prefix + prefix
if !strings.HasSuffix(full, "/") {
full += "/"
}
r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler))
}
// ServeHTTP implements http.Handler, making Router usable as a handler.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.notFoundHandler != nil {
// Use the mux to check for a match. If none, use the custom handler.
_, pattern := r.mux.Handler(req)
if pattern == "" {
r.notFoundHandler.ServeHTTP(w, req)
return
}
}
r.mux.ServeHTTP(w, req)
}
// prefixedPattern inserts the router prefix into a pattern. It is aware of
// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users".
func (r *Router) prefixedPattern(pattern string) string {
if r.prefix == "" {
return pattern
}
// Split method prefix if present: "GET /users" → method="GET ", path="/users"
method, path, hasMethod := splitMethodPattern(pattern)
path = r.prefix + path
if hasMethod {
return method + path
}
return path
}
// splitMethodPattern splits "GET /path" into ("GET ", "/path", true).
// If there is no method prefix, returns ("", pattern, false).
func splitMethodPattern(pattern string) (method, path string, hasMethod bool) {
if idx := strings.IndexByte(pattern, ' '); idx >= 0 {
return pattern[:idx+1], pattern[idx+1:], true
}
return "", pattern, false
}
func (r *Router) middlewaresSnapshot() []Middleware {
if len(r.middlewares) == 0 {
return nil
}
cp := make([]Middleware, len(r.middlewares))
copy(cp, r.middlewares)
return cp
}

View File

@@ -0,0 +1,70 @@
package server_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestRouter_NotFoundHandler(t *testing.T) {
t.Run("custom 404 handler", func(t *testing.T) {
notFound := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
})
r := server.NewRouter(server.WithNotFoundHandler(notFound))
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Matched route works normally.
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/exists", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("matched route: got status %d, want %d", w.Code, http.StatusOK)
}
// Unmatched route uses custom handler.
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/nope", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("not found: got status %d, want %d", w.Code, http.StatusNotFound)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
}
var body map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if body["error"] != "not found" {
t.Fatalf("error = %q, want %q", body["error"], "not found")
}
})
t.Run("default behavior without custom handler", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/nope", nil)
r.ServeHTTP(w, req)
// Default ServeMux returns 404.
if w.Code != http.StatusNotFound {
t.Fatalf("got status %d, want %d", w.Code, http.StatusNotFound)
}
})
}

336
server/route_test.go Normal file
View File

@@ -0,0 +1,336 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestRouter(t *testing.T) {
t.Run("basic route", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("world"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "world" {
t.Fatalf("got body %q, want %q", body, "world")
}
})
t.Run("Handle with http.Handler", func(t *testing.T) {
r := server.NewRouter()
r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("pong"))
}))
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "pong" {
t.Fatalf("got body %q, want %q", body, "pong")
}
})
t.Run("path parameter", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("user:" + req.PathValue("id")))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/users/42", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "user:42" {
t.Fatalf("got body %q, want %q", body, "user:42")
}
})
}
func TestRouterGroup(t *testing.T) {
t.Run("prefix is applied", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api/v1")
api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("users"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "users" {
t.Fatalf("got body %q, want %q", body, "users")
}
})
t.Run("nested groups", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api")
v1 := api.Group("/v1")
v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("items"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "items" {
t.Fatalf("got body %q, want %q", body, "items")
}
})
t.Run("group middleware", func(t *testing.T) {
var mwCalled bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mwCalled = true
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
g := r.Group("/admin", mw)
g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil)
r.ServeHTTP(w, req)
if !mwCalled {
t.Fatal("group middleware was not called")
}
})
}
func TestRouterMount(t *testing.T) {
t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) {
sub := http.NewServeMux()
sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("info"))
})
r := server.NewRouter()
r.Mount("/sub", sub)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/sub/info", nil)
r.ServeHTTP(w, req)
body, _ := io.ReadAll(w.Body)
if string(body) != "info" {
t.Fatalf("got body %q, want %q", body, "info")
}
})
t.Run("mount with trailing slash", func(t *testing.T) {
sub := http.NewServeMux()
sub.HandleFunc("GET /data", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("data"))
})
r := server.NewRouter()
r.Mount("/sub/", sub)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/sub/data", nil)
r.ServeHTTP(w, req)
body, _ := io.ReadAll(w.Body)
if string(body) != "data" {
t.Fatalf("got body %q, want %q", body, "data")
}
})
}
func TestRouter_PatternWithoutMethod(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("/static/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("static"))
})
for _, method := range []string{http.MethodGet, http.MethodPost} {
w := httptest.NewRecorder()
req := httptest.NewRequest(method, "/static/file.css", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("%s /static/file.css: got status %d, want %d", method, w.Code, http.StatusOK)
}
}
}
func TestRouter_GroupEmptyPrefix(t *testing.T) {
r := server.NewRouter()
g := r.Group("")
g.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("hello"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "hello" {
t.Fatalf("got body %q, want %q", body, "hello")
}
}
func TestRouter_GroupInheritsMiddleware(t *testing.T) {
var order []string
parentMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "parent")
next.ServeHTTP(w, r)
})
}
childMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "child")
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
parent := r.Group("/api", parentMW)
child := parent.Group("/v1", childMW)
child.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
expected := []string{"parent", "child", "handler"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
}
func TestRouter_GroupMiddlewareOrder(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A")
next.ServeHTTP(w, r)
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B")
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
g := r.Group("/api", mwA)
sub := g.Group("/v1", mwB)
sub.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/test", nil)
r.ServeHTTP(w, req)
// Parent MW (A) should run before child MW (B), then handler.
expected := []string{"A", "B", "handler"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
}
func TestRouter_PathParamWithGroup(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api")
api.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("id=" + req.PathValue("id")))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/users/42", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "id=42" {
t.Fatalf("got body %q, want %q", body, "id=42")
}
}
func TestRouter_MiddlewareNotAppliedToOtherRoutes(t *testing.T) {
var mwCalled bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mwCalled = true
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
// Add middleware only to /admin group.
admin := r.Group("/admin", mw)
admin.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("admin"))
})
// Route outside the group.
r.HandleFunc("GET /public", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("public"))
})
// Request to /public should NOT trigger group middleware.
mwCalled = false
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/public", nil)
r.ServeHTTP(w, req)
if mwCalled {
t.Fatal("group middleware should not be called for routes outside the group")
}
if w.Body.String() != "public" {
t.Fatalf("got body %q, want %q", w.Body.String(), "public")
}
}

173
server/server.go Normal file
View File

@@ -0,0 +1,173 @@
package server
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"os/signal"
"sync/atomic"
"syscall"
"time"
)
// Server is a production-ready HTTP server with graceful shutdown,
// middleware support, and signal handling.
type Server struct {
httpServer *http.Server
listener net.Listener
addr atomic.Value
logger *slog.Logger
shutdownTimeout time.Duration
onShutdown []func()
listenAddr string
}
// New creates a new Server that will serve the given handler with the
// provided options. Middleware from options is applied to the handler.
func New(handler http.Handler, opts ...Option) *Server {
o := &serverOptions{
addr: ":8080",
shutdownTimeout: 15 * time.Second,
}
for _, opt := range opts {
opt(o)
}
// Apply middleware chain to the handler.
if len(o.middlewares) > 0 {
handler = Chain(o.middlewares...)(handler)
}
srv := &Server{
httpServer: &http.Server{
Handler: handler,
ReadTimeout: o.readTimeout,
ReadHeaderTimeout: o.readHeaderTimeout,
WriteTimeout: o.writeTimeout,
IdleTimeout: o.idleTimeout,
},
logger: o.logger,
shutdownTimeout: o.shutdownTimeout,
onShutdown: o.onShutdown,
listenAddr: o.addr,
}
return srv
}
// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM
// signal is received. It then performs a graceful shutdown within the
// configured shutdown timeout.
//
// Returns nil on clean shutdown or an error if listen/shutdown fails.
func (s *Server) ListenAndServe() error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started", slog.String("addr", ln.Addr().String()))
// Wait for signal in context.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// ListenAndServeTLS starts the server with TLS and blocks until a signal
// is received.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started (TLS)", slog.String("addr", ln.Addr().String()))
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// Shutdown gracefully shuts down the server. It calls any registered
// onShutdown hooks, then waits for active connections to drain within
// the shutdown timeout.
func (s *Server) Shutdown(ctx context.Context) error {
s.runOnShutdown()
return s.httpServer.Shutdown(ctx)
}
// Addr returns the listener address after the server has started.
// Returns an empty string if the server has not started yet.
func (s *Server) Addr() string {
v := s.addr.Load()
if v == nil {
return ""
}
return v.(string)
}
func (s *Server) shutdown() error {
s.log("shutting down")
s.runOnShutdown()
ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.log("shutdown error", slog.String("error", err.Error()))
return err
}
s.log("server stopped")
return nil
}
func (s *Server) runOnShutdown() {
for _, fn := range s.onShutdown {
fn()
}
}
func (s *Server) log(msg string, attrs ...slog.Attr) {
if s.logger != nil {
s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...)
}
}

268
server/server_test.go Normal file
View File

@@ -0,0 +1,268 @@
package server_test
import (
"bytes"
"context"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestServerLifecycle(t *testing.T) {
t.Run("starts and serves requests", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello"))
})
srv := server.New(handler, server.WithAddr(":0"))
// Start in background and wait for addr.
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if string(body) != "hello" {
t.Fatalf("got body %q, want %q", body, "hello")
}
// Shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
})
t.Run("addr returns empty before start", func(t *testing.T) {
srv := server.New(http.NotFoundHandler())
if addr := srv.Addr(); addr != "" {
t.Fatalf("got addr %q before start, want empty", addr)
}
})
}
func TestGracefulShutdown(t *testing.T) {
t.Run("calls onShutdown hooks", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithOnShutdown(func() { called = true }),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
if !called {
t.Fatal("onShutdown hook was not called")
}
})
}
func TestServerWithMiddleware(t *testing.T) {
t.Run("applies middleware from options", func(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithMiddleware(mw),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
if !called {
t.Fatal("middleware was not called")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
})
}
func TestServerDefaults(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler, append(server.Defaults(logger), server.WithAddr(":0"))...)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
// Defaults includes RequestID middleware, so response should have X-Request-Id.
if resp.Header.Get("X-Request-Id") == "" {
t.Fatal("expected X-Request-Id header from Defaults middleware")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
}
func TestServerListenError(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Use an invalid address to trigger a listen error.
srv := server.New(handler, server.WithAddr(":-1"))
err := srv.ListenAndServe()
if err == nil {
t.Fatal("expected error from invalid address, got nil")
}
}
func TestServerMultipleOnShutdownHooks(t *testing.T) {
var calls []int
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithOnShutdown(func() { calls = append(calls, 1) }),
server.WithOnShutdown(func() { calls = append(calls, 2) }),
server.WithOnShutdown(func() { calls = append(calls, 3) }),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
if len(calls) != 3 {
t.Fatalf("expected 3 hooks called, got %d: %v", len(calls), calls)
}
}
func TestServerShutdownWithLogger(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithLogger(logger),
)
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
waitForAddr(t, srv)
// Send SIGINT to trigger graceful shutdown via ListenAndServe's signal handler.
// Instead, use Shutdown directly and check log from server start.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
// The server logs "server started" on ListenAndServe.
logOutput := buf.String()
if !strings.Contains(logOutput, "server started") {
t.Fatalf("expected 'server started' in log, got %q", logOutput)
}
}
func TestServerOptions(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Verify options don't panic and server starts correctly.
srv := server.New(handler,
server.WithAddr(":0"),
server.WithReadTimeout(5*time.Second),
server.WithReadHeaderTimeout(3*time.Second),
server.WithWriteTimeout(10*time.Second),
server.WithIdleTimeout(60*time.Second),
server.WithShutdownTimeout(5*time.Second),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
}
// waitForAddr polls until the server's Addr() is non-empty.
func waitForAddr(t *testing.T, srv *server.Server) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if srv.Addr() != "" {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("server did not start in time")
}