Add retry transport with configurable backoff and Retry-After support

Implements retry middleware as a RoundTripper wrapper:
- Exponential and constant backoff strategies with jitter
- RFC 7231 Retry-After header parsing (seconds and HTTP-date)
- Default policy retries idempotent methods on 429/5xx and network errors
- Body restoration via GetBody, context cancellation, response body cleanup
This commit is contained in:
2026-03-20 14:21:53 +03:00
parent 6b1941fce7
commit 505c7b8c4f
7 changed files with 660 additions and 0 deletions

64
retry/backoff.go Normal file
View File

@@ -0,0 +1,64 @@
package retry
import (
"math/rand/v2"
"time"
)
// Backoff computes the delay before the next retry attempt.
type Backoff interface {
// Delay returns the wait duration for the given attempt number (zero-based).
Delay(attempt int) time.Duration
}
// ExponentialBackoff returns a Backoff that doubles the delay on each attempt.
// The delay is calculated as base * 2^attempt, capped at max. When withJitter
// is true, a random duration in [0, delay*0.5) is added.
func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff {
return &exponentialBackoff{
base: base,
max: max,
withJitter: withJitter,
}
}
// ConstantBackoff returns a Backoff that always returns the same delay.
func ConstantBackoff(d time.Duration) Backoff {
return constantBackoff{delay: d}
}
type exponentialBackoff struct {
base time.Duration
max time.Duration
withJitter bool
}
func (b *exponentialBackoff) Delay(attempt int) time.Duration {
delay := b.base
for range attempt {
delay *= 2
if delay >= b.max {
delay = b.max
break
}
}
if b.withJitter {
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
delay += jitter
}
if delay > b.max {
delay = b.max
}
return delay
}
type constantBackoff struct {
delay time.Duration
}
func (b constantBackoff) Delay(_ int) time.Duration {
return b.delay
}

77
retry/backoff_test.go Normal file
View File

@@ -0,0 +1,77 @@
package retry
import (
"testing"
"time"
)
func TestExponentialBackoff(t *testing.T) {
t.Run("doubles each attempt", func(t *testing.T) {
b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false)
want := []time.Duration{
100 * time.Millisecond, // attempt 0: base
200 * time.Millisecond, // attempt 1: base*2
400 * time.Millisecond, // attempt 2: base*4
800 * time.Millisecond, // attempt 3: base*8
1600 * time.Millisecond, // attempt 4: base*16
}
for i, expected := range want {
got := b.Delay(i)
if got != expected {
t.Errorf("attempt %d: expected %v, got %v", i, expected, got)
}
}
})
t.Run("caps at max", func(t *testing.T) {
b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false)
// attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms
for _, attempt := range []int{3, 4, 10} {
got := b.Delay(attempt)
if got != 500*time.Millisecond {
t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got)
}
}
})
t.Run("with jitter adds randomness", func(t *testing.T) {
base := 100 * time.Millisecond
b := ExponentialBackoff(base, 10*time.Second, true)
// Run multiple times; with jitter, delay >= base for attempt 0.
// Also verify not all values are identical (randomness).
seen := make(map[time.Duration]bool)
for range 20 {
d := b.Delay(0)
if d < base {
t.Fatalf("delay %v is less than base %v", d, base)
}
// With jitter: delay = base + rand in [0, base/2), so max is base*1.5
maxExpected := base + base/2
if d > maxExpected {
t.Fatalf("delay %v exceeds expected max %v", d, maxExpected)
}
seen[d] = true
}
if len(seen) < 2 {
t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen))
}
})
}
func TestConstantBackoff(t *testing.T) {
t.Run("always returns same value", func(t *testing.T) {
d := 250 * time.Millisecond
b := ConstantBackoff(d)
for _, attempt := range []int{0, 1, 2, 5, 100} {
got := b.Delay(attempt)
if got != d {
t.Errorf("attempt %d: expected %v, got %v", attempt, d, got)
}
}
})
}

56
retry/options.go Normal file
View File

@@ -0,0 +1,56 @@
package retry
import "time"
type options struct {
maxAttempts int // default 3
backoff Backoff // default ExponentialBackoff(100ms, 5s, true)
policy Policy // default: defaultPolicy (retry on 5xx and network errors)
retryAfter bool // default true, respect Retry-After header
}
// Option configures the retry transport.
type Option func(*options)
func defaults() options {
return options{
maxAttempts: 3,
backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true),
policy: defaultPolicy{},
retryAfter: true,
}
}
// WithMaxAttempts sets the maximum number of attempts (including the first).
// Values less than 1 are treated as 1 (no retries).
func WithMaxAttempts(n int) Option {
return func(o *options) {
if n < 1 {
n = 1
}
o.maxAttempts = n
}
}
// WithBackoff sets the backoff strategy used to compute delays between retries.
func WithBackoff(b Backoff) Option {
return func(o *options) {
o.backoff = b
}
}
// WithPolicy sets the retry policy that decides whether to retry a request.
func WithPolicy(p Policy) Option {
return func(o *options) {
o.policy = p
}
}
// WithRetryAfter controls whether the Retry-After response header is respected.
// When enabled and present, the Retry-After delay is used if it exceeds the
// backoff delay.
func WithRetryAfter(enable bool) Option {
return func(o *options) {
o.retryAfter = enable
}
}

125
retry/retry.go Normal file
View File

@@ -0,0 +1,125 @@
package retry
import (
"io"
"net/http"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
// Policy decides whether a failed request should be retried.
type Policy interface {
// ShouldRetry reports whether the request should be retried. The extra
// duration, if non-zero, is a policy-suggested delay that overrides the
// backoff strategy.
ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration)
}
// Transport returns a middleware that retries failed requests according to
// the provided options.
func Transport(opts ...Option) middleware.Middleware {
cfg := defaults()
for _, o := range opts {
o(&cfg)
}
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
var resp *http.Response
var err error
for attempt := range cfg.maxAttempts {
// For retries (attempt > 0), restore the request body.
if attempt > 0 {
if req.GetBody != nil {
body, bodyErr := req.GetBody()
if bodyErr != nil {
return resp, bodyErr
}
req.Body = body
} else if req.Body != nil {
// Body was consumed and cannot be re-created.
return resp, err
}
}
resp, err = next.RoundTrip(req)
// Last attempt — return whatever we got.
if attempt == cfg.maxAttempts-1 {
break
}
shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err)
if !shouldRetry {
break
}
// Compute delay: use backoff or policy delay, whichever is larger.
delay := cfg.backoff.Delay(attempt)
if policyDelay > delay {
delay = policyDelay
}
// Respect Retry-After header if enabled.
if cfg.retryAfter && resp != nil {
if ra, ok := ParseRetryAfter(resp); ok && ra > delay {
delay = ra
}
}
// Drain and close the response body to release the connection.
if resp != nil {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}
// Wait for the delay or context cancellation.
timer := time.NewTimer(delay)
select {
case <-req.Context().Done():
timer.Stop()
return nil, req.Context().Err()
case <-timer.C:
}
}
return resp, err
})
}
}
// defaultPolicy retries on network errors, 429, and 5xx server errors.
// It refuses to retry non-idempotent methods.
type defaultPolicy struct{}
func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
if !isIdempotent(req.Method) {
return false, 0
}
// Network error — always retry idempotent requests.
if err != nil {
return true, 0
}
switch resp.StatusCode {
case http.StatusTooManyRequests, // 429
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true, 0
}
return false, 0
}
// isIdempotent reports whether the HTTP method is safe to retry.
func isIdempotent(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
return true
}
return false
}

43
retry/retry_after.go Normal file
View File

@@ -0,0 +1,43 @@
package retry
import (
"net/http"
"strconv"
"time"
)
// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231).
// It supports both the delay-seconds format ("120") and the HTTP-date format
// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the
// header was present and valid; otherwise returns 0 and false.
func ParseRetryAfter(resp *http.Response) (time.Duration, bool) {
if resp == nil {
return 0, false
}
val := resp.Header.Get("Retry-After")
if val == "" {
return 0, false
}
// Try delay-seconds first (most common).
if seconds, err := strconv.ParseInt(val, 10, 64); err == nil {
if seconds < 0 {
return 0, false
}
return time.Duration(seconds) * time.Second, true
}
// Try HTTP-date format (RFC 7231 section 7.1.1.1).
t, err := http.ParseTime(val)
if err != nil {
return 0, false
}
delay := time.Until(t)
if delay < 0 {
// The date is in the past; no need to wait.
return 0, true
}
return delay, true
}

58
retry/retry_after_test.go Normal file
View File

@@ -0,0 +1,58 @@
package retry
import (
"net/http"
"testing"
"time"
)
func TestParseRetryAfter(t *testing.T) {
t.Run("seconds format", func(t *testing.T) {
resp := &http.Response{
Header: http.Header{"Retry-After": []string{"120"}},
}
d, ok := ParseRetryAfter(resp)
if !ok {
t.Fatal("expected ok=true")
}
if d != 120*time.Second {
t.Fatalf("expected 120s, got %v", d)
}
})
t.Run("empty header", func(t *testing.T) {
resp := &http.Response{
Header: make(http.Header),
}
d, ok := ParseRetryAfter(resp)
if ok {
t.Fatal("expected ok=false for empty header")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
t.Run("nil response", func(t *testing.T) {
d, ok := ParseRetryAfter(nil)
if ok {
t.Fatal("expected ok=false for nil response")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
t.Run("negative value", func(t *testing.T) {
resp := &http.Response{
Header: http.Header{"Retry-After": []string{"-5"}},
}
d, ok := ParseRetryAfter(resp)
if ok {
t.Fatal("expected ok=false for negative value")
}
if d != 0 {
t.Fatalf("expected 0, got %v", d)
}
})
}

237
retry/retry_test.go Normal file
View File

@@ -0,0 +1,237 @@
package retry
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func statusResponse(code int) *http.Response {
return &http.Response{
StatusCode: code,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func TestTransport(t *testing.T) {
t.Run("successful request no retry", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call, got %d", got)
}
})
t.Run("retries on 503 then succeeds", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n < 3 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 3 {
t.Fatalf("expected 3 calls, got %d", got)
}
})
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
}
})
t.Run("stops on context cancellation", func(t *testing.T) {
var calls atomic.Int32
ctx, cancel := context.WithCancel(context.Background())
rt := Transport(
WithMaxAttempts(5),
WithBackoff(ConstantBackoff(50*time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
cancel()
}
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
}
})
t.Run("respects maxAttempts", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(2),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusBadGateway), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
}
})
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
var calls atomic.Int32
var bodies []string
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
b, _ := io.ReadAll(req.Body)
bodies = append(bodies, string(b))
if len(bodies) < 2 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
bodyContent := "request-body"
body := bytes.NewReader([]byte(bodyContent))
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
for i, b := range bodies {
if b != bodyContent {
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
}
}
})
t.Run("custom policy", func(t *testing.T) {
var calls atomic.Int32
// Custom policy: retry only on 418
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
if resp != nil && resp.StatusCode == http.StatusTeapot {
return true, 0
}
return false, 0
})
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
WithPolicy(custom),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
return statusResponse(http.StatusTeapot), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
})
}
// policyFunc adapts a function into a Policy.
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
return f(attempt, req, resp, err)
}