package retry import ( "bytes" "context" "io" "net/http" "strings" "sync/atomic" "testing" "time" "git.codelab.vc/pkg/httpx/internal/clock" "git.codelab.vc/pkg/httpx/middleware" ) func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { return middleware.RoundTripperFunc(fn) } func okResponse() *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func statusResponse(code int) *http.Response { return &http.Response{ StatusCode: code, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func TestTransport(t *testing.T) { t.Run("successful request no retry", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return okResponse(), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 1 { t.Fatalf("expected 1 call, got %d", got) } }) t.Run("retries on 503 then succeeds", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n < 3 { return statusResponse(http.StatusServiceUnavailable), nil } return okResponse(), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 3 { t.Fatalf("expected 3 calls, got %d", got) } }) t.Run("does not retry non-idempotent POST", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return statusResponse(http.StatusServiceUnavailable), nil })) req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data")) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("expected 503, got %d", resp.StatusCode) } if got := calls.Load(); got != 1 { t.Fatalf("expected 1 call (no retry for POST), got %d", got) } }) t.Run("stops on context cancellation", func(t *testing.T) { var calls atomic.Int32 ctx, cancel := context.WithCancel(context.Background()) rt := Transport( WithMaxAttempts(5), WithBackoff(ConstantBackoff(50*time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n == 1 { cancel() } return statusResponse(http.StatusServiceUnavailable), nil })) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != context.Canceled { t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err) } }) t.Run("respects maxAttempts", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(2), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return statusResponse(http.StatusBadGateway), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusBadGateway { t.Fatalf("expected 502, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got) } }) t.Run("body is restored via GetBody on retry", func(t *testing.T) { var calls atomic.Int32 var bodies []string rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) b, _ := io.ReadAll(req.Body) bodies = append(bodies, string(b)) if len(bodies) < 2 { return statusResponse(http.StatusServiceUnavailable), nil } return okResponse(), nil })) bodyContent := "request-body" body := bytes.NewReader([]byte(bodyContent)) req, _ := http.NewRequest(http.MethodPut, "http://example.com", body) req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls, got %d", got) } for i, b := range bodies { if b != bodyContent { t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b) } } }) t.Run("custom policy", func(t *testing.T) { var calls atomic.Int32 // Custom policy: retry only on 418 custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { if resp != nil && resp.StatusCode == http.StatusTeapot { return true, 0 } return false, 0 }) rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), WithPolicy(custom), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n == 1 { return statusResponse(http.StatusTeapot), nil } return okResponse(), nil })) req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls, got %d", got) } }) } // TestTransport_BodyNotRewindable verifies that an idempotent request whose // body cannot be replayed (no GetBody) is returned as-is rather than retried // with an empty body or a stale, already-drained response. func TestTransport_BodyNotRewindable(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) io.Copy(io.Discard, req.Body) // a real transport consumes the body return statusResponse(http.StatusServiceUnavailable), nil })) // PUT is idempotent (the policy would retry a 503), but with GetBody unset // the body cannot be rewound. req, _ := http.NewRequest(http.MethodPut, "http://example.com", strings.NewReader("data")) req.GetBody = nil resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp == nil || resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("expected the original 503 response, got %v", resp) } if got := calls.Load(); got != 1 { t.Fatalf("expected exactly 1 call (no rewind retry), got %d", got) } } // TestTransport_InjectedClock verifies that backoff delays are driven by the // configured clock, so retries are deterministic without real sleeps. func TestTransport_InjectedClock(t *testing.T) { clk := clock.Mock(time.Now()) var calls atomic.Int32 rt := Transport( WithMaxAttempts(2), WithBackoff(ConstantBackoff(time.Hour)), // would block forever on a real clock withClock(clk), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return statusResponse(http.StatusServiceUnavailable), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) done := make(chan struct{}) var resp *http.Response var err error go func() { resp, err = rt.RoundTrip(req) close(done) }() // Drive the backoff via the mock clock. Advancing repeatedly is robust // against the timer being created slightly after the first attempt. for { clk.Advance(time.Hour) select { case <-done: goto finished case <-time.After(time.Millisecond): } } finished: if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("expected 503, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls, got %d", got) } } // policyFunc adapts a function into a Policy. type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration) func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { return f(attempt, req, resp, err) }