Gate the retry decision on body rewindability: an idempotent request whose body cannot be replayed (no GetBody) is now returned as-is instead of looping with an empty body or surfacing a stale, already-drained response. Guard ExponentialBackoff against rand.Int64N panicking when delay/2 rounds to zero. Use internal/clock for inter-attempt delays so retry timing is consistent with the rest of the codebase and testable without real sleeps.
316 lines
8.9 KiB
Go
316 lines
8.9 KiB
Go
package retry
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.codelab.vc/pkg/httpx/internal/clock"
|
|
"git.codelab.vc/pkg/httpx/middleware"
|
|
)
|
|
|
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
|
return middleware.RoundTripperFunc(fn)
|
|
}
|
|
|
|
func okResponse() *http.Response {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
func statusResponse(code int) *http.Response {
|
|
return &http.Response{
|
|
StatusCode: code,
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
func TestTransport(t *testing.T) {
|
|
t.Run("successful request no retry", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Fatalf("expected 1 call, got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("retries on 503 then succeeds", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n < 3 {
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 3 {
|
|
t.Fatalf("expected 3 calls, got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("does not retry non-idempotent POST", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data"))
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Fatalf("expected 1 call (no retry for POST), got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("stops on context cancellation", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(5),
|
|
WithBackoff(ConstantBackoff(50*time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n == 1 {
|
|
cancel()
|
|
}
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != context.Canceled {
|
|
t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err)
|
|
}
|
|
})
|
|
|
|
t.Run("respects maxAttempts", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(2),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return statusResponse(http.StatusBadGateway), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusBadGateway {
|
|
t.Fatalf("expected 502, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got)
|
|
}
|
|
})
|
|
|
|
t.Run("body is restored via GetBody on retry", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
var bodies []string
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
b, _ := io.ReadAll(req.Body)
|
|
bodies = append(bodies, string(b))
|
|
if len(bodies) < 2 {
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
bodyContent := "request-body"
|
|
body := bytes.NewReader([]byte(bodyContent))
|
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", body)
|
|
req.GetBody = func() (io.ReadCloser, error) {
|
|
return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil
|
|
}
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", got)
|
|
}
|
|
for i, b := range bodies {
|
|
if b != bodyContent {
|
|
t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("custom policy", func(t *testing.T) {
|
|
var calls atomic.Int32
|
|
|
|
// Custom policy: retry only on 418
|
|
custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
|
if resp != nil && resp.StatusCode == http.StatusTeapot {
|
|
return true, 0
|
|
}
|
|
return false, 0
|
|
})
|
|
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
WithPolicy(custom),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
n := calls.Add(1)
|
|
if n == 1 {
|
|
return statusResponse(http.StatusTeapot), nil
|
|
}
|
|
return okResponse(), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil)
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestTransport_BodyNotRewindable verifies that an idempotent request whose
|
|
// body cannot be replayed (no GetBody) is returned as-is rather than retried
|
|
// with an empty body or a stale, already-drained response.
|
|
func TestTransport_BodyNotRewindable(t *testing.T) {
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(3),
|
|
WithBackoff(ConstantBackoff(time.Millisecond)),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
io.Copy(io.Discard, req.Body) // a real transport consumes the body
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
// PUT is idempotent (the policy would retry a 503), but with GetBody unset
|
|
// the body cannot be rewound.
|
|
req, _ := http.NewRequest(http.MethodPut, "http://example.com", strings.NewReader("data"))
|
|
req.GetBody = nil
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp == nil || resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected the original 503 response, got %v", resp)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Fatalf("expected exactly 1 call (no rewind retry), got %d", got)
|
|
}
|
|
}
|
|
|
|
// TestTransport_InjectedClock verifies that backoff delays are driven by the
|
|
// configured clock, so retries are deterministic without real sleeps.
|
|
func TestTransport_InjectedClock(t *testing.T) {
|
|
clk := clock.Mock(time.Now())
|
|
var calls atomic.Int32
|
|
rt := Transport(
|
|
WithMaxAttempts(2),
|
|
WithBackoff(ConstantBackoff(time.Hour)), // would block forever on a real clock
|
|
withClock(clk),
|
|
)(mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return statusResponse(http.StatusServiceUnavailable), nil
|
|
}))
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
|
|
done := make(chan struct{})
|
|
var resp *http.Response
|
|
var err error
|
|
go func() {
|
|
resp, err = rt.RoundTrip(req)
|
|
close(done)
|
|
}()
|
|
|
|
// Drive the backoff via the mock clock. Advancing repeatedly is robust
|
|
// against the timer being created slightly after the first attempt.
|
|
for {
|
|
clk.Advance(time.Hour)
|
|
select {
|
|
case <-done:
|
|
goto finished
|
|
case <-time.After(time.Millisecond):
|
|
}
|
|
}
|
|
finished:
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Fatalf("expected 503, got %d", resp.StatusCode)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Fatalf("expected 2 calls, got %d", got)
|
|
}
|
|
}
|
|
|
|
// policyFunc adapts a function into a Policy.
|
|
type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration)
|
|
|
|
func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) {
|
|
return f(attempt, req, resp, err)
|
|
}
|