Add retry transport with configurable backoff and Retry-After support
Implements retry middleware as a RoundTripper wrapper: - Exponential and constant backoff strategies with jitter - RFC 7231 Retry-After header parsing (seconds and HTTP-date) - Default policy retries idempotent methods on 429/5xx and network errors - Body restoration via GetBody, context cancellation, response body cleanup
This commit is contained in:
64
retry/backoff.go
Normal file
64
retry/backoff.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backoff computes the delay before the next retry attempt.
|
||||||
|
type Backoff interface {
|
||||||
|
// Delay returns the wait duration for the given attempt number (zero-based).
|
||||||
|
Delay(attempt int) time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExponentialBackoff returns a Backoff that doubles the delay on each attempt.
|
||||||
|
// The delay is calculated as base * 2^attempt, capped at max. When withJitter
|
||||||
|
// is true, a random duration in [0, delay*0.5) is added.
|
||||||
|
func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff {
|
||||||
|
return &exponentialBackoff{
|
||||||
|
base: base,
|
||||||
|
max: max,
|
||||||
|
withJitter: withJitter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConstantBackoff returns a Backoff that always returns the same delay.
|
||||||
|
func ConstantBackoff(d time.Duration) Backoff {
|
||||||
|
return constantBackoff{delay: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
type exponentialBackoff struct {
|
||||||
|
base time.Duration
|
||||||
|
max time.Duration
|
||||||
|
withJitter bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *exponentialBackoff) Delay(attempt int) time.Duration {
|
||||||
|
delay := b.base
|
||||||
|
for range attempt {
|
||||||
|
delay *= 2
|
||||||
|
if delay >= b.max {
|
||||||
|
delay = b.max
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.withJitter {
|
||||||
|
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
|
||||||
|
delay += jitter
|
||||||
|
}
|
||||||
|
|
||||||
|
if delay > b.max {
|
||||||
|
delay = b.max
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
type constantBackoff struct {
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b constantBackoff) Delay(_ int) time.Duration {
|
||||||
|
return b.delay
|
||||||
|
}
|
||||||
77
retry/backoff_test.go
Normal file
77
retry/backoff_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExponentialBackoff(t *testing.T) {
|
||||||
|
t.Run("doubles each attempt", func(t *testing.T) {
|
||||||
|
b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false)
|
||||||
|
|
||||||
|
want := []time.Duration{
|
||||||
|
100 * time.Millisecond, // attempt 0: base
|
||||||
|
200 * time.Millisecond, // attempt 1: base*2
|
||||||
|
400 * time.Millisecond, // attempt 2: base*4
|
||||||
|
800 * time.Millisecond, // attempt 3: base*8
|
||||||
|
1600 * time.Millisecond, // attempt 4: base*16
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range want {
|
||||||
|
got := b.Delay(i)
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("attempt %d: expected %v, got %v", i, expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("caps at max", func(t *testing.T) {
|
||||||
|
b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false)
|
||||||
|
|
||||||
|
// attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms
|
||||||
|
for _, attempt := range []int{3, 4, 10} {
|
||||||
|
got := b.Delay(attempt)
|
||||||
|
if got != 500*time.Millisecond {
|
||||||
|
t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with jitter adds randomness", func(t *testing.T) {
|
||||||
|
base := 100 * time.Millisecond
|
||||||
|
b := ExponentialBackoff(base, 10*time.Second, true)
|
||||||
|
|
||||||
|
// Run multiple times; with jitter, delay >= base for attempt 0.
|
||||||
|
// Also verify not all values are identical (randomness).
|
||||||
|
seen := make(map[time.Duration]bool)
|
||||||
|
for range 20 {
|
||||||
|
d := b.Delay(0)
|
||||||
|
if d < base {
|
||||||
|
t.Fatalf("delay %v is less than base %v", d, base)
|
||||||
|
}
|
||||||
|
// With jitter: delay = base + rand in [0, base/2), so max is base*1.5
|
||||||
|
maxExpected := base + base/2
|
||||||
|
if d > maxExpected {
|
||||||
|
t.Fatalf("delay %v exceeds expected max %v", d, maxExpected)
|
||||||
|
}
|
||||||
|
seen[d] = true
|
||||||
|
}
|
||||||
|
if len(seen) < 2 {
|
||||||
|
t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstantBackoff(t *testing.T) {
|
||||||
|
t.Run("always returns same value", func(t *testing.T) {
|
||||||
|
d := 250 * time.Millisecond
|
||||||
|
b := ConstantBackoff(d)
|
||||||
|
|
||||||
|
for _, attempt := range []int{0, 1, 2, 5, 100} {
|
||||||
|
got := b.Delay(attempt)
|
||||||
|
if got != d {
|
||||||
|
t.Errorf("attempt %d: expected %v, got %v", attempt, d, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
56
retry/options.go
Normal file
56
retry/options.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
maxAttempts int // default 3
|
||||||
|
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
|
||||||
|
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
|
||||||
|
retryAfter bool // default true, respect Retry-After header
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the retry transport.
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
func defaults() options {
|
||||||
|
return options{
|
||||||
|
maxAttempts: 3,
|
||||||
|
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
|
||||||
|
policy: defaultPolicy{},
|
||||||
|
retryAfter: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxAttempts sets the maximum number of attempts (including the first).
|
||||||
|
// Values less than 1 are treated as 1 (no retries).
|
||||||
|
func WithMaxAttempts(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n < 1 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
o.maxAttempts = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBackoff sets the backoff strategy used to compute delays between retries.
|
||||||
|
func WithBackoff(b Backoff) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.backoff = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPolicy sets the retry policy that decides whether to retry a request.
|
||||||
|
func WithPolicy(p Policy) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.policy = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetryAfter controls whether the Retry-After response header is respected.
|
||||||
|
// When enabled and present, the Retry-After delay is used if it exceeds the
|
||||||
|
// backoff delay.
|
||||||
|
func WithRetryAfter(enable bool) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.retryAfter = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
125
retry/retry.go
Normal file
125
retry/retry.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Policy decides whether a failed request should be retried.
|
||||||
|
type Policy interface {
|
||||||
|
// ShouldRetry reports whether the request should be retried. The extra
|
||||||
|
// duration, if non-zero, is a policy-suggested delay that overrides the
|
||||||
|
// backoff strategy.
|
||||||
|
ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns a middleware that retries failed requests according to
|
||||||
|
// the provided options.
|
||||||
|
func Transport(opts ...Option) middleware.Middleware {
|
||||||
|
cfg := defaults()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
var resp *http.Response
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for attempt := range cfg.maxAttempts {
|
||||||
|
// For retries (attempt > 0), restore the request body.
|
||||||
|
if attempt > 0 {
|
||||||
|
if req.GetBody != nil {
|
||||||
|
body, bodyErr := req.GetBody()
|
||||||
|
if bodyErr != nil {
|
||||||
|
return resp, bodyErr
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
} else if req.Body != nil {
|
||||||
|
// Body was consumed and cannot be re-created.
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = next.RoundTrip(req)
|
||||||
|
|
||||||
|
// Last attempt — return whatever we got.
|
||||||
|
if attempt == cfg.maxAttempts-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err)
|
||||||
|
if !shouldRetry {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute delay: use backoff or policy delay, whichever is larger.
|
||||||
|
delay := cfg.backoff.Delay(attempt)
|
||||||
|
if policyDelay > delay {
|
||||||
|
delay = policyDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respect Retry-After header if enabled.
|
||||||
|
if cfg.retryAfter && resp != nil {
|
||||||
|
if ra, ok := ParseRetryAfter(resp); ok && ra > delay {
|
||||||
|
delay = ra
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain and close the response body to release the connection.
|
||||||
|
if resp != nil {
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the delay or context cancellation.
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
select {
|
||||||
|
case <-req.Context().Done():
|
||||||
|
timer.Stop()
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultPolicy retries on network errors, 429, and 5xx server errors.
|
||||||
|
// It refuses to retry non-idempotent methods.
|
||||||
|
type defaultPolicy struct{}
|
||||||
|
|
||||||
|
func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
if !isIdempotent(req.Method) {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network error — always retry idempotent requests.
|
||||||
|
if err != nil {
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusTooManyRequests, // 429
|
||||||
|
http.StatusBadGateway, // 502
|
||||||
|
http.StatusServiceUnavailable, // 503
|
||||||
|
http.StatusGatewayTimeout: // 504
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIdempotent reports whether the HTTP method is safe to retry.
|
||||||
|
func isIdempotent(method string) bool {
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
43
retry/retry_after.go
Normal file
43
retry/retry_after.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231).
|
||||||
|
// It supports both the delay-seconds format ("120") and the HTTP-date format
|
||||||
|
// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the
|
||||||
|
// header was present and valid; otherwise returns 0 and false.
|
||||||
|
func ParseRetryAfter(resp *http.Response) (time.Duration, bool) {
|
||||||
|
if resp == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
val := resp.Header.Get("Retry-After")
|
||||||
|
if val == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try delay-seconds first (most common).
|
||||||
|
if seconds, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||||
|
if seconds < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try HTTP-date format (RFC 7231 section 7.1.1.1).
|
||||||
|
t, err := http.ParseTime(val)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := time.Until(t)
|
||||||
|
if delay < 0 {
|
||||||
|
// The date is in the past; no need to wait.
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return delay, true
|
||||||
|
}
|
||||||
58
retry/retry_after_test.go
Normal file
58
retry/retry_after_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseRetryAfter(t *testing.T) {
|
||||||
|
t.Run("seconds format", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: http.Header{"Retry-After": []string{"120"}},
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected ok=true")
|
||||||
|
}
|
||||||
|
if d != 120*time.Second {
|
||||||
|
t.Fatalf("expected 120s, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty header", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for empty header")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil response", func(t *testing.T) {
|
||||||
|
d, ok := ParseRetryAfter(nil)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for nil response")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negative value", func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
Header: http.Header{"Retry-After": []string{"-5"}},
|
||||||
|
}
|
||||||
|
d, ok := ParseRetryAfter(resp)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected ok=false for negative value")
|
||||||
|
}
|
||||||
|
if d != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", d)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
237
retry/retry_test.go
Normal file
237
retry/retry_test.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusResponse(code int) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: code,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport(t *testing.T) {
|
||||||
|
t.Run("successful request no retry", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("retries on 503 then succeeds", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n < 3 {
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 3 {
|
||||||
|
t.Fatalf("expected 3 calls, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stops on context cancellation", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(5),
|
||||||
|
WithBackoff(ConstantBackoff(50*time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("respects maxAttempts", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(2),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
return statusResponse(http.StatusBadGateway), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
var bodies []string
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
calls.Add(1)
|
||||||
|
b, _ := io.ReadAll(req.Body)
|
||||||
|
bodies = append(bodies, string(b))
|
||||||
|
if len(bodies) < 2 {
|
||||||
|
return statusResponse(http.StatusServiceUnavailable), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
bodyContent := "request-body"
|
||||||
|
body := bytes.NewReader([]byte(bodyContent))
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
|
||||||
|
req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", got)
|
||||||
|
}
|
||||||
|
for i, b := range bodies {
|
||||||
|
if b != bodyContent {
|
||||||
|
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom policy", func(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
|
||||||
|
// Custom policy: retry only on 418
|
||||||
|
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
if resp != nil && resp.StatusCode == http.StatusTeapot {
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithMaxAttempts(3),
|
||||||
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
||||||
|
WithPolicy(custom),
|
||||||
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
n := calls.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return statusResponse(http.StatusTeapot), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// policyFunc adapts a function into a Policy.
|
||||||
|
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
||||||
|
|
||||||
|
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
||||||
|
return f(attempt, req, resp, err)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user