Fix sentinel error aliasing, hot-path allocations, and resource leaks
- Deduplicate sentinel errors: httpx.ErrNoHealthy, ErrCircuitOpen, and ErrRetryExhausted are now aliases to the canonical sub-package values so errors.Is works across package boundaries - Retry transport returns ErrRetryExhausted only when all attempts are actually exhausted, not on early policy exit - Balancer: pre-parse endpoint URLs at construction, replace req.Clone with cheap shallow struct copy to avoid per-request allocations - Circuit breaker: Load before LoadOrStore to avoid allocating a Breaker on every request for known hosts - Health checker: drain response body before close for connection reuse, probe endpoints concurrently, run initial probe synchronously in Start - Client: add Close() to shut down health checker goroutine, propagate URL resolution errors instead of silently discarding them - MockClock: fix lock ordering in Reset (clock.mu before t.mu), fix timer slice compaction to avoid backing-array aliasing, extract fireExpired to deduplicate Advance/Set
This commit is contained in:
@@ -2,6 +2,7 @@ package balancer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@@ -23,6 +24,19 @@ type Strategy interface {
|
|||||||
Next(healthy []Endpoint) (Endpoint, error)
|
Next(healthy []Endpoint) (Endpoint, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Closer can be used to shut down resources associated with a balancer
|
||||||
|
// transport (e.g. background health checker goroutines).
|
||||||
|
type Closer struct {
|
||||||
|
healthChecker *HealthChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops background goroutines. Safe to call multiple times.
|
||||||
|
func (c *Closer) Close() {
|
||||||
|
if c.healthChecker != nil {
|
||||||
|
c.healthChecker.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Transport returns a middleware that load-balances requests across the
|
// Transport returns a middleware that load-balances requests across the
|
||||||
// provided endpoints using the configured strategy.
|
// provided endpoints using the configured strategy.
|
||||||
//
|
//
|
||||||
@@ -33,7 +47,7 @@ type Strategy interface {
|
|||||||
// If active health checking is enabled (WithHealthCheck), a background
|
// If active health checking is enabled (WithHealthCheck), a background
|
||||||
// goroutine periodically probes endpoints. Otherwise all endpoints are
|
// goroutine periodically probes endpoints. Otherwise all endpoints are
|
||||||
// assumed healthy.
|
// assumed healthy.
|
||||||
func Transport(endpoints []Endpoint, opts ...Option) middleware.Middleware {
|
func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Closer) {
|
||||||
o := &options{
|
o := &options{
|
||||||
strategy: RoundRobin(),
|
strategy: RoundRobin(),
|
||||||
}
|
}
|
||||||
@@ -41,10 +55,22 @@ func Transport(endpoints []Endpoint, opts ...Option) middleware.Middleware {
|
|||||||
opt(o)
|
opt(o)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-parse endpoint URLs once at construction time.
|
||||||
|
parsed := make(map[string]*url.URL, len(endpoints))
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
u, err := url.Parse(ep.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err))
|
||||||
|
}
|
||||||
|
parsed[ep.URL] = u
|
||||||
|
}
|
||||||
|
|
||||||
if o.healthChecker != nil {
|
if o.healthChecker != nil {
|
||||||
o.healthChecker.Start(endpoints)
|
o.healthChecker.Start(endpoints)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
closer := &Closer{healthChecker: o.healthChecker}
|
||||||
|
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
healthy := endpoints
|
healthy := endpoints
|
||||||
@@ -61,18 +87,18 @@ func Transport(endpoints []Endpoint, opts ...Option) middleware.Middleware {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
epURL, err := url.Parse(ep.URL)
|
epURL := parsed[ep.URL]
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone the request URL and replace scheme+host with the endpoint.
|
// Shallow-copy request and URL to avoid mutating the original,
|
||||||
r := req.Clone(req.Context())
|
// without the expense of req.Clone's deep header copy.
|
||||||
|
r := *req
|
||||||
|
u := *req.URL
|
||||||
|
r.URL = &u
|
||||||
r.URL.Scheme = epURL.Scheme
|
r.URL.Scheme = epURL.Scheme
|
||||||
r.URL.Host = epURL.Host
|
r.URL.Host = epURL.Host
|
||||||
r.Host = epURL.Host
|
r.Host = epURL.Host
|
||||||
|
|
||||||
return next.RoundTrip(r)
|
return next.RoundTrip(&r)
|
||||||
})
|
})
|
||||||
}
|
}, closer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) {
|
|||||||
return okResponse(), nil
|
return okResponse(), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
rt := Transport(endpoints)(base)
|
mw, _ := Transport(endpoints)
|
||||||
|
rt := mw(base)
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil)
|
req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -67,7 +68,8 @@ func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
rt := Transport(endpoints)(base)
|
mw, _ := Transport(endpoints)
|
||||||
|
rt := mw(base)
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil)
|
req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package balancer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -70,8 +71,10 @@ func newHealthChecker(opts ...HealthOption) *HealthChecker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start begins the background health checking loop for the given endpoints.
|
// Start begins the background health checking loop for the given endpoints.
|
||||||
// All endpoints are initially considered healthy.
|
// An initial probe is run synchronously so that unhealthy endpoints are
|
||||||
|
// detected before the first request.
|
||||||
func (h *HealthChecker) Start(endpoints []Endpoint) {
|
func (h *HealthChecker) Start(endpoints []Endpoint) {
|
||||||
|
// Mark all healthy as a safe default, then immediately probe.
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
for _, ep := range endpoints {
|
for _, ep := range endpoints {
|
||||||
h.status[ep.URL] = true
|
h.status[ep.URL] = true
|
||||||
@@ -82,6 +85,9 @@ func (h *HealthChecker) Start(endpoints []Endpoint) {
|
|||||||
h.cancel = cancel
|
h.cancel = cancel
|
||||||
h.stopped = make(chan struct{})
|
h.stopped = make(chan struct{})
|
||||||
|
|
||||||
|
// Run initial probe synchronously so callers don't hit stale state.
|
||||||
|
h.probe(ctx, endpoints)
|
||||||
|
|
||||||
go h.loop(ctx, endpoints)
|
go h.loop(ctx, endpoints)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +117,7 @@ func (h *HealthChecker) Healthy(endpoints []Endpoint) []Endpoint {
|
|||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
var result []Endpoint
|
result := make([]Endpoint, 0, len(endpoints))
|
||||||
for _, ep := range endpoints {
|
for _, ep := range endpoints {
|
||||||
if h.status[ep.URL] {
|
if h.status[ep.URL] {
|
||||||
result = append(result, ep)
|
result = append(result, ep)
|
||||||
@@ -137,13 +143,18 @@ func (h *HealthChecker) loop(ctx context.Context, endpoints []Endpoint) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) {
|
func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(endpoints))
|
||||||
for _, ep := range endpoints {
|
for _, ep := range endpoints {
|
||||||
healthy := h.check(ctx, ep)
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
h.mu.Lock()
|
healthy := h.check(ctx, ep)
|
||||||
h.status[ep.URL] = healthy
|
h.mu.Lock()
|
||||||
h.mu.Unlock()
|
h.status[ep.URL] = healthy
|
||||||
|
h.mu.Unlock()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
|
func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
|
||||||
@@ -156,6 +167,7 @@ func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||||
|
|||||||
@@ -156,7 +156,10 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
host := req.URL.Host
|
host := req.URL.Host
|
||||||
|
|
||||||
val, _ := hosts.LoadOrStore(host, NewBreaker(opts...))
|
val, ok := hosts.Load(host)
|
||||||
|
if !ok {
|
||||||
|
val, _ = hosts.LoadOrStore(host, NewBreaker(opts...))
|
||||||
|
}
|
||||||
cb := val.(*Breaker)
|
cb := val.(*Breaker)
|
||||||
|
|
||||||
done, err := cb.Allow()
|
done, err := cb.Allow()
|
||||||
|
|||||||
42
client.go
42
client.go
@@ -2,6 +2,7 @@ package httpx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -15,9 +16,10 @@ import (
|
|||||||
// Client is a high-level HTTP client that composes middleware for retry,
|
// Client is a high-level HTTP client that composes middleware for retry,
|
||||||
// circuit breaking, load balancing, logging, and more.
|
// circuit breaking, load balancing, logging, and more.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
baseURL string
|
baseURL string
|
||||||
errorMapper ErrorMapper
|
errorMapper ErrorMapper
|
||||||
|
balancerCloser *balancer.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Client with the given options.
|
// New creates a new Client with the given options.
|
||||||
@@ -37,8 +39,11 @@ func New(opts ...Option) *Client {
|
|||||||
var chain []middleware.Middleware
|
var chain []middleware.Middleware
|
||||||
|
|
||||||
// Balancer (innermost, wraps base transport).
|
// Balancer (innermost, wraps base transport).
|
||||||
|
var balancerCloser *balancer.Closer
|
||||||
if len(o.endpoints) > 0 {
|
if len(o.endpoints) > 0 {
|
||||||
chain = append(chain, balancer.Transport(o.endpoints, o.balancerOpts...))
|
var mw middleware.Middleware
|
||||||
|
mw, balancerCloser = balancer.Transport(o.endpoints, o.balancerOpts...)
|
||||||
|
chain = append(chain, mw)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Circuit breaker wraps balancer.
|
// Circuit breaker wraps balancer.
|
||||||
@@ -72,15 +77,18 @@ func New(opts ...Option) *Client {
|
|||||||
Transport: rt,
|
Transport: rt,
|
||||||
Timeout: o.timeout,
|
Timeout: o.timeout,
|
||||||
},
|
},
|
||||||
baseURL: o.baseURL,
|
baseURL: o.baseURL,
|
||||||
errorMapper: o.errorMapper,
|
errorMapper: o.errorMapper,
|
||||||
|
balancerCloser: balancerCloser,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do executes an HTTP request.
|
// Do executes an HTTP request.
|
||||||
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) {
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
c.resolveURL(req)
|
if err := c.resolveURL(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,14 +151,23 @@ func (c *Client) Delete(ctx context.Context, url string) (*Response, error) {
|
|||||||
return c.Do(ctx, req)
|
return c.Do(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close releases resources associated with the Client, such as background
|
||||||
|
// health checker goroutines. It is safe to call multiple times.
|
||||||
|
func (c *Client) Close() {
|
||||||
|
if c.balancerCloser != nil {
|
||||||
|
c.balancerCloser.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HTTPClient returns the underlying *http.Client for advanced use cases.
|
// HTTPClient returns the underlying *http.Client for advanced use cases.
|
||||||
|
// Mutating the returned client may bypass the configured middleware chain.
|
||||||
func (c *Client) HTTPClient() *http.Client {
|
func (c *Client) HTTPClient() *http.Client {
|
||||||
return c.httpClient
|
return c.httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) resolveURL(req *http.Request) {
|
func (c *Client) resolveURL(req *http.Request) error {
|
||||||
if c.baseURL == "" {
|
if c.baseURL == "" {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
// Only resolve relative URLs (no scheme).
|
// Only resolve relative URLs (no scheme).
|
||||||
if req.URL.Scheme == "" && req.URL.Host == "" {
|
if req.URL.Scheme == "" && req.URL.Host == "" {
|
||||||
@@ -159,6 +176,11 @@ func (c *Client) resolveURL(req *http.Request) {
|
|||||||
path = "/" + path
|
path = "/" + path
|
||||||
}
|
}
|
||||||
base := strings.TrimRight(c.baseURL, "/")
|
base := strings.TrimRight(c.baseURL, "/")
|
||||||
req.URL, _ = req.URL.Parse(base + path)
|
u, err := req.URL.Parse(base + path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("httpx: resolving URL %q with base %q: %w", path, c.baseURL, err)
|
||||||
|
}
|
||||||
|
req.URL = u
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
13
error.go
13
error.go
@@ -1,16 +1,21 @@
|
|||||||
package httpx
|
package httpx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/balancer"
|
||||||
|
"git.codelab.vc/pkg/httpx/circuitbreaker"
|
||||||
|
"git.codelab.vc/pkg/httpx/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sentinel errors returned by httpx components.
|
// Sentinel errors returned by httpx components.
|
||||||
|
// These are aliases for the canonical errors defined in sub-packages,
|
||||||
|
// so that errors.Is works regardless of which import the caller uses.
|
||||||
var (
|
var (
|
||||||
ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted")
|
ErrRetryExhausted = retry.ErrRetryExhausted
|
||||||
ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
|
ErrCircuitOpen = circuitbreaker.ErrCircuitOpen
|
||||||
ErrNoHealthy = errors.New("httpx: no healthy endpoints available")
|
ErrNoHealthy = balancer.ErrNoHealthy
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error provides structured error information for failed HTTP operations.
|
// Error provides structured error information for failed HTTP operations.
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ func (m *MockClock) NewTimer(d time.Duration) Timer {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
t := &mockTimer{
|
t := &mockTimer{
|
||||||
|
clock: m,
|
||||||
ch: make(chan time.Time, 1),
|
ch: make(chan time.Time, 1),
|
||||||
deadline: m.now.Add(d),
|
deadline: m.now.Add(d),
|
||||||
active: true,
|
active: true,
|
||||||
@@ -84,6 +85,25 @@ func (m *MockClock) Advance(d time.Duration) {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.now = m.now.Add(d)
|
m.now = m.now.Add(d)
|
||||||
now := m.now
|
now := m.now
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.fireExpired(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the clock to an absolute time and fires any expired timers.
|
||||||
|
func (m *MockClock) Set(t time.Time) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.now = t
|
||||||
|
now := m.now
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.fireExpired(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fireExpired fires all active timers whose deadline has passed, then
|
||||||
|
// removes inactive timers to prevent unbounded growth.
|
||||||
|
func (m *MockClock) fireExpired(now time.Time) {
|
||||||
|
m.mu.Lock()
|
||||||
timers := m.timers
|
timers := m.timers
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
@@ -94,27 +114,27 @@ func (m *MockClock) Advance(d time.Duration) {
|
|||||||
}
|
}
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Set sets the clock to an absolute time and fires any expired timers.
|
// Compact: remove inactive timers. Use a new slice to avoid aliasing
|
||||||
func (m *MockClock) Set(t time.Time) {
|
// the backing array (NewTimer may have appended between snapshots).
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.now = t
|
n := len(m.timers)
|
||||||
now := m.now
|
active := make([]*mockTimer, 0, n)
|
||||||
timers := m.timers
|
for _, t := range m.timers {
|
||||||
m.mu.Unlock()
|
t.mu.Lock()
|
||||||
|
keep := t.active
|
||||||
for _, tmr := range timers {
|
t.mu.Unlock()
|
||||||
tmr.mu.Lock()
|
if keep {
|
||||||
if tmr.active && !now.Before(tmr.deadline) {
|
active = append(active, t)
|
||||||
tmr.fire(now)
|
|
||||||
}
|
}
|
||||||
tmr.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
m.timers = active
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockTimer struct {
|
type mockTimer struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
clock *MockClock
|
||||||
ch chan time.Time
|
ch chan time.Time
|
||||||
deadline time.Time
|
deadline time.Time
|
||||||
active bool
|
active bool
|
||||||
@@ -131,12 +151,17 @@ func (t *mockTimer) Stop() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *mockTimer) Reset(d time.Duration) bool {
|
func (t *mockTimer) Reset(d time.Duration) bool {
|
||||||
|
// Acquire clock lock first to match the lock ordering in fireExpired
|
||||||
|
// (clock.mu → t.mu), preventing deadlock.
|
||||||
|
t.clock.mu.Lock()
|
||||||
|
deadline := t.clock.now.Add(d)
|
||||||
|
t.clock.mu.Unlock()
|
||||||
|
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
was := t.active
|
was := t.active
|
||||||
t.active = true
|
t.active = true
|
||||||
// Note: deadline will be recalculated on next Advance
|
t.deadline = deadline
|
||||||
t.deadline = time.Now().Add(d) // placeholder; mock users should use Advance
|
|
||||||
return was
|
return was
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package retry
|
package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
@@ -8,6 +10,10 @@ import (
|
|||||||
"git.codelab.vc/pkg/httpx/middleware"
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrRetryExhausted is returned when all retry attempts have been exhausted
|
||||||
|
// and the last attempt also failed.
|
||||||
|
var ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted")
|
||||||
|
|
||||||
// Policy decides whether a failed request should be retried.
|
// Policy decides whether a failed request should be retried.
|
||||||
type Policy interface {
|
type Policy interface {
|
||||||
// ShouldRetry reports whether the request should be retried. The extra
|
// ShouldRetry reports whether the request should be retried. The extra
|
||||||
@@ -28,6 +34,7 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
var err error
|
var err error
|
||||||
|
var exhausted bool
|
||||||
|
|
||||||
for attempt := range cfg.maxAttempts {
|
for attempt := range cfg.maxAttempts {
|
||||||
// For retries (attempt > 0), restore the request body.
|
// For retries (attempt > 0), restore the request body.
|
||||||
@@ -48,6 +55,7 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
|
|
||||||
// Last attempt — return whatever we got.
|
// Last attempt — return whatever we got.
|
||||||
if attempt == cfg.maxAttempts-1 {
|
if attempt == cfg.maxAttempts-1 {
|
||||||
|
exhausted = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +93,10 @@ func Transport(opts ...Option) middleware.Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap with ErrRetryExhausted only when all attempts were used.
|
||||||
|
if exhausted && err != nil {
|
||||||
|
err = fmt.Errorf("%w: %w", ErrRetryExhausted, err)
|
||||||
|
}
|
||||||
return resp, err
|
return resp, err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user