Compare commits
11 Commits
4d47918a66
...
89cfc38f0e
| Author | SHA1 | Date | |
|---|---|---|---|
| 89cfc38f0e | |||
| 8a63f142a7 | |||
| 21274c178a | |||
| 49be6f8a7e | |||
| 3395f70abd | |||
| 7a2cef00c3 | |||
| de5bf9a6d9 | |||
| 7f12b0c87a | |||
| 1b322c8c81 | |||
| b40a373675 | |||
| f6384ecbea |
12
CLAUDE.md
12
CLAUDE.md
@@ -22,15 +22,25 @@ go vet ./... # static analysis
|
||||
- **Sentinel errors**: canonical values live in sub-packages, root package re-exports as aliases
|
||||
- **balancer.Transport** returns `(Middleware, *Closer)` — Closer must be tracked for health checker shutdown
|
||||
- **Client.Close()** stops the health checker goroutine
|
||||
- **Client.Patch()** — PATCH method, same pattern as Put/Post
|
||||
- **NewFormRequest** — form-encoded request builder (`application/x-www-form-urlencoded`) with `GetBody` for retry
|
||||
- **WithMaxResponseBody** — wraps `resp.Body` with `io.LimitedReader` to prevent OOM
|
||||
- **middleware.RequestID()** — propagates request ID from context to outgoing `X-Request-Id` header
|
||||
- **`internal/requestid`** — shared context key used by both `server` and `middleware` packages to avoid circular imports
|
||||
|
||||
### Server (`server/`)
|
||||
- **Core pattern**: middleware is `func(http.Handler) http.Handler`
|
||||
- **Server** wraps `http.Server` with `net.Listener`, graceful shutdown via signal handling, lifecycle hooks
|
||||
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers
|
||||
- **Router** wraps `http.ServeMux` — supports groups with prefix + middleware inheritance, `Mount` for sub-handlers, `WithNotFoundHandler` for custom 404
|
||||
- **Middleware chain** via `Chain(A, B, C)` — A outermost, C innermost (same as client side)
|
||||
- **statusWriter** wraps `http.ResponseWriter` to capture status; implements `Unwrap()` for `http.ResponseController`
|
||||
- **Defaults()** preset: RequestID → Recovery → Logging + production timeouts
|
||||
- **HealthHandler** exposes `GET /healthz` (liveness) and `GET /readyz` (readiness with pluggable checkers)
|
||||
- **CORS** middleware — preflight OPTIONS handling, `AllowOrigins`, `AllowMethods`, `AllowHeaders`, `ExposeHeaders`, `AllowCredentials`, `MaxAge`
|
||||
- **RateLimit** middleware — per-key token bucket (`sync.Map`), IP from `X-Forwarded-For`, `WithRate`/`WithBurst`/`WithKeyFunc`, uses `internal/clock`
|
||||
- **MaxBodySize** middleware — wraps `r.Body` via `http.MaxBytesReader`
|
||||
- **Timeout** middleware — wraps `http.TimeoutHandler`, returns 503
|
||||
- **WriteJSON** / **WriteError** — JSON response helpers in `server/respond.go`
|
||||
|
||||
## Conventions
|
||||
|
||||
|
||||
78
README.md
78
README.md
@@ -1,6 +1,6 @@
|
||||
# httpx
|
||||
|
||||
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging), health checks, graceful shutdown. stdlib only, zero external deps.
|
||||
HTTP client and server toolkit for Go microservices. Client side: retry, load balancing, circuit breaking, request ID propagation, response size limits — all as `http.RoundTripper` middleware. Server side: routing, middleware (request ID, recovery, logging, CORS, rate limiting, body limits, timeouts), health checks, JSON helpers, graceful shutdown. stdlib only, zero external deps.
|
||||
|
||||
```
|
||||
go get git.codelab.vc/pkg/httpx
|
||||
@@ -29,6 +29,16 @@ if err != nil {
|
||||
|
||||
var user User
|
||||
resp.JSON(&user)
|
||||
|
||||
// PATCH request
|
||||
resp, err = client.Patch(ctx, "/users/123", strings.NewReader(`{"name":"updated"}`))
|
||||
|
||||
// Form-encoded request (OAuth, webhooks, etc.)
|
||||
req, _ := httpx.NewFormRequest(ctx, http.MethodPost, "/oauth/token", url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"scope": {"read write"},
|
||||
})
|
||||
resp, err = client.Do(ctx, req)
|
||||
```
|
||||
|
||||
## Packages
|
||||
@@ -42,7 +52,7 @@ Client middleware is `func(http.RoundTripper) http.RoundTripper`. Use them with
|
||||
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
|
||||
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
|
||||
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
|
||||
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery. |
|
||||
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery, request ID propagation. |
|
||||
|
||||
### Server
|
||||
|
||||
@@ -56,6 +66,12 @@ Server middleware is `func(http.Handler) http.Handler`. The `server` package pro
|
||||
| `server.Recovery` | Recovers panics, returns 500, logs stack trace. |
|
||||
| `server.Logging` | Structured request logging (method, path, status, duration, request ID). |
|
||||
| `server.HealthHandler` | Liveness (`/healthz`) and readiness (`/readyz`) endpoints with pluggable checkers. |
|
||||
| `server.CORS` | Cross-origin resource sharing with preflight handling and functional options. |
|
||||
| `server.RateLimit` | Per-key token bucket rate limiting with IP extraction and `Retry-After`. |
|
||||
| `server.MaxBodySize` | Limits request body size via `http.MaxBytesReader`. |
|
||||
| `server.Timeout` | Context-based request timeout, returns 503 on expiry. |
|
||||
| `server.WriteJSON` | JSON response helper, sets Content-Type and status. |
|
||||
| `server.WriteError` | JSON error response (`{"error": "..."}`) helper. |
|
||||
| `server.Defaults` | Production preset: RequestID → Recovery → Logging + sensible timeouts. |
|
||||
|
||||
The client assembles them in this order:
|
||||
@@ -111,9 +127,15 @@ httpClient := &http.Client{
|
||||
```go
|
||||
logger := slog.Default()
|
||||
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("world"))
|
||||
r := server.NewRouter(
|
||||
// Custom JSON 404 instead of plain text
|
||||
server.WithNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
server.WriteError(w, 404, "not found")
|
||||
})),
|
||||
)
|
||||
|
||||
r.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
|
||||
server.WriteJSON(w, 200, map[string]string{"message": "world"})
|
||||
})
|
||||
|
||||
// Groups with middleware
|
||||
@@ -125,10 +147,54 @@ r.Mount("/", server.HealthHandler(
|
||||
func() error { return db.Ping() },
|
||||
))
|
||||
|
||||
srv := server.New(r, server.Defaults(logger)...)
|
||||
srv := server.New(r,
|
||||
append(server.Defaults(logger),
|
||||
// Protection middleware
|
||||
server.WithMiddleware(
|
||||
server.CORS(
|
||||
server.AllowOrigins("https://app.example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT", "PATCH", "DELETE"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.MaxAge(3600),
|
||||
),
|
||||
server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(200),
|
||||
),
|
||||
server.MaxBodySize(1<<20), // 1 MB
|
||||
server.Timeout(30*time.Second),
|
||||
),
|
||||
)...,
|
||||
)
|
||||
log.Fatal(srv.ListenAndServe()) // graceful shutdown on SIGINT/SIGTERM
|
||||
```
|
||||
|
||||
## Client request ID propagation
|
||||
|
||||
In microservices, forward the incoming request ID to downstream calls:
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithMiddleware(middleware.RequestID()),
|
||||
)
|
||||
|
||||
// In a server handler — the context already has the request ID from server.RequestID():
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// ID is automatically forwarded as X-Request-Id
|
||||
resp, err := client.Get(r.Context(), "https://downstream/api")
|
||||
}
|
||||
```
|
||||
|
||||
## Response body limit
|
||||
|
||||
Protect against OOM from unexpectedly large upstream responses:
|
||||
|
||||
```go
|
||||
client := httpx.New(
|
||||
httpx.WithMaxResponseBody(10 << 20), // 10 MB max
|
||||
)
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
Go 1.24+, stdlib only.
|
||||
|
||||
18
client.go
18
client.go
@@ -20,6 +20,7 @@ type Client struct {
|
||||
baseURL string
|
||||
errorMapper ErrorMapper
|
||||
balancerCloser *balancer.Closer
|
||||
maxResponseBody int64
|
||||
}
|
||||
|
||||
// New creates a new Client with the given options.
|
||||
@@ -80,6 +81,7 @@ func New(opts ...Option) *Client {
|
||||
baseURL: o.baseURL,
|
||||
errorMapper: o.errorMapper,
|
||||
balancerCloser: balancerCloser,
|
||||
maxResponseBody: o.maxResponseBody,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +101,13 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if c.maxResponseBody > 0 {
|
||||
resp.Body = &limitedReadCloser{
|
||||
R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody},
|
||||
C: resp.Body,
|
||||
}
|
||||
}
|
||||
|
||||
r := newResponse(resp)
|
||||
|
||||
if c.errorMapper != nil {
|
||||
@@ -142,6 +151,15 @@ func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Patch performs a PATCH request to the given URL with the given body.
|
||||
func (c *Client) Patch(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Do(ctx, req)
|
||||
}
|
||||
|
||||
// Delete performs a DELETE request to the given URL.
|
||||
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
||||
|
||||
@@ -24,6 +24,7 @@ type clientOptions struct {
|
||||
enableCB bool
|
||||
endpoints []balancer.Endpoint
|
||||
balancerOpts []balancer.Option
|
||||
maxResponseBody int64
|
||||
}
|
||||
|
||||
// Option configures a Client.
|
||||
@@ -85,3 +86,11 @@ func WithEndpoints(eps ...balancer.Endpoint) Option {
|
||||
func WithBalancer(opts ...balancer.Option) Option {
|
||||
return func(o *clientOptions) { o.balancerOpts = opts }
|
||||
}
|
||||
|
||||
// WithMaxResponseBody limits the number of bytes read from response bodies
|
||||
// by Response.Bytes (and by extension String, JSON, XML). If the response
|
||||
// body exceeds n bytes, reading stops and returns an error.
|
||||
// A value of 0 means no limit (the default).
|
||||
func WithMaxResponseBody(n int64) Option {
|
||||
return func(o *clientOptions) { o.maxResponseBody = n }
|
||||
}
|
||||
|
||||
45
client_patch_test.go
Normal file
45
client_patch_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestClient_Patch(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPatch {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
if string(b) != `{"name":"updated"}` {
|
||||
t.Errorf("expected body %q, got %q", `{"name":"updated"}`, string(b))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "patched")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Patch(context.Background(), srv.URL+"/item/1", strings.NewReader(`{"name":"updated"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "patched" {
|
||||
t.Errorf("expected body %q, got %q", "patched", body)
|
||||
}
|
||||
}
|
||||
19
internal/requestid/requestid.go
Normal file
19
internal/requestid/requestid.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Package requestid provides a shared context key for request IDs,
|
||||
// allowing both client and server packages to access request IDs
|
||||
// without circular imports.
|
||||
package requestid
|
||||
|
||||
import "context"
|
||||
|
||||
type key struct{}
|
||||
|
||||
// NewContext returns a context with the given request ID.
|
||||
func NewContext(ctx context.Context, id string) context.Context {
|
||||
return context.WithValue(ctx, key{}, id)
|
||||
}
|
||||
|
||||
// FromContext returns the request ID from ctx, or empty string if not set.
|
||||
func FromContext(ctx context.Context) string {
|
||||
id, _ := ctx.Value(key{}).(string)
|
||||
return id
|
||||
}
|
||||
23
middleware/requestid.go
Normal file
23
middleware/requestid.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that propagates the request ID from the
|
||||
// request context to the outgoing X-Request-Id header. This pairs with
|
||||
// the server.RequestID middleware: the server stores the ID in the context,
|
||||
// and the client middleware forwards it to downstream services.
|
||||
func RequestID() Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if id := requestid.FromContext(req.Context()); id != "" {
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set("X-Request-Id", id)
|
||||
}
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
69
middleware/requestid_test.go
Normal file
69
middleware/requestid_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
t.Run("propagates ID from context", func(t *testing.T) {
|
||||
var gotHeader string
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotHeader = req.Header.Get("X-Request-Id")
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
ctx := requestid.NewContext(context.Background(), "test-id-123")
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||
_, err := mw.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotHeader != "test-id-123" {
|
||||
t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no ID in context skips header", func(t *testing.T) {
|
||||
var gotHeader string
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotHeader = req.Header.Get("X-Request-Id")
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil)
|
||||
_, err := mw.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if gotHeader != "" {
|
||||
t.Fatalf("expected no X-Request-Id header, got %q", gotHeader)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not mutate original request", func(t *testing.T) {
|
||||
base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mw := middleware.RequestID()(base)
|
||||
|
||||
ctx := requestid.NewContext(context.Background(), "test-id")
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||
_, _ = mw.RoundTrip(req)
|
||||
|
||||
if req.Header.Get("X-Request-Id") != "" {
|
||||
t.Fatal("original request was mutated")
|
||||
}
|
||||
})
|
||||
}
|
||||
18
request.go
18
request.go
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// NewRequest creates an http.Request with context. It is a convenience
|
||||
@@ -32,3 +33,20 @@ func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Re
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewFormRequest creates an http.Request with a form-encoded body and
|
||||
// sets Content-Type to application/x-www-form-urlencoded.
|
||||
// The GetBody function is set so that the request can be retried.
|
||||
func NewFormRequest(ctx context.Context, method, rawURL string, values url.Values) (*http.Request, error) {
|
||||
encoded := values.Encode()
|
||||
b := []byte(encoded)
|
||||
req, err := http.NewRequestWithContext(ctx, method, rawURL, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(b)), nil
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
80
request_form_test.go
Normal file
80
request_form_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestNewFormRequest(t *testing.T) {
|
||||
t.Run("body is form-encoded", func(t *testing.T) {
|
||||
values := url.Values{"username": {"alice"}, "scope": {"read"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com/token", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
t.Fatalf("parsing form: %v", err)
|
||||
}
|
||||
if parsed.Get("username") != "alice" {
|
||||
t.Errorf("username = %q, want %q", parsed.Get("username"), "alice")
|
||||
}
|
||||
if parsed.Get("scope") != "read" {
|
||||
t.Errorf("scope = %q, want %q", parsed.Get("scope"), "read")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content type is set", func(t *testing.T) {
|
||||
values := url.Values{"key": {"value"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ct := req.Header.Get("Content-Type")
|
||||
if ct != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("Content-Type = %q, want %q", ct, "application/x-www-form-urlencoded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetBody works for retry", func(t *testing.T) {
|
||||
values := url.Values{"key": {"value"}}
|
||||
req, err := httpx.NewFormRequest(context.Background(), http.MethodPost, "http://example.com", values)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if req.GetBody == nil {
|
||||
t.Fatal("GetBody is nil")
|
||||
}
|
||||
|
||||
b1, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
body2, err := req.GetBody()
|
||||
if err != nil {
|
||||
t.Fatalf("GetBody(): %v", err)
|
||||
}
|
||||
b2, err := io.ReadAll(body2)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body2: %v", err)
|
||||
}
|
||||
|
||||
if string(b1) != string(b2) {
|
||||
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
|
||||
}
|
||||
})
|
||||
}
|
||||
15
response.go
15
response.go
@@ -97,3 +97,18 @@ func (r *Response) BodyReader() io.Reader {
|
||||
}
|
||||
return r.Body
|
||||
}
|
||||
|
||||
// limitedReadCloser wraps an io.LimitedReader with a separate Closer
|
||||
// so the original body can be closed.
|
||||
type limitedReadCloser struct {
|
||||
R io.LimitedReader
|
||||
C io.Closer
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Read(p []byte) (int, error) {
|
||||
return l.R.Read(p)
|
||||
}
|
||||
|
||||
func (l *limitedReadCloser) Close() error {
|
||||
return l.C.Close()
|
||||
}
|
||||
|
||||
76
response_limit_test.go
Normal file
76
response_limit_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package httpx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx"
|
||||
)
|
||||
|
||||
func TestClient_MaxResponseBody(t *testing.T) {
|
||||
t.Run("allows response within limit", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, "hello")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(1024))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
body, err := resp.String()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if body != "hello" {
|
||||
t.Fatalf("body = %q, want %q", body, "hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncates response exceeding limit", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", 1000)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, largeBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New(httpx.WithMaxResponseBody(100))
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if len(b) != 100 {
|
||||
t.Fatalf("body length = %d, want %d", len(b), 100)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no limit when zero", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", 10000)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprint(w, largeBody)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := httpx.New()
|
||||
resp, err := client.Get(context.Background(), srv.URL+"/")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
if len(b) != 10000 {
|
||||
t.Fatalf("body length = %d, want %d", len(b), 10000)
|
||||
}
|
||||
})
|
||||
}
|
||||
15
server/middleware_bodylimit.go
Normal file
15
server/middleware_bodylimit.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// MaxBodySize returns a middleware that limits the size of incoming request
|
||||
// bodies. If the body exceeds n bytes, the server returns 413 Request Entity
|
||||
// Too Large. It wraps the body with http.MaxBytesReader.
|
||||
func MaxBodySize(n int64) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, n)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
61
server/middleware_bodylimit_test.go
Normal file
61
server/middleware_bodylimit_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestMaxBodySize(t *testing.T) {
|
||||
t.Run("allows body within limit", func(t *testing.T) {
|
||||
handler := server.MaxBodySize(1024)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}),
|
||||
)
|
||||
|
||||
body := strings.NewReader("hello")
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "hello" {
|
||||
t.Fatalf("got body %q, want %q", w.Body.String(), "hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects body exceeding limit", func(t *testing.T) {
|
||||
handler := server.MaxBodySize(5)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
body := strings.NewReader("this is longer than 5 bytes")
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
}
|
||||
128
server/middleware_cors.go
Normal file
128
server/middleware_cors.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type corsOptions struct {
|
||||
allowOrigins []string
|
||||
allowMethods []string
|
||||
allowHeaders []string
|
||||
exposeHeaders []string
|
||||
allowCredentials bool
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// CORSOption configures the CORS middleware.
|
||||
type CORSOption func(*corsOptions)
|
||||
|
||||
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
|
||||
// Default is no origins (CORS disabled).
|
||||
func AllowOrigins(origins ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowOrigins = origins }
|
||||
}
|
||||
|
||||
// AllowMethods sets the allowed HTTP methods for preflight requests.
|
||||
// Default is GET, POST, HEAD.
|
||||
func AllowMethods(methods ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowMethods = methods }
|
||||
}
|
||||
|
||||
// AllowHeaders sets the allowed request headers for preflight requests.
|
||||
func AllowHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowHeaders = headers }
|
||||
}
|
||||
|
||||
// ExposeHeaders sets headers that browsers are allowed to access.
|
||||
func ExposeHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.exposeHeaders = headers }
|
||||
}
|
||||
|
||||
// AllowCredentials indicates whether the response to the request can be
|
||||
// exposed when the credentials flag is true.
|
||||
func AllowCredentials(allow bool) CORSOption {
|
||||
return func(o *corsOptions) { o.allowCredentials = allow }
|
||||
}
|
||||
|
||||
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
|
||||
func MaxAge(seconds int) CORSOption {
|
||||
return func(o *corsOptions) { o.maxAge = seconds }
|
||||
}
|
||||
|
||||
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
|
||||
// It processes preflight OPTIONS requests and sets the appropriate
|
||||
// Access-Control-* response headers.
|
||||
func CORS(opts ...CORSOption) Middleware {
|
||||
o := &corsOptions{
|
||||
allowMethods: []string{"GET", "POST", "HEAD"},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
|
||||
allowAll := false
|
||||
for _, origin := range o.allowOrigins {
|
||||
if origin == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
allowedOrigins[origin] = struct{}{}
|
||||
}
|
||||
|
||||
methods := strings.Join(o.allowMethods, ", ")
|
||||
headers := strings.Join(o.allowHeaders, ", ")
|
||||
expose := strings.Join(o.exposeHeaders, ", ")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed := allowAll
|
||||
if !allowed {
|
||||
_, allowed = allowedOrigins[origin]
|
||||
}
|
||||
if !allowed {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the allowed origin. When credentials are enabled,
|
||||
// we must echo the specific origin, not "*".
|
||||
if allowAll && !o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if expose != "" {
|
||||
w.Header().Set("Access-Control-Expose-Headers", expose)
|
||||
}
|
||||
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// Handle preflight.
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
w.Header().Set("Access-Control-Allow-Methods", methods)
|
||||
if headers != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", headers)
|
||||
}
|
||||
if o.maxAge > 0 {
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
143
server/middleware_cors_test.go
Normal file
143
server/middleware_cors_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("no Origin header passes through", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Fatal("expected no CORS headers without Origin")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard origin", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("specific origin allowed", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://evil.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preflight OPTIONS", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("http://example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.MaxAge(3600),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
|
||||
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
|
||||
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
|
||||
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials with specific origin", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.AllowCredentials(true),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// With credentials, must echo specific origin even with wildcard config.
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expose headers", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.ExposeHeaders("X-Custom", "X-Request-Id"),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
|
||||
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Vary header is set", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Vary"); got != "Origin" {
|
||||
t.Fatalf("Vary = %q, want %q", got, "Origin")
|
||||
}
|
||||
})
|
||||
}
|
||||
129
server/middleware_ratelimit.go
Normal file
129
server/middleware_ratelimit.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/internal/clock"
|
||||
)
|
||||
|
||||
type rateLimitOptions struct {
|
||||
rate float64
|
||||
burst int
|
||||
keyFunc func(r *http.Request) string
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
// RateLimitOption configures the RateLimit middleware.
|
||||
type RateLimitOption func(*rateLimitOptions)
|
||||
|
||||
// WithRate sets the token refill rate (tokens per second).
|
||||
func WithRate(tokensPerSecond float64) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.rate = tokensPerSecond }
|
||||
}
|
||||
|
||||
// WithBurst sets the maximum burst size (bucket capacity).
|
||||
func WithBurst(n int) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.burst = n }
|
||||
}
|
||||
|
||||
// WithKeyFunc sets a custom function to extract the rate-limit key from a
|
||||
// request. By default, the client IP address is used.
|
||||
func WithKeyFunc(fn func(r *http.Request) string) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.keyFunc = fn }
|
||||
}
|
||||
|
||||
// withRateLimitClock sets the clock for testing. Not exported.
|
||||
func withRateLimitClock(c clock.Clock) RateLimitOption {
|
||||
return func(o *rateLimitOptions) { o.clock = c }
|
||||
}
|
||||
|
||||
// RateLimit returns a middleware that limits requests using a per-key token
|
||||
// bucket algorithm. When the limit is exceeded, it returns 429 Too Many
|
||||
// Requests with a Retry-After header.
|
||||
func RateLimit(opts ...RateLimitOption) Middleware {
|
||||
o := &rateLimitOptions{
|
||||
rate: 10,
|
||||
burst: 20,
|
||||
clock: clock.System(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
if o.keyFunc == nil {
|
||||
o.keyFunc = clientIP
|
||||
}
|
||||
|
||||
var buckets sync.Map
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := o.keyFunc(r)
|
||||
val, _ := buckets.LoadOrStore(key, &bucket{
|
||||
tokens: float64(o.burst),
|
||||
lastTime: o.clock.Now(),
|
||||
})
|
||||
b := val.(*bucket)
|
||||
|
||||
b.mu.Lock()
|
||||
now := o.clock.Now()
|
||||
elapsed := now.Sub(b.lastTime).Seconds()
|
||||
b.tokens += elapsed * o.rate
|
||||
if b.tokens > float64(o.burst) {
|
||||
b.tokens = float64(o.burst)
|
||||
}
|
||||
b.lastTime = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
retryAfter := (1 - b.tokens) / o.rate
|
||||
b.mu.Unlock()
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(retryAfter)+1))
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
b.tokens--
|
||||
b.mu.Unlock()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP from the request. It checks
|
||||
// X-Forwarded-For first, then X-Real-Ip, and falls back to RemoteAddr.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First IP in the comma-separated list is the original client.
|
||||
if i := indexOf(xff, ','); i > 0 {
|
||||
return xff[:i]
|
||||
}
|
||||
return xff
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func indexOf(s string, b byte) int {
|
||||
for i := range len(s) {
|
||||
if s[i] == b {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
171
server/middleware_ratelimit_test.go
Normal file
171
server/middleware_ratelimit_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(100),
|
||||
server.WithBurst(10),
|
||||
)(okHandler)
|
||||
|
||||
for i := range 10 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: got status %d, want %d", i, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects when burst exhausted", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(2),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
for range 2 {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Next request should be rejected.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("expected Retry-After header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different IPs have independent limits", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// First IP exhausts its limit.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Second IP should still be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "5.6.7.8:5678"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses X-Forwarded-For", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust limit for 10.0.0.1.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same forwarded IP should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.1")
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom key function", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1),
|
||||
server.WithBurst(1),
|
||||
server.WithKeyFunc(func(r *http.Request) string {
|
||||
return r.Header.Get("X-API-Key")
|
||||
}),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust key "abc".
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Same key should be rate limited.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "abc")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Different key should be allowed.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-API-Key", "xyz")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tokens refill over time", func(t *testing.T) {
|
||||
mw := server.RateLimit(
|
||||
server.WithRate(1000), // Very fast refill for test
|
||||
server.WithBurst(1),
|
||||
)(okHandler)
|
||||
|
||||
// Exhaust burst.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// Wait a bit for refill.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d after refill, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type requestIDKey struct{}
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that assigns a unique request ID to each
|
||||
// request. If the incoming request already has an X-Request-Id header, that
|
||||
@@ -23,7 +23,7 @@ func RequestID() Middleware {
|
||||
id = newUUID()
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), requestIDKey{}, id)
|
||||
ctx := requestid.NewContext(r.Context(), id)
|
||||
w.Header().Set("X-Request-Id", id)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
@@ -33,8 +33,7 @@ func RequestID() Middleware {
|
||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||
// string if none is set.
|
||||
func RequestIDFromContext(ctx context.Context) string {
|
||||
id, _ := ctx.Value(requestIDKey{}).(string)
|
||||
return id
|
||||
return requestid.FromContext(ctx)
|
||||
}
|
||||
|
||||
// newUUID generates a UUID v4 string using crypto/rand.
|
||||
|
||||
15
server/middleware_timeout.go
Normal file
15
server/middleware_timeout.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Timeout returns a middleware that limits request processing time.
|
||||
// If the handler does not complete within d, the client receives a
|
||||
// 503 Service Unavailable response. It wraps http.TimeoutHandler.
|
||||
func Timeout(d time.Duration) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.TimeoutHandler(next, d, "Service Unavailable\n")
|
||||
}
|
||||
}
|
||||
49
server/middleware_timeout_test.go
Normal file
49
server/middleware_timeout_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
t.Run("handler completes within timeout", func(t *testing.T) {
|
||||
handler := server.Timeout(1 * time.Second)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
}),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handler exceeds timeout returns 503", func(t *testing.T) {
|
||||
handler := server.Timeout(10 * time.Millisecond)(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
})
|
||||
}
|
||||
29
server/respond.go
Normal file
29
server/respond.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// WriteJSON encodes v as JSON and writes it to w with the given status code.
|
||||
// It sets Content-Type to application/json.
|
||||
func WriteJSON(w http.ResponseWriter, status int, v any) error {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, err = w.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteError writes a JSON error response with the given status code and
|
||||
// message. The response body is {"error": "<message>"}.
|
||||
func WriteError(w http.ResponseWriter, status int, msg string) error {
|
||||
return WriteJSON(w, status, errorBody{Error: msg})
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
72
server/respond_test.go
Normal file
72
server/respond_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
t.Run("writes JSON with status and content type", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
type resp struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
err := server.WriteJSON(w, 201, resp{ID: 1, Name: "Alice"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if w.Code != 201 {
|
||||
t.Fatalf("got status %d, want %d", w.Code, 201)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||
}
|
||||
|
||||
var decoded resp
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if decoded.ID != 1 || decoded.Name != "Alice" {
|
||||
t.Fatalf("got %+v, want {ID:1 Name:Alice}", decoded)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for unmarshalable input", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := server.WriteJSON(w, 200, make(chan int))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for channel type")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := server.WriteError(w, 404, "not found")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Fatalf("got status %d, want %d", w.Code, 404)
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body.Error != "not found" {
|
||||
t.Fatalf("error = %q, want %q", body.Error, "not found")
|
||||
}
|
||||
}
|
||||
@@ -12,13 +12,28 @@ type Router struct {
|
||||
mux *http.ServeMux
|
||||
prefix string
|
||||
middlewares []Middleware
|
||||
notFoundHandler http.Handler
|
||||
}
|
||||
|
||||
// RouterOption configures a Router.
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithNotFoundHandler sets a custom handler for requests that don't match
|
||||
// any registered pattern. This is useful for returning JSON 404/405 responses
|
||||
// instead of the default plain text.
|
||||
func WithNotFoundHandler(h http.Handler) RouterOption {
|
||||
return func(r *Router) { r.notFoundHandler = h }
|
||||
}
|
||||
|
||||
// NewRouter creates a new Router backed by a fresh http.ServeMux.
|
||||
func NewRouter() *Router {
|
||||
return &Router{
|
||||
func NewRouter(opts ...RouterOption) *Router {
|
||||
r := &Router{
|
||||
mux: http.NewServeMux(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Handle registers a handler for the given pattern. The pattern follows
|
||||
@@ -63,6 +78,14 @@ func (r *Router) Mount(prefix string, handler http.Handler) {
|
||||
|
||||
// ServeHTTP implements http.Handler, making Router usable as a handler.
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if r.notFoundHandler != nil {
|
||||
// Use the mux to check for a match. If none, use the custom handler.
|
||||
_, pattern := r.mux.Handler(req)
|
||||
if pattern == "" {
|
||||
r.notFoundHandler.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
r.mux.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
|
||||
70
server/route_notfound_test.go
Normal file
70
server/route_notfound_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestRouter_NotFoundHandler(t *testing.T) {
|
||||
t.Run("custom 404 handler", func(t *testing.T) {
|
||||
notFound := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "not found"})
|
||||
})
|
||||
|
||||
r := server.NewRouter(server.WithNotFoundHandler(notFound))
|
||||
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Matched route works normally.
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/exists", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("matched route: got status %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Unmatched route uses custom handler.
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("not found: got status %d, want %d", w.Code, http.StatusNotFound)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want %q", ct, "application/json")
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if body["error"] != "not found" {
|
||||
t.Fatalf("error = %q, want %q", body["error"], "not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default behavior without custom handler", func(t *testing.T) {
|
||||
r := server.NewRouter()
|
||||
r.HandleFunc("GET /exists", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/nope", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Default ServeMux returns 404.
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user