From 505c7b8c4fa65631f1f128b6571cb62f8dacc4fe Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Fri, 20 Mar 2026 14:21:53 +0300 Subject: [PATCH] Add retry transport with configurable backoff and Retry-After support Implements retry middleware as a RoundTripper wrapper: - Exponential and constant backoff strategies with jitter - RFC 7231 Retry-After header parsing (seconds and HTTP-date) - Default policy retries idempotent methods on 429/5xx and network errors - Body restoration via GetBody, context cancellation, response body cleanup --- retry/backoff.go | 64 ++++++++++ retry/backoff_test.go | 77 +++++++++++++ retry/options.go | 56 +++++++++ retry/retry.go | 125 ++++++++++++++++++++ retry/retry_after.go | 43 +++++++ retry/retry_after_test.go | 58 ++++++++++ retry/retry_test.go | 237 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 660 insertions(+) create mode 100644 retry/backoff.go create mode 100644 retry/backoff_test.go create mode 100644 retry/options.go create mode 100644 retry/retry.go create mode 100644 retry/retry_after.go create mode 100644 retry/retry_after_test.go create mode 100644 retry/retry_test.go diff --git a/retry/backoff.go b/retry/backoff.go new file mode 100644 index 0000000..21966a9 --- /dev/null +++ b/retry/backoff.go @@ -0,0 +1,64 @@ +package retry + +import ( + "math/rand/v2" + "time" +) + +// Backoff computes the delay before the next retry attempt. +type Backoff interface { + // Delay returns the wait duration for the given attempt number (zero-based). + Delay(attempt int) time.Duration +} + +// ExponentialBackoff returns a Backoff that doubles the delay on each attempt. +// The delay is calculated as base * 2^attempt, capped at max. When withJitter +// is true, a random duration in [0, delay*0.5) is added. +func ExponentialBackoff(base, max time.Duration, withJitter bool) Backoff { + return &exponentialBackoff{ + base: base, + max: max, + withJitter: withJitter, + } +} + +// ConstantBackoff returns a Backoff that always returns the same delay. +func ConstantBackoff(d time.Duration) Backoff { + return constantBackoff{delay: d} +} + +type exponentialBackoff struct { + base time.Duration + max time.Duration + withJitter bool +} + +func (b *exponentialBackoff) Delay(attempt int) time.Duration { + delay := b.base + for range attempt { + delay *= 2 + if delay >= b.max { + delay = b.max + break + } + } + + if b.withJitter { + jitter := time.Duration(rand.Int64N(int64(delay / 2))) + delay += jitter + } + + if delay > b.max { + delay = b.max + } + + return delay +} + +type constantBackoff struct { + delay time.Duration +} + +func (b constantBackoff) Delay(_ int) time.Duration { + return b.delay +} diff --git a/retry/backoff_test.go b/retry/backoff_test.go new file mode 100644 index 0000000..29eb239 --- /dev/null +++ b/retry/backoff_test.go @@ -0,0 +1,77 @@ +package retry + +import ( + "testing" + "time" +) + +func TestExponentialBackoff(t *testing.T) { + t.Run("doubles each attempt", func(t *testing.T) { + b := ExponentialBackoff(100*time.Millisecond, 10*time.Second, false) + + want := []time.Duration{ + 100 * time.Millisecond, // attempt 0: base + 200 * time.Millisecond, // attempt 1: base*2 + 400 * time.Millisecond, // attempt 2: base*4 + 800 * time.Millisecond, // attempt 3: base*8 + 1600 * time.Millisecond, // attempt 4: base*16 + } + + for i, expected := range want { + got := b.Delay(i) + if got != expected { + t.Errorf("attempt %d: expected %v, got %v", i, expected, got) + } + } + }) + + t.Run("caps at max", func(t *testing.T) { + b := ExponentialBackoff(100*time.Millisecond, 500*time.Millisecond, false) + + // attempt 0: 100ms, 1: 200ms, 2: 400ms, 3: 500ms (capped), 4: 500ms + for _, attempt := range []int{3, 4, 10} { + got := b.Delay(attempt) + if got != 500*time.Millisecond { + t.Errorf("attempt %d: expected cap at 500ms, got %v", attempt, got) + } + } + }) + + t.Run("with jitter adds randomness", func(t *testing.T) { + base := 100 * time.Millisecond + b := ExponentialBackoff(base, 10*time.Second, true) + + // Run multiple times; with jitter, delay >= base for attempt 0. + // Also verify not all values are identical (randomness). + seen := make(map[time.Duration]bool) + for range 20 { + d := b.Delay(0) + if d < base { + t.Fatalf("delay %v is less than base %v", d, base) + } + // With jitter: delay = base + rand in [0, base/2), so max is base*1.5 + maxExpected := base + base/2 + if d > maxExpected { + t.Fatalf("delay %v exceeds expected max %v", d, maxExpected) + } + seen[d] = true + } + if len(seen) < 2 { + t.Errorf("expected jitter to produce varying delays, got %d unique values", len(seen)) + } + }) +} + +func TestConstantBackoff(t *testing.T) { + t.Run("always returns same value", func(t *testing.T) { + d := 250 * time.Millisecond + b := ConstantBackoff(d) + + for _, attempt := range []int{0, 1, 2, 5, 100} { + got := b.Delay(attempt) + if got != d { + t.Errorf("attempt %d: expected %v, got %v", attempt, d, got) + } + } + }) +} diff --git a/retry/options.go b/retry/options.go new file mode 100644 index 0000000..5e59cb8 --- /dev/null +++ b/retry/options.go @@ -0,0 +1,56 @@ +package retry + +import "time" + +type options struct { + maxAttempts int // default 3 + backoff Backoff // default ExponentialBackoff(100ms, 5s, true) + policy Policy // default: defaultPolicy (retry on 5xx and network errors) + retryAfter bool // default true, respect Retry-After header +} + +// Option configures the retry transport. +type Option func(*options) + +func defaults() options { + return options{ + maxAttempts: 3, + backoff: ExponentialBackoff(100*time.Millisecond, 5*time.Second, true), + policy: defaultPolicy{}, + retryAfter: true, + } +} + +// WithMaxAttempts sets the maximum number of attempts (including the first). +// Values less than 1 are treated as 1 (no retries). +func WithMaxAttempts(n int) Option { + return func(o *options) { + if n < 1 { + n = 1 + } + o.maxAttempts = n + } +} + +// WithBackoff sets the backoff strategy used to compute delays between retries. +func WithBackoff(b Backoff) Option { + return func(o *options) { + o.backoff = b + } +} + +// WithPolicy sets the retry policy that decides whether to retry a request. +func WithPolicy(p Policy) Option { + return func(o *options) { + o.policy = p + } +} + +// WithRetryAfter controls whether the Retry-After response header is respected. +// When enabled and present, the Retry-After delay is used if it exceeds the +// backoff delay. +func WithRetryAfter(enable bool) Option { + return func(o *options) { + o.retryAfter = enable + } +} diff --git a/retry/retry.go b/retry/retry.go new file mode 100644 index 0000000..36c491b --- /dev/null +++ b/retry/retry.go @@ -0,0 +1,125 @@ +package retry + +import ( + "io" + "net/http" + "time" + + "git.codelab.vc/pkg/httpx/middleware" +) + +// Policy decides whether a failed request should be retried. +type Policy interface { + // ShouldRetry reports whether the request should be retried. The extra + // duration, if non-zero, is a policy-suggested delay that overrides the + // backoff strategy. + ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) +} + +// Transport returns a middleware that retries failed requests according to +// the provided options. +func Transport(opts ...Option) middleware.Middleware { + cfg := defaults() + for _, o := range opts { + o(&cfg) + } + + return func(next http.RoundTripper) http.RoundTripper { + return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + + for attempt := range cfg.maxAttempts { + // For retries (attempt > 0), restore the request body. + if attempt > 0 { + if req.GetBody != nil { + body, bodyErr := req.GetBody() + if bodyErr != nil { + return resp, bodyErr + } + req.Body = body + } else if req.Body != nil { + // Body was consumed and cannot be re-created. + return resp, err + } + } + + resp, err = next.RoundTrip(req) + + // Last attempt — return whatever we got. + if attempt == cfg.maxAttempts-1 { + break + } + + shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err) + if !shouldRetry { + break + } + + // Compute delay: use backoff or policy delay, whichever is larger. + delay := cfg.backoff.Delay(attempt) + if policyDelay > delay { + delay = policyDelay + } + + // Respect Retry-After header if enabled. + if cfg.retryAfter && resp != nil { + if ra, ok := ParseRetryAfter(resp); ok && ra > delay { + delay = ra + } + } + + // Drain and close the response body to release the connection. + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + + // Wait for the delay or context cancellation. + timer := time.NewTimer(delay) + select { + case <-req.Context().Done(): + timer.Stop() + return nil, req.Context().Err() + case <-timer.C: + } + } + + return resp, err + }) + } +} + +// defaultPolicy retries on network errors, 429, and 5xx server errors. +// It refuses to retry non-idempotent methods. +type defaultPolicy struct{} + +func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { + if !isIdempotent(req.Method) { + return false, 0 + } + + // Network error — always retry idempotent requests. + if err != nil { + return true, 0 + } + + switch resp.StatusCode { + case http.StatusTooManyRequests, // 429 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout: // 504 + return true, 0 + } + + return false, 0 +} + +// isIdempotent reports whether the HTTP method is safe to retry. +func isIdempotent(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut: + return true + } + return false +} diff --git a/retry/retry_after.go b/retry/retry_after.go new file mode 100644 index 0000000..9c6a488 --- /dev/null +++ b/retry/retry_after.go @@ -0,0 +1,43 @@ +package retry + +import ( + "net/http" + "strconv" + "time" +) + +// ParseRetryAfter extracts the delay from a Retry-After header (RFC 7231). +// It supports both the delay-seconds format ("120") and the HTTP-date format +// ("Fri, 31 Dec 1999 23:59:59 GMT"). Returns the duration and true if the +// header was present and valid; otherwise returns 0 and false. +func ParseRetryAfter(resp *http.Response) (time.Duration, bool) { + if resp == nil { + return 0, false + } + + val := resp.Header.Get("Retry-After") + if val == "" { + return 0, false + } + + // Try delay-seconds first (most common). + if seconds, err := strconv.ParseInt(val, 10, 64); err == nil { + if seconds < 0 { + return 0, false + } + return time.Duration(seconds) * time.Second, true + } + + // Try HTTP-date format (RFC 7231 section 7.1.1.1). + t, err := http.ParseTime(val) + if err != nil { + return 0, false + } + + delay := time.Until(t) + if delay < 0 { + // The date is in the past; no need to wait. + return 0, true + } + return delay, true +} diff --git a/retry/retry_after_test.go b/retry/retry_after_test.go new file mode 100644 index 0000000..7f387d4 --- /dev/null +++ b/retry/retry_after_test.go @@ -0,0 +1,58 @@ +package retry + +import ( + "net/http" + "testing" + "time" +) + +func TestParseRetryAfter(t *testing.T) { + t.Run("seconds format", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Retry-After": []string{"120"}}, + } + d, ok := ParseRetryAfter(resp) + if !ok { + t.Fatal("expected ok=true") + } + if d != 120*time.Second { + t.Fatalf("expected 120s, got %v", d) + } + }) + + t.Run("empty header", func(t *testing.T) { + resp := &http.Response{ + Header: make(http.Header), + } + d, ok := ParseRetryAfter(resp) + if ok { + t.Fatal("expected ok=false for empty header") + } + if d != 0 { + t.Fatalf("expected 0, got %v", d) + } + }) + + t.Run("nil response", func(t *testing.T) { + d, ok := ParseRetryAfter(nil) + if ok { + t.Fatal("expected ok=false for nil response") + } + if d != 0 { + t.Fatalf("expected 0, got %v", d) + } + }) + + t.Run("negative value", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Retry-After": []string{"-5"}}, + } + d, ok := ParseRetryAfter(resp) + if ok { + t.Fatal("expected ok=false for negative value") + } + if d != 0 { + t.Fatalf("expected 0, got %v", d) + } + }) +} diff --git a/retry/retry_test.go b/retry/retry_test.go new file mode 100644 index 0000000..373395c --- /dev/null +++ b/retry/retry_test.go @@ -0,0 +1,237 @@ +package retry + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { + return middleware.RoundTripperFunc(fn) +} + +func okResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + } +} + +func statusResponse(code int) *http.Response { + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + } +} + +func TestTransport(t *testing.T) { + t.Run("successful request no retry", func(t *testing.T) { + var calls atomic.Int32 + rt := Transport( + WithMaxAttempts(3), + WithBackoff(ConstantBackoff(time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + calls.Add(1) + return okResponse(), nil + })) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 1 { + t.Fatalf("expected 1 call, got %d", got) + } + }) + + t.Run("retries on 503 then succeeds", func(t *testing.T) { + var calls atomic.Int32 + rt := Transport( + WithMaxAttempts(3), + WithBackoff(ConstantBackoff(time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + n := calls.Add(1) + if n < 3 { + return statusResponse(http.StatusServiceUnavailable), nil + } + return okResponse(), nil + })) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 3 { + t.Fatalf("expected 3 calls, got %d", got) + } + }) + + t.Run("does not retry non-idempotent POST", func(t *testing.T) { + var calls atomic.Int32 + rt := Transport( + WithMaxAttempts(3), + WithBackoff(ConstantBackoff(time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + calls.Add(1) + return statusResponse(http.StatusServiceUnavailable), nil + })) + + req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data")) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 1 { + t.Fatalf("expected 1 call (no retry for POST), got %d", got) + } + }) + + t.Run("stops on context cancellation", func(t *testing.T) { + var calls atomic.Int32 + ctx, cancel := context.WithCancel(context.Background()) + + rt := Transport( + WithMaxAttempts(5), + WithBackoff(ConstantBackoff(50*time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + n := calls.Add(1) + if n == 1 { + cancel() + } + return statusResponse(http.StatusServiceUnavailable), nil + })) + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err) + } + }) + + t.Run("respects maxAttempts", func(t *testing.T) { + var calls atomic.Int32 + rt := Transport( + WithMaxAttempts(2), + WithBackoff(ConstantBackoff(time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + calls.Add(1) + return statusResponse(http.StatusBadGateway), nil + })) + + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("expected 502, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 2 { + t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got) + } + }) + + t.Run("body is restored via GetBody on retry", func(t *testing.T) { + var calls atomic.Int32 + var bodies []string + + rt := Transport( + WithMaxAttempts(3), + WithBackoff(ConstantBackoff(time.Millisecond)), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + calls.Add(1) + b, _ := io.ReadAll(req.Body) + bodies = append(bodies, string(b)) + if len(bodies) < 2 { + return statusResponse(http.StatusServiceUnavailable), nil + } + return okResponse(), nil + })) + + bodyContent := "request-body" + body := bytes.NewReader([]byte(bodyContent)) + req, _ := http.NewRequest(http.MethodPut, "http://example.com", body) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil + } + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 2 { + t.Fatalf("expected 2 calls, got %d", got) + } + for i, b := range bodies { + if b != bodyContent { + t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b) + } + } + }) + + t.Run("custom policy", func(t *testing.T) { + var calls atomic.Int32 + + // Custom policy: retry only on 418 + custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { + if resp != nil && resp.StatusCode == http.StatusTeapot { + return true, 0 + } + return false, 0 + }) + + rt := Transport( + WithMaxAttempts(3), + WithBackoff(ConstantBackoff(time.Millisecond)), + WithPolicy(custom), + )(mockTransport(func(req *http.Request) (*http.Response, error) { + n := calls.Add(1) + if n == 1 { + return statusResponse(http.StatusTeapot), nil + } + return okResponse(), nil + })) + + req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := calls.Load(); got != 2 { + t.Fatalf("expected 2 calls, got %d", got) + } + }) +} + +// policyFunc adapts a function into a Policy. +type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration) + +func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { + return f(attempt, req, resp, err) +}