Implements retry middleware as a RoundTripper wrapper: - Exponential and constant backoff strategies with jitter - RFC 7231 Retry-After header parsing (seconds and HTTP-date) - Default policy retries idempotent methods on 429/5xx and network errors - Body restoration via GetBody, context cancellation, response body cleanup
238 lines
6.5 KiB
Go
238 lines
6.5 KiB
Go
package retry
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.codelab.vc/pkg/httpx/middleware"
|
|
)
|
|
|
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
|
return middleware.RoundTripperFunc(fn)
|
|
}
|
|
|
|
func okResponse() *http.Response {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
func statusResponse(code int) *http.Response {
|
|
return &http.Response{
|
|
StatusCode: code,
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
func TestTransport(t *testing.T) {
|
|
t.Run("successful request no retry", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Fatalf("expected 1 call, got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("retries on 503 then succeeds", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n < 3 {
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 3 {
|
|
t.Fatalf("expected 3 calls, got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("stops on context cancellation", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(5),
|
|
WithBackoff(ConstantBackoff(50*time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n == 1 {
|
|
cancel()
|
|
}
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != context.Canceled {
|
|
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
|
|
}
|
|
})
|
|
|
|
t.Run("respects maxAttempts", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(2),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return statusResponse(http.StatusBadGateway), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusBadGateway {
|
|
t.Fatalf("expected 502, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
var bodies []string
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
b, _ := io.ReadAll(req.Body)
|
|
bodies = append(bodies, string(b))
|
|
if len(bodies) < 2 {
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
bodyContent := "request-body"
|
|
body := bytes.NewReader([]byte(bodyContent))
|
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
|
|
req.GetBody = func() (io.ReadCloser, error) {
|
|
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
|
|
}
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", got)
|
|
}
|
|
for i, b := range bodies {
|
|
if b != bodyContent {
|
|
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("custom policy", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
|
|
// Custom policy: retry only on 418
|
|
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
|
if resp != nil && resp.StatusCode == http.StatusTeapot {
|
|
return true, 0
|
|
}
|
|
return false, 0
|
|
})
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
WithPolicy(custom),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n == 1 {
|
|
return statusResponse(http.StatusTeapot), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
// policyFunc adapts a function into a Policy.
|
|
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
|
|
|
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
|
return f(attempt, req, resp, err)
|
|
}
|