Compare commits
8 Commits
f2a4a4fccc
...
d260abc393
| Author | SHA1 | Date | |
|---|---|---|---|
| d260abc393 | |||
| 5cfd1a7400 | |||
| f9a05f5c57 | |||
| a90c4cd7fa | |||
| 8d322123a4 | |||
| 2ca930236d | |||
| 505c7b8c4f | |||
| 6b1941fce7 |
93
README.md
93
README.md
@@ -1,2 +1,95 @@
|
|||||||
# httpx
|
# httpx
|
||||||
|
|
||||||
|
HTTP client for Go microservices. Retry, load balancing, circuit breaking, all as `http.RoundTripper` middleware. stdlib only, zero external deps.
|
||||||
|
|
||||||
|
```
|
||||||
|
go get git.codelab.vc/pkg/httpx
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```go
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithBaseURL("https://api.example.com"),
|
||||||
|
httpx.WithTimeout(10*time.Second),
|
||||||
|
httpx.WithRetry(retry.WithMaxAttempts(3)),
|
||||||
|
httpx.WithMiddleware(
|
||||||
|
middleware.UserAgent("my-service/1.0"),
|
||||||
|
middleware.BearerAuth(func(ctx context.Context) (string, error) {
|
||||||
|
return os.Getenv("API_TOKEN"), nil
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
resp, err := client.Get(ctx, "/users/123")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
resp.JSON(&user)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Packages
|
||||||
|
|
||||||
|
Everything is a `func(http.RoundTripper) http.RoundTripper`. Use them with `httpx.Client` or plug into a plain `http.Client`.
|
||||||
|
|
||||||
|
| Package | What it does |
|
||||||
|
|---------|-------------|
|
||||||
|
| `retry` | Exponential/constant backoff, Retry-After support. Idempotent methods only by default. |
|
||||||
|
| `balancer` | Round robin, failover, weighted random. Optional background health checks. |
|
||||||
|
| `circuitbreaker` | Per-host state machine (closed/open/half-open). Stops hammering dead endpoints. |
|
||||||
|
| `middleware` | Logging (slog), default headers, bearer/basic auth, panic recovery. |
|
||||||
|
|
||||||
|
The client assembles them in this order:
|
||||||
|
|
||||||
|
```
|
||||||
|
Request → Logging → Your Middleware → Retry → Circuit Breaker → Balancer → Transport
|
||||||
|
```
|
||||||
|
|
||||||
|
Retry wraps the circuit breaker and balancer, so each attempt can pick a different endpoint.
|
||||||
|
|
||||||
|
## Multi-DC setup
|
||||||
|
|
||||||
|
```go
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithEndpoints(
|
||||||
|
balancer.Endpoint{URL: "https://dc1.api.internal", Weight: 3},
|
||||||
|
balancer.Endpoint{URL: "https://dc2.api.internal", Weight: 1},
|
||||||
|
),
|
||||||
|
httpx.WithBalancer(balancer.WithStrategy(balancer.WeightedRandom())),
|
||||||
|
httpx.WithRetry(retry.WithMaxAttempts(4)),
|
||||||
|
httpx.WithCircuitBreaker(circuitbreaker.WithFailureThreshold(5)),
|
||||||
|
httpx.WithLogger(slog.Default()),
|
||||||
|
)
|
||||||
|
defer client.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Standalone usage
|
||||||
|
|
||||||
|
Each component works with any `http.Client`, no need for the full wrapper:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Just retry, nothing else
|
||||||
|
transport := retry.Transport(retry.WithMaxAttempts(3))
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: transport(http.DefaultTransport),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Chain a few middlewares together
|
||||||
|
chain := middleware.Chain(
|
||||||
|
middleware.Logging(slog.Default()),
|
||||||
|
middleware.UserAgent("my-service/1.0"),
|
||||||
|
retry.Transport(retry.WithMaxAttempts(2)),
|
||||||
|
)
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: chain(http.DefaultTransport),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
Go 1.24+, stdlib only.
|
||||||
|
|||||||
104
balancer/balancer.go
Normal file
104
balancer/balancer.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNoHealthy is returned when no healthy endpoints are available.
|
||||||
|
var ErrNoHealthy = errors.New("httpx: no healthy endpoints available")
|
||||||
|
|
||||||
|
// Endpoint represents a backend server that can handle requests.
|
||||||
|
type Endpoint struct {
|
||||||
|
URL string
|
||||||
|
Weight int
|
||||||
|
Meta map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy selects an endpoint from the list of healthy endpoints.
|
||||||
|
type Strategy interface {
|
||||||
|
Next(healthy []Endpoint) (Endpoint, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closer can be used to shut down resources associated with a balancer
|
||||||
|
// transport (e.g. background health checker goroutines).
|
||||||
|
type Closer struct {
|
||||||
|
healthChecker *HealthChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops background goroutines. Safe to call multiple times.
|
||||||
|
func (c *Closer) Close() {
|
||||||
|
if c.healthChecker != nil {
|
||||||
|
c.healthChecker.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns a middleware that load-balances requests across the
|
||||||
|
// provided endpoints using the configured strategy.
|
||||||
|
//
|
||||||
|
// For each request the middleware picks an endpoint via the strategy,
|
||||||
|
// replaces the request URL scheme and host with the endpoint's URL,
|
||||||
|
// and forwards the request to the underlying RoundTripper.
|
||||||
|
//
|
||||||
|
// If active health checking is enabled (WithHealthCheck), a background
|
||||||
|
// goroutine periodically probes endpoints. Otherwise all endpoints are
|
||||||
|
// assumed healthy.
|
||||||
|
func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Closer) {
|
||||||
|
o := &options{
|
||||||
|
strategy: RoundRobin(),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-parse endpoint URLs once at construction time.
|
||||||
|
parsed := make(map[string]*url.URL, len(endpoints))
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
u, err := url.Parse(ep.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
|
||||||
|
}
|
||||||
|
parsed[ep.URL] = u
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.healthChecker != nil {
|
||||||
|
o.healthChecker.Start(endpoints)
|
||||||
|
}
|
||||||
|
|
||||||
|
closer := &Closer{healthChecker: o.healthChecker}
|
||||||
|
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
healthy := endpoints
|
||||||
|
if o.healthChecker != nil {
|
||||||
|
healthy = o.healthChecker.Healthy(endpoints)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(healthy) == 0 {
|
||||||
|
return nil, ErrNoHealthy
|
||||||
|
}
|
||||||
|
|
||||||
|
ep, err := o.strategy.Next(healthy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
epURL := parsed[ep.URL]
|
||||||
|
|
||||||
|
// Shallow-copy request and URL to avoid mutating the original,
|
||||||
|
// without the expense of req.Clone's deep header copy.
|
||||||
|
r := *req
|
||||||
|
u := *req.URL
|
||||||
|
r.URL = &u
|
||||||
|
r.URL.Scheme = epURL.Scheme
|
||||||
|
r.URL.Host = epURL.Host
|
||||||
|
r.Host = epURL.Host
|
||||||
|
|
||||||
|
return next.RoundTrip(&r)
|
||||||
|
})
|
||||||
|
}, closer
|
||||||
|
}
|
||||||
214
balancer/balancer_test.go
Normal file
214
balancer/balancer_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
|
||||||
|
endpoints := []Endpoint{
|
||||||
|
{URL: "https://backend1.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var captured *http.Request
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mw, _ := Transport(endpoints)
|
||||||
|
rt := mw(base)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if captured == nil {
|
||||||
|
t.Fatal("base transport was not called")
|
||||||
|
}
|
||||||
|
if captured.URL.Scheme != "https" {
|
||||||
|
t.Errorf("scheme = %q, want %q", captured.URL.Scheme, "https")
|
||||||
|
}
|
||||||
|
if captured.URL.Host != "backend1.example.com" {
|
||||||
|
t.Errorf("host = %q, want %q", captured.URL.Host, "backend1.example.com")
|
||||||
|
}
|
||||||
|
if captured.URL.Path != "/api/v1/users" {
|
||||||
|
t.Errorf("path = %q, want %q", captured.URL.Path, "/api/v1/users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
||||||
|
var endpoints []Endpoint
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
t.Fatal("base transport should not be called")
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mw, _ := Transport(endpoints)
|
||||||
|
rt := mw(base)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = rt.RoundTrip(req)
|
||||||
|
if err != ErrNoHealthy {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobin_DistributesEvenly(t *testing.T) {
|
||||||
|
endpoints := []Endpoint{
|
||||||
|
{URL: "https://a.example.com"},
|
||||||
|
{URL: "https://b.example.com"},
|
||||||
|
{URL: "https://c.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := RoundRobin()
|
||||||
|
counts := make(map[string]int)
|
||||||
|
|
||||||
|
const iterations = 300
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
ep, err := rr.Next(endpoints)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
counts[ep.URL]++
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := iterations / len(endpoints)
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
got := counts[ep.URL]
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("endpoint %s: got %d calls, want %d", ep.URL, got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobin_ErrNoHealthy(t *testing.T) {
|
||||||
|
rr := RoundRobin()
|
||||||
|
_, err := rr.Next(nil)
|
||||||
|
if err != ErrNoHealthy {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailover_AlwaysPicksFirst(t *testing.T) {
|
||||||
|
endpoints := []Endpoint{
|
||||||
|
{URL: "https://primary.example.com"},
|
||||||
|
{URL: "https://secondary.example.com"},
|
||||||
|
{URL: "https://tertiary.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
fo := Failover()
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
ep, err := fo.Next(endpoints)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
if ep.URL != "https://primary.example.com" {
|
||||||
|
t.Errorf("iteration %d: got %q, want %q", i, ep.URL, "https://primary.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailover_ErrNoHealthy(t *testing.T) {
|
||||||
|
fo := Failover()
|
||||||
|
_, err := fo.Next(nil)
|
||||||
|
if err != ErrNoHealthy {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWeightedRandom_RespectsWeights(t *testing.T) {
|
||||||
|
endpoints := []Endpoint{
|
||||||
|
{URL: "https://heavy.example.com", Weight: 80},
|
||||||
|
{URL: "https://light.example.com", Weight: 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
wr := WeightedRandom()
|
||||||
|
counts := make(map[string]int)
|
||||||
|
|
||||||
|
const iterations = 10000
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
ep, err := wr.Next(endpoints)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
counts[ep.URL]++
|
||||||
|
}
|
||||||
|
|
||||||
|
totalWeight := 0
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
totalWeight += ep.Weight
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
got := float64(counts[ep.URL]) / float64(iterations)
|
||||||
|
want := float64(ep.Weight) / float64(totalWeight)
|
||||||
|
if math.Abs(got-want) > 0.05 {
|
||||||
|
t.Errorf("endpoint %s: got ratio %.3f, want ~%.3f (tolerance 0.05)", ep.URL, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWeightedRandom_DefaultWeightForZero(t *testing.T) {
|
||||||
|
endpoints := []Endpoint{
|
||||||
|
{URL: "https://a.example.com", Weight: 0},
|
||||||
|
{URL: "https://b.example.com", Weight: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
wr := WeightedRandom()
|
||||||
|
counts := make(map[string]int)
|
||||||
|
|
||||||
|
const iterations = 1000
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
ep, err := wr.Next(endpoints)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
counts[ep.URL]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// With equal default weights, distribution should be roughly even.
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
got := float64(counts[ep.URL]) / float64(iterations)
|
||||||
|
if math.Abs(got-0.5) > 0.1 {
|
||||||
|
t.Errorf("endpoint %s: got ratio %.3f, want ~0.5 (tolerance 0.1)", ep.URL, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWeightedRandom_ErrNoHealthy(t *testing.T) {
|
||||||
|
wr := WeightedRandom()
|
||||||
|
_, err := wr.Next(nil)
|
||||||
|
if err != ErrNoHealthy {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrNoHealthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
17
balancer/failover.go
Normal file
17
balancer/failover.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
type failover struct{}
|
||||||
|
|
||||||
|
// Failover returns a strategy that always picks the first healthy endpoint.
|
||||||
|
// If the primary endpoint is unhealthy, it falls back to the next available
|
||||||
|
// healthy endpoint in order.
|
||||||
|
func Failover() Strategy {
|
||||||
|
return &failover{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failover) Next(healthy []Endpoint) (Endpoint, error) {
|
||||||
|
if len(healthy) == 0 {
|
||||||
|
return Endpoint{}, ErrNoHealthy
|
||||||
|
}
|
||||||
|
return healthy[0], nil
|
||||||
|
}
|
||||||
174
balancer/health.go
Normal file
174
balancer/health.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultHealthInterval = 10 * time.Second
|
||||||
|
defaultHealthPath = "/health"
|
||||||
|
defaultHealthTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthOption configures the HealthChecker.
|
||||||
|
type HealthOption func(*HealthChecker)
|
||||||
|
|
||||||
|
// WithHealthInterval sets the interval between health check probes.
|
||||||
|
// Default is 10 seconds.
|
||||||
|
func WithHealthInterval(d time.Duration) HealthOption {
|
||||||
|
return func(h *HealthChecker) {
|
||||||
|
h.interval = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHealthPath sets the HTTP path to probe for health checks.
|
||||||
|
// Default is "/health".
|
||||||
|
func WithHealthPath(path string) HealthOption {
|
||||||
|
return func(h *HealthChecker) {
|
||||||
|
h.path = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHealthTimeout sets the timeout for each health check request.
|
||||||
|
// Default is 5 seconds.
|
||||||
|
func WithHealthTimeout(d time.Duration) HealthOption {
|
||||||
|
return func(h *HealthChecker) {
|
||||||
|
h.timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthChecker periodically probes endpoints to determine their health status.
|
||||||
|
type HealthChecker struct {
|
||||||
|
interval time.Duration
|
||||||
|
path string
|
||||||
|
timeout time.Duration
|
||||||
|
client *http.Client
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
status map[string]bool
|
||||||
|
cancel context.CancelFunc
|
||||||
|
stopped chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHealthChecker(opts ...HealthOption) *HealthChecker {
|
||||||
|
h := &HealthChecker{
|
||||||
|
interval: defaultHealthInterval,
|
||||||
|
path: defaultHealthPath,
|
||||||
|
timeout: defaultHealthTimeout,
|
||||||
|
status: make(map[string]bool),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(h)
|
||||||
|
}
|
||||||
|
h.client = &http.Client{
|
||||||
|
Timeout: h.timeout,
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the background health checking loop for the given endpoints.
|
||||||
|
// An initial probe is run synchronously so that unhealthy endpoints are
|
||||||
|
// detected before the first request.
|
||||||
|
func (h *HealthChecker) Start(endpoints []Endpoint) {
|
||||||
|
// Mark all healthy as a safe default, then immediately probe.
|
||||||
|
h.mu.Lock()
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
h.status[ep.URL] = true
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
h.cancel = cancel
|
||||||
|
h.stopped = make(chan struct{})
|
||||||
|
|
||||||
|
// Run initial probe synchronously so callers don't hit stale state.
|
||||||
|
h.probe(ctx, endpoints)
|
||||||
|
|
||||||
|
go h.loop(ctx, endpoints)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop terminates the background health checking goroutine and waits for
|
||||||
|
// it to finish.
|
||||||
|
func (h *HealthChecker) Stop() {
|
||||||
|
if h.cancel != nil {
|
||||||
|
h.cancel()
|
||||||
|
<-h.stopped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthy reports whether the given endpoint is currently healthy.
|
||||||
|
func (h *HealthChecker) IsHealthy(ep Endpoint) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
healthy, ok := h.status[ep.URL]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
// Healthy returns the subset of endpoints that are currently healthy.
|
||||||
|
func (h *HealthChecker) Healthy(endpoints []Endpoint) []Endpoint {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]Endpoint, 0, len(endpoints))
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
if h.status[ep.URL] {
|
||||||
|
result = append(result, ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HealthChecker) loop(ctx context.Context, endpoints []Endpoint) {
|
||||||
|
defer close(h.stopped)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(h.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
h.probe(ctx, endpoints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(endpoints))
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
healthy := h.check(ctx, ep)
|
||||||
|
h.mu.Lock()
|
||||||
|
h.status[ep.URL] = healthy
|
||||||
|
h.mu.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ep.URL+h.path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||||
|
}
|
||||||
25
balancer/options.go
Normal file
25
balancer/options.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
// options holds configuration for the load balancer transport.
|
||||||
|
type options struct {
|
||||||
|
strategy Strategy // default RoundRobin
|
||||||
|
healthChecker *HealthChecker // optional
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the load balancer transport.
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
// WithStrategy sets the endpoint selection strategy.
|
||||||
|
// If not specified, RoundRobin is used.
|
||||||
|
func WithStrategy(s Strategy) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.strategy = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHealthCheck enables active health checking of endpoints.
|
||||||
|
func WithHealthCheck(opts ...HealthOption) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.healthChecker = newHealthChecker(opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
balancer/roundrobin.go
Normal file
21
balancer/roundrobin.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
type roundRobin struct {
|
||||||
|
counter atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundRobin returns a strategy that cycles through healthy endpoints
|
||||||
|
// sequentially using an atomic counter.
|
||||||
|
func RoundRobin() Strategy {
|
||||||
|
return &roundRobin{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *roundRobin) Next(healthy []Endpoint) (Endpoint, error) {
|
||||||
|
if len(healthy) == 0 {
|
||||||
|
return Endpoint{}, ErrNoHealthy
|
||||||
|
}
|
||||||
|
idx := r.counter.Add(1) - 1
|
||||||
|
return healthy[idx%uint64(len(healthy))], nil
|
||||||
|
}
|
||||||
42
balancer/weighted.go
Normal file
42
balancer/weighted.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package balancer
|
||||||
|
|
||||||
|
import "math/rand/v2"
|
||||||
|
|
||||||
|
type weightedRandom struct{}
|
||||||
|
|
||||||
|
// WeightedRandom returns a strategy that selects endpoints randomly,
|
||||||
|
// weighted by each endpoint's Weight field. Endpoints with Weight <= 0
|
||||||
|
// are treated as having a weight of 1.
|
||||||
|
func WeightedRandom() Strategy {
|
||||||
|
return &weightedRandom{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *weightedRandom) Next(healthy []Endpoint) (Endpoint, error) {
|
||||||
|
if len(healthy) == 0 {
|
||||||
|
return Endpoint{}, ErrNoHealthy
|
||||||
|
}
|
||||||
|
|
||||||
|
totalWeight := 0
|
||||||
|
for _, ep := range healthy {
|
||||||
|
weight := ep.Weight
|
||||||
|
if weight <= 0 {
|
||||||
|
weight = 1
|
||||||
|
}
|
||||||
|
totalWeight += weight
|
||||||
|
}
|
||||||
|
|
||||||
|
r := rand.IntN(totalWeight)
|
||||||
|
for _, ep := range healthy {
|
||||||
|
weight := ep.Weight
|
||||||
|
if weight <= 0 {
|
||||||
|
weight = 1
|
||||||
|
}
|
||||||
|
r -= weight
|
||||||
|
if r < 0 {
|
||||||
|
return ep, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should never reach here, but return last endpoint as a safeguard.
|
||||||
|
return healthy[len(healthy)-1], nil
|
||||||
|
}
|
||||||
176
circuitbreaker/breaker.go
Normal file
176
circuitbreaker/breaker.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCircuitOpen is returned by Allow when the breaker is in the Open state.
|
||||||
|
var ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
|
||||||
|
|
||||||
|
// State represents the current state of a circuit breaker.
|
||||||
|
type State int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateClosed State = iota // normal operation
|
||||||
|
StateOpen // failing, reject requests
|
||||||
|
StateHalfOpen // testing recovery
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns a human-readable name for the state.
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateClosed:
|
||||||
|
return "closed"
|
||||||
|
case StateOpen:
|
||||||
|
return "open"
|
||||||
|
case StateHalfOpen:
|
||||||
|
return "half-open"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Breaker implements a per-endpoint circuit breaker state machine.
|
||||||
|
//
|
||||||
|
// State transitions:
|
||||||
|
//
|
||||||
|
// Closed → Open: after failureThreshold consecutive failures
|
||||||
|
// Open → HalfOpen: after openDuration passes
|
||||||
|
// HalfOpen → Closed: on success
|
||||||
|
// HalfOpen → Open: on failure (timer resets)
|
||||||
|
type Breaker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
opts options
|
||||||
|
|
||||||
|
state State
|
||||||
|
failures int // consecutive failure count (Closed state)
|
||||||
|
openedAt time.Time
|
||||||
|
halfOpenCur int // current in-flight half-open requests
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBreaker creates a Breaker with the given options.
|
||||||
|
func NewBreaker(opts ...Option) *Breaker {
|
||||||
|
o := defaults()
|
||||||
|
for _, fn := range opts {
|
||||||
|
fn(&o)
|
||||||
|
}
|
||||||
|
return &Breaker{opts: o}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State returns the current state of the breaker.
|
||||||
|
func (b *Breaker) State() State {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
return b.stateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
||||||
|
// open duration has elapsed. Caller must hold b.mu.
|
||||||
|
func (b *Breaker) stateLocked() State {
|
||||||
|
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
|
||||||
|
b.state = StateHalfOpen
|
||||||
|
b.halfOpenCur = 0
|
||||||
|
}
|
||||||
|
return b.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow checks whether a request is permitted. If allowed it returns a done
|
||||||
|
// callback that the caller MUST invoke with the result of the request. If the
|
||||||
|
// breaker is open, it returns ErrCircuitOpen.
|
||||||
|
func (b *Breaker) Allow() (done func(success bool), err error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
switch b.stateLocked() {
|
||||||
|
case StateClosed:
|
||||||
|
// always allow
|
||||||
|
case StateOpen:
|
||||||
|
return nil, ErrCircuitOpen
|
||||||
|
case StateHalfOpen:
|
||||||
|
if b.halfOpenCur >= b.opts.halfOpenMax {
|
||||||
|
return nil, ErrCircuitOpen
|
||||||
|
}
|
||||||
|
b.halfOpenCur++
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.doneFunc(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doneFunc returns the callback for a single in-flight request. Caller must
|
||||||
|
// hold b.mu when calling doneFunc, but the returned function acquires the lock
|
||||||
|
// itself.
|
||||||
|
func (b *Breaker) doneFunc() func(success bool) {
|
||||||
|
var once sync.Once
|
||||||
|
return func(success bool) {
|
||||||
|
once.Do(func() {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
b.record(success)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record processes the outcome of a single request. Caller must hold b.mu.
|
||||||
|
func (b *Breaker) record(success bool) {
|
||||||
|
switch b.state {
|
||||||
|
case StateClosed:
|
||||||
|
if success {
|
||||||
|
b.failures = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.failures++
|
||||||
|
if b.failures >= b.opts.failureThreshold {
|
||||||
|
b.tripLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
case StateHalfOpen:
|
||||||
|
b.halfOpenCur--
|
||||||
|
if success {
|
||||||
|
b.state = StateClosed
|
||||||
|
b.failures = 0
|
||||||
|
} else {
|
||||||
|
b.tripLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tripLocked transitions to the Open state and records the timestamp.
|
||||||
|
func (b *Breaker) tripLocked() {
|
||||||
|
b.state = StateOpen
|
||||||
|
b.openedAt = time.Now()
|
||||||
|
b.halfOpenCur = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns a middleware that applies per-host circuit breaking. It
|
||||||
|
// maintains an internal map of host → *Breaker so each target host is tracked
|
||||||
|
// independently.
|
||||||
|
func Transport(opts ...Option) middleware.Middleware {
|
||||||
|
var hosts sync.Map // map[string]*Breaker
|
||||||
|
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
host := req.URL.Host
|
||||||
|
|
||||||
|
val, ok := hosts.Load(host)
|
||||||
|
if !ok {
|
||||||
|
val, _ = hosts.LoadOrStore(host, NewBreaker(opts...))
|
||||||
|
}
|
||||||
|
cb := val.(*Breaker)
|
||||||
|
|
||||||
|
done, err := cb.Allow()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, rtErr := next.RoundTrip(req)
|
||||||
|
done(rtErr == nil && resp != nil && resp.StatusCode < 500)
|
||||||
|
|
||||||
|
return resp, rtErr
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
249
circuitbreaker/breaker_test.go
Normal file
249
circuitbreaker/breaker_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errResponse(code int) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: code,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_StartsInClosedState(t *testing.T) {
|
||||||
|
b := NewBreaker()
|
||||||
|
if s := b.State(); s != StateClosed {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) {
|
||||||
|
const threshold = 3
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(threshold),
|
||||||
|
WithOpenDuration(time.Hour), // long duration so it stays open
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < threshold; i++ {
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: Allow returned error: %v", i, err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_OpenRejectsRequests(t *testing.T) {
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Subsequent requests should be rejected.
|
||||||
|
_, err = b.Allow()
|
||||||
|
if !errors.Is(err, ErrCircuitOpen) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the open duration to elapse.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateHalfOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Wait for half-open.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// A successful request in half-open should close the breaker.
|
||||||
|
done, err = b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(true)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateClosed {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Wait for half-open.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// A failed request in half-open should re-open the breaker.
|
||||||
|
done, err = b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_PerHostBreakers(t *testing.T) {
|
||||||
|
const threshold = 2
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.URL.Host == "failing.example.com" {
|
||||||
|
return errResponse(http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithFailureThreshold(threshold),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)(base)
|
||||||
|
|
||||||
|
t.Run("failing host trips breaker", func(t *testing.T) {
|
||||||
|
for i := 0; i < threshold; i++ {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request to failing host should be rejected.
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = rt.RoundTrip(req)
|
||||||
|
if !errors.Is(err, ErrCircuitOpen) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("healthy host is unaffected", func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_SuccessResetsFailures(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
callCount++
|
||||||
|
// Fail on odd calls, succeed on even.
|
||||||
|
if callCount%2 == 1 {
|
||||||
|
return errResponse(http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithFailureThreshold(3),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)(base)
|
||||||
|
|
||||||
|
// Alternate fail/success — should never trip because successes reset the
|
||||||
|
// consecutive failure counter.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
50
circuitbreaker/options.go
Normal file
50
circuitbreaker/options.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
failureThreshold int // consecutive failures to trip
|
||||||
|
openDuration time.Duration // how long to stay open before half-open
|
||||||
|
halfOpenMax int // max concurrent requests in half-open
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaults() options {
|
||||||
|
return options{
|
||||||
|
failureThreshold: 5,
|
||||||
|
openDuration: 30 * time.Second,
|
||||||
|
halfOpenMax: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Breaker.
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
// WithFailureThreshold sets the number of consecutive failures required to
|
||||||
|
// trip the breaker from Closed to Open. Default is 5.
|
||||||
|
func WithFailureThreshold(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n > 0 {
|
||||||
|
o.failureThreshold = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOpenDuration sets how long the breaker stays in the Open state before
|
||||||
|
// transitioning to HalfOpen. Default is 30s.
|
||||||
|
func WithOpenDuration(d time.Duration) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if d > 0 {
|
||||||
|
o.openDuration = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHalfOpenMax sets the maximum number of concurrent probe requests
|
||||||
|
// allowed while the breaker is in the HalfOpen state. Default is 1.
|
||||||
|
func WithHalfOpenMax(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n > 0 {
|
||||||
|
o.halfOpenMax = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
186
client.go
Normal file
186
client.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is a high-level HTTP client that composes middleware for retry,
|
||||||
|
// circuit breaking, load balancing, logging, and more.
|
||||||
|
type Client struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
baseURL string
|
||||||
|
errorMapper ErrorMapper
|
||||||
|
balancerCloser *balancer.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Client with the given options.
|
||||||
|
//
|
||||||
|
// The middleware chain is assembled as (outermost → innermost):
|
||||||
|
//
|
||||||
|
// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Base Transport
|
||||||
|
func New(opts ...Option) *Client {
|
||||||
|
o := &clientOptions{
|
||||||
|
transport: http.DefaultTransport,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the middleware chain from inside out.
|
||||||
|
var chain []middleware.Middleware
|
||||||
|
|
||||||
|
// Balancer (innermost, wraps base transport).
|
||||||
|
var balancerCloser *balancer.Closer
|
||||||
|
if len(o.endpoints) > 0 {
|
||||||
|
var mw middleware.Middleware
|
||||||
|
mw, balancerCloser = balancer.Transport(o.endpoints, o.balancerOpts...)
|
||||||
|
chain = append(chain, mw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Circuit breaker wraps balancer.
|
||||||
|
if o.enableCB {
|
||||||
|
chain = append(chain, circuitbreaker.Transport(o.cbOpts...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry wraps circuit breaker + balancer.
|
||||||
|
if o.enableRetry {
|
||||||
|
chain = append(chain, retry.Transport(o.retryOpts...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// User middlewares.
|
||||||
|
for i := len(o.middlewares) - 1; i >= 0; i-- {
|
||||||
|
chain = append(chain, o.middlewares[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logging (outermost).
|
||||||
|
if o.logger != nil {
|
||||||
|
chain = append(chain, middleware.Logging(o.logger))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble: chain[last] is outermost.
|
||||||
|
rt := o.transport
|
||||||
|
for _, mw := range chain {
|
||||||
|
rt = mw(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: rt,
|
||||||
|
Timeout: o.timeout,
|
||||||
|
},
|
||||||
|
baseURL: o.baseURL,
|
||||||
|
errorMapper: o.errorMapper,
|
||||||
|
balancerCloser: balancerCloser,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do executes an HTTP request.
|
||||||
|
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
if err := c.resolveURL(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &Error{
|
||||||
|
Op: req.Method,
|
||||||
|
URL: req.URL.String(),
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := newResponse(resp)
|
||||||
|
|
||||||
|
if c.errorMapper != nil {
|
||||||
|
if mapErr := c.errorMapper(resp); mapErr != nil {
|
||||||
|
return r, &Error{
|
||||||
|
Op: req.Method,
|
||||||
|
URL: req.URL.String(),
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Err: mapErr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get performs a GET request to the given URL.
|
||||||
|
func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.Do(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post performs a POST request to the given URL with the given body.
|
||||||
|
func (c *Client) Post(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.Do(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put performs a PUT request to the given URL with the given body.
|
||||||
|
func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.Do(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete performs a DELETE request to the given URL.
|
||||||
|
func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.Do(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases resources associated with the Client, such as background
|
||||||
|
// health checker goroutines. It is safe to call multiple times.
|
||||||
|
func (c *Client) Close() {
|
||||||
|
if c.balancerCloser != nil {
|
||||||
|
c.balancerCloser.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient returns the underlying *http.Client for advanced use cases.
|
||||||
|
// Mutating the returned client may bypass the configured middleware chain.
|
||||||
|
func (c *Client) HTTPClient() *http.Client {
|
||||||
|
return c.httpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) resolveURL(req *http.Request) error {
|
||||||
|
if c.baseURL == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Only resolve relative URLs (no scheme).
|
||||||
|
if req.URL.Scheme == "" && req.URL.Host == "" {
|
||||||
|
path := req.URL.String()
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
base := strings.TrimRight(c.baseURL, "/")
|
||||||
|
u, err := req.URL.Parse(base + path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("httpx: resolving URL %q with base %q: %w", path, c.baseURL, err)
|
||||||
|
}
|
||||||
|
req.URL = u
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
87
client_options.go
Normal file
87
client_options.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientOptions struct {
|
||||||
|
baseURL string
|
||||||
|
timeout time.Duration
|
||||||
|
transport http.RoundTripper
|
||||||
|
logger *slog.Logger
|
||||||
|
errorMapper ErrorMapper
|
||||||
|
middlewares []middleware.Middleware
|
||||||
|
retryOpts []retry.Option
|
||||||
|
enableRetry bool
|
||||||
|
cbOpts []circuitbreaker.Option
|
||||||
|
enableCB bool
|
||||||
|
endpoints []balancer.Endpoint
|
||||||
|
balancerOpts []balancer.Option
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Client.
|
||||||
|
type Option func(*clientOptions)
|
||||||
|
|
||||||
|
// WithBaseURL sets the base URL prepended to all relative request paths.
|
||||||
|
func WithBaseURL(url string) Option {
|
||||||
|
return func(o *clientOptions) { o.baseURL = url }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout sets the overall request timeout.
|
||||||
|
func WithTimeout(d time.Duration) Option {
|
||||||
|
return func(o *clientOptions) { o.timeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTransport sets the base http.RoundTripper. Defaults to http.DefaultTransport.
|
||||||
|
func WithTransport(rt http.RoundTripper) Option {
|
||||||
|
return func(o *clientOptions) { o.transport = rt }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger enables structured logging of requests and responses.
|
||||||
|
func WithLogger(l *slog.Logger) Option {
|
||||||
|
return func(o *clientOptions) { o.logger = l }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithErrorMapper sets a function that maps HTTP responses to errors.
|
||||||
|
func WithErrorMapper(m ErrorMapper) Option {
|
||||||
|
return func(o *clientOptions) { o.errorMapper = m }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware appends user middlewares to the chain.
|
||||||
|
// These run between logging and retry in the middleware stack.
|
||||||
|
func WithMiddleware(mws ...middleware.Middleware) Option {
|
||||||
|
return func(o *clientOptions) { o.middlewares = append(o.middlewares, mws...) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetry enables retry with the given options.
|
||||||
|
func WithRetry(opts ...retry.Option) Option {
|
||||||
|
return func(o *clientOptions) {
|
||||||
|
o.enableRetry = true
|
||||||
|
o.retryOpts = opts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCircuitBreaker enables per-host circuit breaking.
|
||||||
|
func WithCircuitBreaker(opts ...circuitbreaker.Option) Option {
|
||||||
|
return func(o *clientOptions) {
|
||||||
|
o.enableCB = true
|
||||||
|
o.cbOpts = opts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEndpoints sets the endpoints for load balancing.
|
||||||
|
func WithEndpoints(eps ...balancer.Endpoint) Option {
|
||||||
|
return func(o *clientOptions) { o.endpoints = eps }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBalancer configures the load balancer strategy and options.
|
||||||
|
func WithBalancer(opts ...balancer.Option) Option {
|
||||||
|
return func(o *clientOptions) { o.balancerOpts = opts }
|
||||||
|
}
|
||||||
311
client_test.go
Normal file
311
client_test.go
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_Get(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
t.Errorf("expected GET, got %s", r.Method)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, "hello")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New()
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if body != "hello" {
|
||||||
|
t.Errorf("expected body %q, got %q", "hello", body)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Post(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
b, _ := io.ReadAll(r.Body)
|
||||||
|
if string(b) != "request-body" {
|
||||||
|
t.Errorf("expected body %q, got %q", "request-body", string(b))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
fmt.Fprint(w, "created")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New()
|
||||||
|
resp, err := client.Post(context.Background(), srv.URL+"/items", strings.NewReader("request-body"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Errorf("expected status 201, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if body != "created" {
|
||||||
|
t.Errorf("expected body %q, got %q", "created", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_BaseURL(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/v1/users" {
|
||||||
|
t.Errorf("expected path /api/v1/users, got %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithBaseURL(srv.URL + "/api/v1"))
|
||||||
|
|
||||||
|
// Use a relative path (no scheme/host).
|
||||||
|
resp, err := client.Get(context.Background(), "/users")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_WithMiddleware(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
val := r.Header.Get("X-Custom-Header")
|
||||||
|
if val != "test-value" {
|
||||||
|
t.Errorf("expected header X-Custom-Header=%q, got %q", "test-value", val)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
addHeader := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
req.Header.Set("X-Custom-Header", "test-value")
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithMiddleware(addHeader))
|
||||||
|
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_RetryIntegration(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, "success")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithRetry(
|
||||||
|
retry.WithMaxAttempts(3),
|
||||||
|
retry.WithBackoff(retry.ConstantBackoff(1*time.Millisecond)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/flaky")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := resp.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
if body != "success" {
|
||||||
|
t.Errorf("expected body %q, got %q", "success", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := calls.Load(); got != 3 {
|
||||||
|
t.Errorf("expected 3 total requests, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_BalancerIntegration(t *testing.T) {
|
||||||
|
var hits1, hits2 atomic.Int32
|
||||||
|
|
||||||
|
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits1.Add(1)
|
||||||
|
fmt.Fprint(w, "server1")
|
||||||
|
}))
|
||||||
|
defer srv1.Close()
|
||||||
|
|
||||||
|
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits2.Add(1)
|
||||||
|
fmt.Fprint(w, "server2")
|
||||||
|
}))
|
||||||
|
defer srv2.Close()
|
||||||
|
|
||||||
|
client := httpx.New(
|
||||||
|
httpx.WithEndpoints(
|
||||||
|
balancer.Endpoint{URL: srv1.URL},
|
||||||
|
balancer.Endpoint{URL: srv2.URL},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
const totalRequests = 6
|
||||||
|
for i := range totalRequests {
|
||||||
|
resp, err := client.Get(context.Background(), fmt.Sprintf("/item/%d", i))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 := hits1.Load()
|
||||||
|
h2 := hits2.Load()
|
||||||
|
|
||||||
|
if h1+h2 != totalRequests {
|
||||||
|
t.Errorf("expected %d total hits, got %d", totalRequests, h1+h2)
|
||||||
|
}
|
||||||
|
if h1 == 0 || h2 == 0 {
|
||||||
|
t.Errorf("expected requests distributed across both servers, got server1=%d server2=%d", h1, h2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_ErrorMapper(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprint(w, "not found")
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
mapper := func(resp *http.Response) error {
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := httpx.New(httpx.WithErrorMapper(mapper))
|
||||||
|
|
||||||
|
resp, err := client.Get(context.Background(), srv.URL+"/missing")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The response should still be returned alongside the error.
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatal("expected non-nil response even on mapped error")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status 404, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error message contains the status code.
|
||||||
|
if !strings.Contains(err.Error(), "404") {
|
||||||
|
t.Errorf("expected error to contain 404, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_JSON(t *testing.T) {
|
||||||
|
type reqPayload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
type respPayload struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if ct := r.Header.Get("Content-Type"); ct != "application/json" {
|
||||||
|
t.Errorf("expected Content-Type application/json, got %q", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
var p reqPayload
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||||
|
t.Errorf("decoding request body: %v", err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.Name != "Alice" || p.Age != 30 {
|
||||||
|
t.Errorf("unexpected payload: %+v", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(respPayload{ID: 1, Name: p.Name})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := httpx.New()
|
||||||
|
|
||||||
|
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, srv.URL+"/users", reqPayload{
|
||||||
|
Name: "Alice",
|
||||||
|
Age: 30,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating JSON request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result respPayload
|
||||||
|
if err := resp.JSON(&result); err != nil {
|
||||||
|
t.Fatalf("decoding JSON response: %v", err)
|
||||||
|
}
|
||||||
|
if result.ID != 1 {
|
||||||
|
t.Errorf("expected ID 1, got %d", result.ID)
|
||||||
|
}
|
||||||
|
if result.Name != "Alice" {
|
||||||
|
t.Errorf("expected Name %q, got %q", "Alice", result.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure slog import is used (referenced in imports for completeness with the spec).
|
||||||
|
var _ = slog.Default
|
||||||
49
error.go
Normal file
49
error.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors returned by httpx components.
|
||||||
|
// These are aliases for the canonical errors defined in sub-packages,
|
||||||
|
// so that errors.Is works regardless of which import the caller uses.
|
||||||
|
var (
|
||||||
|
ErrRetryExhausted = retry.ErrRetryExhausted
|
||||||
|
ErrCircuitOpen = circuitbreaker.ErrCircuitOpen
|
||||||
|
ErrNoHealthy = balancer.ErrNoHealthy
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error provides structured error information for failed HTTP operations.
|
||||||
|
type Error struct {
|
||||||
|
// Op is the operation that failed (e.g. "Get", "Do").
|
||||||
|
Op string
|
||||||
|
// URL is the originally-requested URL.
|
||||||
|
URL string
|
||||||
|
// Endpoint is the resolved endpoint URL (after balancing).
|
||||||
|
Endpoint string
|
||||||
|
// StatusCode is the HTTP status code, if a response was received.
|
||||||
|
StatusCode int
|
||||||
|
// Retries is the number of retry attempts made.
|
||||||
|
Retries int
|
||||||
|
// Err is the underlying error.
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
if e.Endpoint != "" && e.Endpoint != e.URL {
|
||||||
|
return fmt.Sprintf("httpx: %s %s (endpoint %s): %v", e.Op, e.URL, e.Endpoint, e.Err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("httpx: %s %s: %v", e.Op, e.URL, e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
// ErrorMapper maps an HTTP response to an error. If the response is
|
||||||
|
// acceptable, the mapper should return nil. Used by Client to convert
|
||||||
|
// non-successful HTTP responses into Go errors.
|
||||||
|
type ErrorMapper func(resp *http.Response) error
|
||||||
90
error_test.go
Normal file
90
error_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestError(t *testing.T) {
|
||||||
|
t.Run("formats without endpoint", func(t *testing.T) {
|
||||||
|
inner := errors.New("connection refused")
|
||||||
|
e := &httpx.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com/api",
|
||||||
|
Err: inner,
|
||||||
|
}
|
||||||
|
|
||||||
|
want := "httpx: Get http://example.com/api: connection refused"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("formats with endpoint different from url", func(t *testing.T) {
|
||||||
|
inner := errors.New("timeout")
|
||||||
|
e := &httpx.Error{
|
||||||
|
Op: "Do",
|
||||||
|
URL: "http://example.com/api",
|
||||||
|
Endpoint: "http://node1.example.com/api",
|
||||||
|
Err: inner,
|
||||||
|
}
|
||||||
|
|
||||||
|
want := "httpx: Do http://example.com/api (endpoint http://node1.example.com/api): timeout"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("formats with endpoint same as url", func(t *testing.T) {
|
||||||
|
inner := errors.New("not found")
|
||||||
|
e := &httpx.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com/api",
|
||||||
|
Endpoint: "http://example.com/api",
|
||||||
|
Err: inner,
|
||||||
|
}
|
||||||
|
|
||||||
|
want := "httpx: Get http://example.com/api: not found"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unwrap returns inner error", func(t *testing.T) {
|
||||||
|
inner := errors.New("underlying")
|
||||||
|
e := &httpx.Error{Op: "Get", URL: "http://example.com", Err: inner}
|
||||||
|
|
||||||
|
if got := e.Unwrap(); got != inner {
|
||||||
|
t.Errorf("Unwrap() = %v, want %v", got, inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(e, inner) {
|
||||||
|
t.Error("errors.Is should find the inner error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSentinelErrors(t *testing.T) {
|
||||||
|
t.Run("ErrRetryExhausted", func(t *testing.T) {
|
||||||
|
if httpx.ErrRetryExhausted == nil {
|
||||||
|
t.Fatal("ErrRetryExhausted is nil")
|
||||||
|
}
|
||||||
|
if httpx.ErrRetryExhausted.Error() == "" {
|
||||||
|
t.Fatal("ErrRetryExhausted has empty message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ErrCircuitOpen", func(t *testing.T) {
|
||||||
|
if httpx.ErrCircuitOpen == nil {
|
||||||
|
t.Fatal("ErrCircuitOpen is nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ErrNoHealthy", func(t *testing.T) {
|
||||||
|
if httpx.ErrNoHealthy == nil {
|
||||||
|
t.Fatal("ErrNoHealthy is nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
175
internal/clock/clock.go
Normal file
175
internal/clock/clock.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package clock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Clock abstracts time operations for deterministic testing.
|
||||||
|
type Clock interface {
|
||||||
|
Now() time.Time
|
||||||
|
Since(t time.Time) time.Duration
|
||||||
|
NewTimer(d time.Duration) Timer
|
||||||
|
After(d time.Duration) <-chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timer abstracts time.Timer for testability.
|
||||||
|
type Timer interface {
|
||||||
|
C() <-chan time.Time
|
||||||
|
Stop() bool
|
||||||
|
Reset(d time.Duration) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// System returns a Clock backed by the real system time.
|
||||||
|
func System() Clock { return systemClock{} }
|
||||||
|
|
||||||
|
type systemClock struct{}
|
||||||
|
|
||||||
|
func (systemClock) Now() time.Time { return time.Now() }
|
||||||
|
func (systemClock) Since(t time.Time) time.Duration { return time.Since(t) }
|
||||||
|
func (systemClock) NewTimer(d time.Duration) Timer { return &systemTimer{t: time.NewTimer(d)} }
|
||||||
|
func (systemClock) After(d time.Duration) <-chan time.Time { return time.After(d) }
|
||||||
|
|
||||||
|
type systemTimer struct{ t *time.Timer }
|
||||||
|
|
||||||
|
func (s *systemTimer) C() <-chan time.Time { return s.t.C }
|
||||||
|
func (s *systemTimer) Stop() bool { return s.t.Stop() }
|
||||||
|
func (s *systemTimer) Reset(d time.Duration) bool { return s.t.Reset(d) }
|
||||||
|
|
||||||
|
// Mock returns a manually-controlled Clock for tests.
|
||||||
|
func Mock(now time.Time) *MockClock {
|
||||||
|
return &MockClock{now: now}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockClock is a deterministic clock for testing.
|
||||||
|
type MockClock struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
now time.Time
|
||||||
|
timers []*mockTimer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClock) Now() time.Time {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.now
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClock) Since(t time.Time) time.Duration {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.now.Sub(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClock) NewTimer(d time.Duration) Timer {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
t := &mockTimer{
|
||||||
|
clock: m,
|
||||||
|
ch: make(chan time.Time, 1),
|
||||||
|
deadline: m.now.Add(d),
|
||||||
|
active: true,
|
||||||
|
}
|
||||||
|
m.timers = append(m.timers, t)
|
||||||
|
if d <= 0 {
|
||||||
|
t.fire(m.now)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClock) After(d time.Duration) <-chan time.Time {
|
||||||
|
return m.NewTimer(d).C()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance moves the clock forward by d and fires any expired timers.
|
||||||
|
func (m *MockClock) Advance(d time.Duration) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.now = m.now.Add(d)
|
||||||
|
now := m.now
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.fireExpired(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the clock to an absolute time and fires any expired timers.
|
||||||
|
func (m *MockClock) Set(t time.Time) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.now = t
|
||||||
|
now := m.now
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.fireExpired(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fireExpired fires all active timers whose deadline has passed, then
|
||||||
|
// removes inactive timers to prevent unbounded growth.
|
||||||
|
func (m *MockClock) fireExpired(now time.Time) {
|
||||||
|
m.mu.Lock()
|
||||||
|
timers := m.timers
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, t := range timers {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.active && !now.Before(t.deadline) {
|
||||||
|
t.fire(now)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact: remove inactive timers. Use a new slice to avoid aliasing
|
||||||
|
// the backing array (NewTimer may have appended between snapshots).
|
||||||
|
m.mu.Lock()
|
||||||
|
n := len(m.timers)
|
||||||
|
active := make([]*mockTimer, 0, n)
|
||||||
|
for _, t := range m.timers {
|
||||||
|
t.mu.Lock()
|
||||||
|
keep := t.active
|
||||||
|
t.mu.Unlock()
|
||||||
|
if keep {
|
||||||
|
active = append(active, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.timers = active
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockTimer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
clock *MockClock
|
||||||
|
ch chan time.Time
|
||||||
|
deadline time.Time
|
||||||
|
active bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *mockTimer) C() <-chan time.Time { return t.ch }
|
||||||
|
|
||||||
|
func (t *mockTimer) Stop() bool {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
was := t.active
|
||||||
|
t.active = false
|
||||||
|
return was
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *mockTimer) Reset(d time.Duration) bool {
|
||||||
|
// Acquire clock lock first to match the lock ordering in fireExpired
|
||||||
|
// (clock.mu → t.mu), preventing deadlock.
|
||||||
|
t.clock.mu.Lock()
|
||||||
|
deadline := t.clock.now.Add(d)
|
||||||
|
t.clock.mu.Unlock()
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
was := t.active
|
||||||
|
t.active = true
|
||||||
|
t.deadline = deadline
|
||||||
|
return was
|
||||||
|
}
|
||||||
|
|
||||||
|
// fire sends the time on the channel. Caller must hold t.mu.
|
||||||
|
func (t *mockTimer) fire(now time.Time) {
|
||||||
|
t.active = false
|
||||||
|
select {
|
||||||
|
case t.ch <- now:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
33
middleware/auth.go
Normal file
33
middleware/auth.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BearerAuth returns a middleware that sets the Authorization header to a
|
||||||
|
// Bearer token obtained by calling tokenFunc on each request. If tokenFunc
|
||||||
|
// returns an error, the request is not sent and the error is returned.
|
||||||
|
func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
token, err := tokenFunc(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuth returns a middleware that sets HTTP Basic Authentication
|
||||||
|
// credentials on every outgoing request.
|
||||||
|
func BasicAuth(username, password string) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req.SetBasicAuth(username, password)
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
89
middleware/auth_test.go
Normal file
89
middleware/auth_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBearerAuth(t *testing.T) {
|
||||||
|
t.Run("sets authorization header", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenFunc := func(_ context.Context) (string, error) {
|
||||||
|
return "my-secret-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := "Bearer my-secret-token"
|
||||||
|
if got := captured.Get("Authorization"); got != want {
|
||||||
|
t.Errorf("Authorization = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
t.Fatal("base transport should not be called")
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenErr := errors.New("token expired")
|
||||||
|
tokenFunc := func(_ context.Context) (string, error) {
|
||||||
|
return "", tokenErr
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, tokenErr) {
|
||||||
|
t.Errorf("got error %v, want %v", err, tokenErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth(t *testing.T) {
|
||||||
|
t.Run("sets basic auth header", func(t *testing.T) {
|
||||||
|
var capturedReq *http.Request
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
capturedReq = req
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.BasicAuth("user", "pass")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username, password, ok := capturedReq.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("BasicAuth() returned ok=false")
|
||||||
|
}
|
||||||
|
if username != "user" {
|
||||||
|
t.Errorf("username = %q, want %q", username, "user")
|
||||||
|
}
|
||||||
|
if password != "pass" {
|
||||||
|
t.Errorf("password = %q, want %q", password, "pass")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
29
middleware/headers.go
Normal file
29
middleware/headers.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// DefaultHeaders returns a middleware that adds the given headers to every
|
||||||
|
// outgoing request. Existing headers on the request are not overwritten.
|
||||||
|
func DefaultHeaders(headers http.Header) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
for key, values := range headers {
|
||||||
|
if req.Header.Get(key) != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
req.Header.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAgent returns a middleware that sets the User-Agent header on every
|
||||||
|
// outgoing request, unless one is already present.
|
||||||
|
func UserAgent(ua string) Middleware {
|
||||||
|
return DefaultHeaders(http.Header{
|
||||||
|
"User-Agent": {ua},
|
||||||
|
})
|
||||||
|
}
|
||||||
107
middleware/headers_test.go
Normal file
107
middleware/headers_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultHeaders(t *testing.T) {
|
||||||
|
t.Run("adds headers without overwriting existing", func(t *testing.T) {
|
||||||
|
defaults := http.Header{
|
||||||
|
"X-Custom": {"default-value"},
|
||||||
|
"X-Untouched": {"from-middleware"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.DefaultHeaders(defaults)(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("X-Custom", "request-value")
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("X-Custom"); got != "request-value" {
|
||||||
|
t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value")
|
||||||
|
}
|
||||||
|
if got := captured.Get("X-Untouched"); got != "from-middleware" {
|
||||||
|
t.Errorf("X-Untouched = %q, want %q", got, "from-middleware")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("adds headers when absent", func(t *testing.T) {
|
||||||
|
defaults := http.Header{
|
||||||
|
"Accept": {"application/json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.DefaultHeaders(defaults)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("Accept"); got != "application/json" {
|
||||||
|
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAgent(t *testing.T) {
|
||||||
|
t.Run("sets user agent header", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("User-Agent"); got != "httpx/1.0" {
|
||||||
|
t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not overwrite existing user agent", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("User-Agent", "custom-agent")
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("User-Agent"); got != "custom-agent" {
|
||||||
|
t.Errorf("User-Agent = %q, want %q", got, "custom-agent")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
38
middleware/logging.go
Normal file
38
middleware/logging.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logging returns a middleware that logs each request's method, URL, status
|
||||||
|
// code, duration, and error (if any) using the provided structured logger.
|
||||||
|
// Successful responses are logged at Info level; errors at Error level.
|
||||||
|
func Logging(logger *slog.Logger) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
attrs := []slog.Attr{
|
||||||
|
slog.String("method", req.Method),
|
||||||
|
slog.String("url", req.URL.String()),
|
||||||
|
slog.Duration("duration", duration),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
attrs = append(attrs, slog.String("error", err.Error()))
|
||||||
|
logger.LogAttrs(req.Context(), slog.LevelError, "request failed", attrs...)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs = append(attrs, slog.Int("status", resp.StatusCode))
|
||||||
|
logger.LogAttrs(req.Context(), slog.LevelInfo, "request completed", attrs...)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
124
middleware/logging_test.go
Normal file
124
middleware/logging_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// captureHandler is a slog.Handler that captures log records for inspection.
|
||||||
|
type captureHandler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
records []slog.Record
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||||
|
|
||||||
|
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.records = append(h.records, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||||
|
func (h *captureHandler) WithGroup(_ string) slog.Handler { return h }
|
||||||
|
|
||||||
|
func (h *captureHandler) lastRecord() slog.Record {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return h.records[len(h.records)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogging(t *testing.T) {
|
||||||
|
t.Run("logs method url status duration on success", func(t *testing.T) {
|
||||||
|
handler := &captureHandler{}
|
||||||
|
logger := slog.New(handler)
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Logging(logger)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil)
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := handler.lastRecord()
|
||||||
|
if rec.Level != slog.LevelInfo {
|
||||||
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := map[string]string{}
|
||||||
|
rec.Attrs(func(a slog.Attr) bool {
|
||||||
|
attrs[a.Key] = a.Value.String()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if attrs["method"] != "POST" {
|
||||||
|
t.Errorf("method = %q, want %q", attrs["method"], "POST")
|
||||||
|
}
|
||||||
|
if attrs["url"] != "http://example.com/api" {
|
||||||
|
t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["status"]; !ok {
|
||||||
|
t.Error("missing status attribute")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["duration"]; !ok {
|
||||||
|
t.Error("missing duration attribute")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logs error on failure", func(t *testing.T) {
|
||||||
|
handler := &captureHandler{}
|
||||||
|
logger := slog.New(handler)
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("connection refused")
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Logging(logger)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil)
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := handler.lastRecord()
|
||||||
|
if rec.Level != slog.LevelError {
|
||||||
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelError)
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := map[string]string{}
|
||||||
|
rec.Attrs(func(a slog.Attr) bool {
|
||||||
|
attrs[a.Key] = a.Value.String()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if attrs["error"] != "connection refused" {
|
||||||
|
t.Errorf("error = %q, want %q", attrs["error"], "connection refused")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["method"]; !ok {
|
||||||
|
t.Error("missing method attribute")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["url"]; !ok {
|
||||||
|
t.Error("missing url attribute")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
29
middleware/middleware.go
Normal file
29
middleware/middleware.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Middleware wraps an http.RoundTripper to add behavior.
|
||||||
|
// This is the fundamental building block of the httpx library.
|
||||||
|
type Middleware func(http.RoundTripper) http.RoundTripper
|
||||||
|
|
||||||
|
// Chain composes middlewares so that Chain(A, B, C)(base) == A(B(C(base))).
|
||||||
|
// Middlewares are applied from right to left: C wraps base first, then B wraps
|
||||||
|
// the result, then A wraps last. This means A is the outermost layer and sees
|
||||||
|
// every request first.
|
||||||
|
func Chain(mws ...Middleware) Middleware {
|
||||||
|
return func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
|
rt = mws[i](rt)
|
||||||
|
}
|
||||||
|
return rt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTripperFunc is an adapter to allow the use of ordinary functions as
|
||||||
|
// http.RoundTripper. It works exactly like http.HandlerFunc for handlers.
|
||||||
|
type RoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
// RoundTrip implements http.RoundTripper.
|
||||||
|
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
115
middleware/middleware_test.go
Normal file
115
middleware/middleware_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain(t *testing.T) {
|
||||||
|
t.Run("applies middlewares in correct order", func(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
mwA := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "A-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "A-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
mwB := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "B-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "B-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
mwC := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "C-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "C-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "base")
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := middleware.Chain(mwA, mwB, mwC)(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := chained.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"}
|
||||||
|
if len(order) != len(expected) {
|
||||||
|
t.Fatalf("got %v, want %v", order, expected)
|
||||||
|
}
|
||||||
|
for i, v := range expected {
|
||||||
|
if order[i] != v {
|
||||||
|
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty chain returns base transport", func(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
called = true
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := middleware.Chain()(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := chained.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("base transport was not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripperFunc(t *testing.T) {
|
||||||
|
t.Run("implements RoundTripper", func(t *testing.T) {
|
||||||
|
fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var rt http.RoundTripper = fn
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
22
middleware/recovery.go
Normal file
22
middleware/recovery.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recovery returns a middleware that recovers from panics in the inner
|
||||||
|
// RoundTripper chain. A recovered panic is converted to an error wrapping
|
||||||
|
// the panic value.
|
||||||
|
func Recovery() Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic recovered in round trip: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
51
middleware/recovery_test.go
Normal file
51
middleware/recovery_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecovery(t *testing.T) {
|
||||||
|
t.Run("recovers from panic and returns error", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
panic("something went wrong")
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Recovery()(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Errorf("expected nil response, got %v", resp)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "panic recovered") {
|
||||||
|
t.Errorf("error = %q, want it to contain %q", err.Error(), "panic recovered")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "something went wrong") {
|
||||||
|
t.Errorf("error = %q, want it to contain %q", err.Error(), "something went wrong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passes through normal responses", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Recovery()(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
34
request.go
Normal file
34
request.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRequest creates an http.Request with context. It is a convenience
|
||||||
|
// wrapper around http.NewRequestWithContext.
|
||||||
|
func NewRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||||
|
return http.NewRequestWithContext(ctx, method, url, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJSONRequest creates an http.Request with a JSON-encoded body and
|
||||||
|
// sets Content-Type to application/json.
|
||||||
|
func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Request, error) {
|
||||||
|
b, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("httpx: encoding JSON body: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader(b)), nil
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
78
request_test.go
Normal file
78
request_test.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package httpx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewJSONRequest(t *testing.T) {
|
||||||
|
t.Run("body is JSON encoded", func(t *testing.T) {
|
||||||
|
payload := map[string]string{"key": "value"}
|
||||||
|
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded map[string]string
|
||||||
|
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if decoded["key"] != "value" {
|
||||||
|
t.Errorf("decoded[key] = %q, want %q", decoded["key"], "value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("content type is set", func(t *testing.T) {
|
||||||
|
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := req.Header.Get("Content-Type")
|
||||||
|
if ct != "application/json" {
|
||||||
|
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetBody works", func(t *testing.T) {
|
||||||
|
payload := map[string]int{"num": 123}
|
||||||
|
req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.GetBody == nil {
|
||||||
|
t.Fatal("GetBody is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read body first time
|
||||||
|
b1, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a fresh body
|
||||||
|
body2, err := req.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody(): %v", err)
|
||||||
|
}
|
||||||
|
b2, err := io.ReadAll(body2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(b1) != string(b2) {
|
||||||
|
t.Errorf("GetBody returned different data: %q vs %q", b1, b2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
99
response.go
Normal file
99
response.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response wraps http.Response with convenience methods.
|
||||||
|
type Response struct {
|
||||||
|
*http.Response
|
||||||
|
body []byte
|
||||||
|
read bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponse(resp *http.Response) *Response {
|
||||||
|
return &Response{Response: resp}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes reads and returns the entire response body.
|
||||||
|
// The body is cached so subsequent calls return the same data.
|
||||||
|
func (r *Response) Bytes() ([]byte, error) {
|
||||||
|
if r.read {
|
||||||
|
return r.body, nil
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
b, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.body = b
|
||||||
|
r.read = true
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String reads the response body and returns it as a string.
|
||||||
|
func (r *Response) String() (string, error) {
|
||||||
|
b, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON decodes the response body as JSON into v.
|
||||||
|
func (r *Response) JSON(v any) error {
|
||||||
|
b, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("httpx: reading response body: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(b, v); err != nil {
|
||||||
|
return fmt.Errorf("httpx: decoding JSON: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XML decodes the response body as XML into v.
|
||||||
|
func (r *Response) XML(v any) error {
|
||||||
|
b, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("httpx: reading response body: %w", err)
|
||||||
|
}
|
||||||
|
if err := xml.Unmarshal(b, v); err != nil {
|
||||||
|
return fmt.Errorf("httpx: decoding XML: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSuccess returns true if the status code is in the 2xx range.
|
||||||
|
func (r *Response) IsSuccess() bool {
|
||||||
|
return r.StatusCode >= 200 && r.StatusCode < 300
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsError returns true if the status code is 4xx or 5xx.
|
||||||
|
func (r *Response) IsError() bool {
|
||||||
|
return r.StatusCode >= 400
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close drains and closes the response body.
|
||||||
|
func (r *Response) Close() error {
|
||||||
|
if r.read {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
return r.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BodyReader returns a reader for the response body.
|
||||||
|
// If the body has already been read via Bytes/String/JSON/XML,
|
||||||
|
// returns a reader over the cached bytes.
|
||||||
|
func (r *Response) BodyReader() io.Reader {
|
||||||
|
if r.read {
|
||||||
|
return bytes.NewReader(r.body)
|
||||||
|
}
|
||||||
|
return r.Body
|
||||||
|
}
|
||||||
96
response_test.go
Normal file
96
response_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package httpx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeTestResponse(statusCode int, body string) *Response {
|
||||||
|
return newResponse(&http.Response{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponse(t *testing.T) {
|
||||||
|
t.Run("Bytes returns body", func(t *testing.T) {
|
||||||
|
r := makeTestResponse(200, "hello world")
|
||||||
|
b, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(b) != "hello world" {
|
||||||
|
t.Errorf("Bytes() = %q, want %q", string(b), "hello world")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("body caching returns same data", func(t *testing.T) {
|
||||||
|
r := makeTestResponse(200, "cached body")
|
||||||
|
b1, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
b2, err := r.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(b1) != string(b2) {
|
||||||
|
t.Errorf("Bytes() returned different data: %q vs %q", b1, b2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("String returns body as string", func(t *testing.T) {
|
||||||
|
r := makeTestResponse(200, "string body")
|
||||||
|
s, err := r.String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("String() error: %v", err)
|
||||||
|
}
|
||||||
|
if s != "string body" {
|
||||||
|
t.Errorf("String() = %q, want %q", s, "string body")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JSON decodes body", func(t *testing.T) {
|
||||||
|
r := makeTestResponse(200, `{"name":"test","value":42}`)
|
||||||
|
var result struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Value int `json:"value"`
|
||||||
|
}
|
||||||
|
if err := r.JSON(&result); err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "test" {
|
||||||
|
t.Errorf("Name = %q, want %q", result.Name, "test")
|
||||||
|
}
|
||||||
|
if result.Value != 42 {
|
||||||
|
t.Errorf("Value = %d, want %d", result.Value, 42)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsSuccess for 2xx", func(t *testing.T) {
|
||||||
|
for _, code := range []int{200, 201, 204, 299} {
|
||||||
|
r := makeTestResponse(code, "")
|
||||||
|
if !r.IsSuccess() {
|
||||||
|
t.Errorf("IsSuccess() = false for status %d", code)
|
||||||
|
}
|
||||||
|
if r.IsError() {
|
||||||
|
t.Errorf("IsError() = true for status %d", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsError for 4xx and 5xx", func(t *testing.T) {
|
||||||
|
for _, code := range []int{400, 404, 500, 503} {
|
||||||
|
r := makeTestResponse(code, "")
|
||||||
|
if !r.IsError() {
|
||||||
|
t.Errorf("IsError() = false for status %d", code)
|
||||||
|
}
|
||||||
|
if r.IsSuccess() {
|
||||||
|
t.Errorf("IsSuccess() = true for status %d", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
64
retry/backoff.go
Normal file
64
retry/backoff.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backoff computes the delay before the next retry attempt.
|
||||||
|
type Backoff interface {
|
||||||
|
// Delay returns the wait duration for the given attempt number (zero-based).
|
||||||
|
Delay(attempt int) time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExponentialBackoff returns a Backoff that doubles the delay on each attempt.
|
||||||
|
// The delay is calculated as base * 2^attempt, capped at max. When withJitter
|
||||||
|
// is true, a random duration in [0, delay*0.5) is added.
|
||||||
|
func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff {
|
||||||
|
return &exponentialBackoff{
|
||||||
|
base: base,
|
||||||
|
max: max,
|
||||||
|
withJitter: withJitter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConstantBackoff returns a Backoff that always returns the same delay.
|
||||||
|
func ConstantBackoff(d time.Duration) Backoff {
|
||||||
|
return constantBackoff{delay: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
type exponentialBackoff struct {
|
||||||
|
base time.Duration
|
||||||
|
max time.Duration
|
||||||
|
withJitter bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *exponentialBackoff) Delay(attempt int) time.Duration {
|
||||||
|
delay := b.base
|
||||||
|
for range attempt {
|
||||||
|
delay *= 2
|
||||||
|
if delay >= b.max {
|
||||||
|
delay = b.max
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.withJitter {
|
||||||
|
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
|
||||||
|
delay += jitter
|
||||||
|
}
|
||||||
|
|
||||||
|
if delay > b.max {
|
||||||
|
delay = b.max
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
type constantBackoff struct {
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b constantBackoff) Delay(_ int) time.Duration {
|
||||||
|
return b.delay
|
||||||
|
}
|
||||||
77
retry/backoff_test.go
Normal file
77
retry/backoff_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExponentialBackoff(t *testing.T) {
|
||||||
|
t.Run("doubles each attempt", func(t *testing.T) {
|
||||||
|
b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false)
|
||||||
|
|
||||||
|
want := []time.Duration{
|
||||||
|
100 * time.Millisecond, // attempt 0: base
|
||||||
|
200 * time.Millisecond, // attempt 1: base*2
|
||||||
|
400 * time.Millisecond, // attempt 2: base*4
|
||||||
|
800 * time.Millisecond, // attempt 3: base*8
|
||||||
|
1600 * time.Millisecond, // attempt 4: base*16
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range want {
|
||||||
|
got := b.Delay(i)
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("attempt %d: expected %v, got %v", i, expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("caps at max", func(t *testing.T) {
|
||||||
|
b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false)
|
||||||
|
|
||||||
|
// attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms
|
||||||
|
for _, attempt := range []int{3, 4, 10} {
|
||||||
|
got := b.Delay(attempt)
|
||||||
|
if got != 500*time.Millisecond {
|
||||||
|
t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with jitter adds randomness", func(t *testing.T) {
|
||||||
|
base := 100 * time.Millisecond
|
||||||
|
b := ExponentialBackoff(base, 10*time.Second, true)
|
||||||
|
|
||||||
|
// Run multiple times; with jitter, delay >= base for attempt 0.
|
||||||
|
// Also verify not all values are identical (randomness).
|
||||||
|
seen := make(map[time.Duration]bool)
|
||||||
|
for range 20 {
|
||||||
|
d := b.Delay(0)
|
||||||
|
if d < base {
|
||||||
|
t.Fatalf("delay %v is less than base %v", d, base)
|
||||||
|
}
|
||||||
|
// With jitter: delay = base + rand in [0, base/2), so max is base*1.5
|
||||||
|
maxExpected := base + base/2
|
||||||
|
if d > maxExpected {
|
||||||
|
t.Fatalf("delay %v exceeds expected max %v", d, maxExpected)
|
||||||
|
}
|
||||||
|
seen[d] = true
|
||||||
|
}
|
||||||
|
if len(seen) < 2 {
|
||||||
|
t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstantBackoff(t *testing.T) {
|
||||||
|
t.Run("always returns same value", func(t *testing.T) {
|
||||||
|
d := 250 * time.Millisecond
|
||||||
|
b := ConstantBackoff(d)
|
||||||
|
|
||||||
|
for _, attempt := range []int{0, 1, 2, 5, 100} {
|
||||||
|
got := b.Delay(attempt)
|
||||||
|
if got != d {
|
||||||
|
t.Errorf("attempt %d: expected %v, got %v", attempt, d, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
56
retry/options.go
Normal file
56
retry/options.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
maxAttempts int // default 3
|
||||||
|
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||||
|
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||||
|
retryAfter bool // default true, respect Retry-After header
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the retry transport.
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
func defaults() options {
|
||||||
|
return options{
|
||||||
|
maxAttempts: 3,
|
||||||
|
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
||||||
|
policy: defaultPolicy{},
|
||||||
|
retryAfter: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxAttempts sets the maximum number of attempts (including the first).
|
||||||
|
// Values less than 1 are treated as 1 (no retries).
|
||||||
|
func WithMaxAttempts(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n < 1 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
o.maxAttempts = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBackoff sets the backoff strategy used to compute delays between retries.
|
||||||
|
func WithBackoff(b Backoff) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.backoff = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPolicy sets the retry policy that decides whether to retry a request.
|
||||||
|
func WithPolicy(p Policy) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.policy = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetryAfter controls whether the Retry-After response header is respected.
|
||||||
|
// When enabled and present, the Retry-After delay is used if it exceeds the
|
||||||
|
// backoff delay.
|
||||||
|
func WithRetryAfter(enable bool) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.retryAfter = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
137
retry/retry.go
Normal file
137
retry/retry.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrRetryExhausted is returned when all retry attempts have been exhausted
|
||||||
|
// and the last attempt also failed.
|
||||||
|
var ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted")
|
||||||
|
|
||||||
|
// Policy decides whether a failed request should be retried.
|
||||||
|
type Policy interface {
|
||||||
|
// ShouldRetry reports whether the request should be retried. The extra
|
||||||
|
// duration, if non-zero, is a policy-suggested delay that overrides the
|
||||||
|
// backoff strategy.
|
||||||
|
ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns a middleware that retries failed requests according to
|
||||||
|
// the provided options.
|
||||||
|
func Transport(opts ...Option) middleware.Middleware {
|
||||||
|
cfg := defaults()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
var resp *http.Response
|
||||||
|
var err error
|
||||||
|
var exhausted bool
|
||||||
|
|
||||||
|
for attempt := range cfg.maxAttempts {
|
||||||
|
// For retries (attempt > 0), restore the request body.
|
||||||
|
if attempt > 0 {
|
||||||
|
if req.GetBody != nil {
|
||||||
|
body, bodyErr := req.GetBody()
|
||||||
|
if bodyErr != nil {
|
||||||
|
return resp, bodyErr
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
} else if req.Body != nil {
|
||||||
|
// Body was consumed and cannot be re-created.
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = next.RoundTrip(req)
|
||||||
|
|
||||||
|
// Last attempt — return whatever we got.
|
||||||
|
if attempt == cfg.maxAttempts-1 {
|
||||||
|
exhausted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err)
|
||||||
|
if !shouldRetry {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute delay: use backoff or policy delay, whichever is larger.
|
||||||
|
delay := cfg.backoff.Delay(attempt)
|
||||||
|
if policyDelay > delay {
|
||||||
|
delay = policyDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respect Retry-After header if enabled.
|
||||||
|
if cfg.retryAfter && resp != nil {
|
||||||
|
if ra, ok := ParseRetryAfter(resp); ok && ra > delay {
|
||||||
|
delay = ra
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain and close the response body to release the connection.
|
||||||
|
if resp != nil {
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the delay or context cancellation.
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
select {
|
||||||
|
case <-req.Context().Done():
|
||||||
|
timer.Stop()
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap with ErrRetryExhausted only when all attempts were used.
|
||||||
|
if exhausted && err != nil {
|
||||||
|
err = fmt.Errorf("%w: %w", ErrRetryExhausted, err)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultPolicy retries on network errors, 429, and 5xx server errors.
|
||||||
|
// It refuses to retry non-idempotent methods.
|
||||||
|
type defaultPolicy struct{}
|
||||||
|
|
||||||
|
func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
if !isIdempotent(req.Method) {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network error — always retry idempotent requests.
|
||||||
|
if err != nil {
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusTooManyRequests, // 429
|
||||||
|
http.StatusBadGateway, // 502
|
||||||
|
http.StatusServiceUnavailable, // 503
|
||||||
|
http.StatusGatewayTimeout: // 504
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIdempotent reports whether the HTTP method is safe to retry.
|
||||||
|
func isIdempotent(method string) bool {
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
43
retry/retry_after.go
Normal file
43
retry/retry_after.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231).
|
||||||
|
// It supports both the delay-seconds format ("120") and the HTTP-date format
|
||||||
|
// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the
|
||||||
|
// header was present and valid; otherwise returns 0 and false.
|
||||||
|
func ParseRetryAfter(resp *http.Response) (time.Duration, bool) {
|
||||||
|
if resp == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
val := resp.Header.Get("Retry-After")
|
||||||
|
if val == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try delay-seconds first (most common).
|
||||||
|
if seconds, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||||
|
if seconds < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try HTTP-date format (RFC 7231 section 7.1.1.1).
|
||||||
|
t, err := http.ParseTime(val)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := time.Until(t)
|
||||||
|
if delay < 0 {
|
||||||
|
// The date is in the past; no need to wait.
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return delay, true
|
||||||
|
}
|
||||||
58
retry/retry_after_test.go
Normal file
58
retry/retry_after_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseRetryAfter(t *testing.T) {
|
||||||
|
t.Run("seconds format", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: http.Header{"Retry-After": []string{"120"}},
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected ok=true")
|
||||||
|
}
|
||||||
|
if d != 120*time.Second {
|
||||||
|
t.Fatalf("expected 120s, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty header", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for empty header")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil response", func(t *testing.T) {
|
||||||
|
d, ok := ParseRetryAfter(nil)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for nil response")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative value", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: http.Header{"Retry-After": []string{"-5"}},
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for negative value")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
237
retry/retry_test.go
Normal file
237
retry/retry_test.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusResponse(code int) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: code,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport(t *testing.T) {
|
||||||
|
t.Run("successful request no retry", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("retries on 503 then succeeds", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n < 3 {
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 3 {
|
||||||
|
t.Fatalf("expected 3 calls, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stops on context cancellation", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(5),
|
||||||
|
WithBackoff(ConstantBackoff(50*time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("respects maxAttempts", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(2),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return statusResponse(http.StatusBadGateway), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
var bodies []string
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
b, _ := io.ReadAll(req.Body)
|
||||||
|
bodies = append(bodies, string(b))
|
||||||
|
if len(bodies) < 2 {
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
bodyContent := "request-body"
|
||||||
|
body := bytes.NewReader([]byte(bodyContent))
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
|
||||||
|
req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", got)
|
||||||
|
}
|
||||||
|
for i, b := range bodies {
|
||||||
|
if b != bodyContent {
|
||||||
|
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom policy", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
|
||||||
|
// Custom policy: retry only on 418
|
||||||
|
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
if resp != nil && resp.StatusCode == http.StatusTeapot {
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
WithPolicy(custom),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return statusResponse(http.StatusTeapot), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// policyFunc adapts a function into a Policy.
|
||||||
|
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
||||||
|
|
||||||
|
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
return f(attempt, req, resp, err)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user