Add retry transport with configurable backoff and Retry-After support

Implements retry middleware as a RoundTripper wrapper:
- Exponential and constant backoff strategies with jitter
- RFC 7231 Retry-After header parsing (seconds and HTTP-date)
- Default policy retries idempotent methods on 429/5xx and network errors
- Body restoration via GetBody, context cancellation, response body cleanup
This commit is contained in:
2026-03-20 14:21:53 +03:00
parent 6b1941fce7
commit 505c7b8c4f
7 changed files with 660 additions and 0 deletions

237
retry/retry_test.go Normal file
View File

@@ -0,0 +1,237 @@
package retry
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
return middleware.RoundTripperFunc(fn)
}
func okResponse() *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func statusResponse(code int) *http.Response {
return &http.Response{
StatusCode: code,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
}
}
func TestTransport(t *testing.T) {
t.Run("successful request no retry", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call, got %d", got)
}
})
t.Run("retries on 503 then succeeds", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n < 3 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 3 {
t.Fatalf("expected 3 calls, got %d", got)
}
})
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("expected 503, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 1 {
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
}
})
t.Run("stops on context cancellation", func(t *testing.T) {
var calls atomic.Int32
ctx, cancel := context.WithCancel(context.Background())
rt := Transport(
WithMaxAttempts(5),
WithBackoff(ConstantBackoff(50*time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
cancel()
}
return statusResponse(http.StatusServiceUnavailable), nil
}))
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
}
})
t.Run("respects maxAttempts", func(t *testing.T) {
var calls atomic.Int32
rt := Transport(
WithMaxAttempts(2),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
return statusResponse(http.StatusBadGateway), nil
}))
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
}
})
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
var calls atomic.Int32
var bodies []string
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
calls.Add(1)
b, _ := io.ReadAll(req.Body)
bodies = append(bodies, string(b))
if len(bodies) < 2 {
return statusResponse(http.StatusServiceUnavailable), nil
}
return okResponse(), nil
}))
bodyContent := "request-body"
body := bytes.NewReader([]byte(bodyContent))
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
for i, b := range bodies {
if b != bodyContent {
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
}
}
})
t.Run("custom policy", func(t *testing.T) {
var calls atomic.Int32
// Custom policy: retry only on 418
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
if resp != nil && resp.StatusCode == http.StatusTeapot {
return true, 0
}
return false, 0
})
rt := Transport(
WithMaxAttempts(3),
WithBackoff(ConstantBackoff(time.Millisecond)),
WithPolicy(custom),
)(mockTransport(func(req *http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
return statusResponse(http.StatusTeapot), nil
}
return okResponse(), nil
}))
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if got := calls.Load(); got != 2 {
t.Fatalf("expected 2 calls, got %d", got)
}
})
}
// policyFunc adapts a function into a Policy.
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
return f(attempt, req, resp, err)
}