package retry import ( "bytes" "context" "io" "net/http" "strings" "sync/atomic" "testing" "time" "git.codelab.vc/pkg/httpx/middleware" ) func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { return middleware.RoundTripperFunc(fn) } func okResponse() *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func statusResponse(code int) *http.Response { return &http.Response{ StatusCode: code, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func TestTransport(t *testing.T) { t.Run("successful request no retry", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return okResponse(), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 1 { t.Fatalf("expected 1 call, got %d", got) } }) t.Run("retries on 503 then succeeds", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n < 3 { return statusResponse(http.StatusServiceUnavailable), nil } return okResponse(), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 3 { t.Fatalf("expected 3 calls, got %d", got) } }) t.Run("does not retry non-idempotent POST", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return statusResponse(http.StatusServiceUnavailable), nil })) req, _ := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("data")) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("expected 503, got %d", resp.StatusCode) } if got := calls.Load(); got != 1 { t.Fatalf("expected 1 call (no retry for POST), got %d", got) } }) t.Run("stops on context cancellation", func(t *testing.T) { var calls atomic.Int32 ctx, cancel := context.WithCancel(context.Background()) rt := Transport( WithMaxAttempts(5), WithBackoff(ConstantBackoff(50*time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n == 1 { cancel() } return statusResponse(http.StatusServiceUnavailable), nil })) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != context.Canceled { t.Fatalf("expected context.Canceled, got resp=%v err=%v", resp, err) } }) t.Run("respects maxAttempts", func(t *testing.T) { var calls atomic.Int32 rt := Transport( WithMaxAttempts(2), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) return statusResponse(http.StatusBadGateway), nil })) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusBadGateway { t.Fatalf("expected 502, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls (maxAttempts=2), got %d", got) } }) t.Run("body is restored via GetBody on retry", func(t *testing.T) { var calls atomic.Int32 var bodies []string rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), )(mockTransport(func(req *http.Request) (*http.Response, error) { calls.Add(1) b, _ := io.ReadAll(req.Body) bodies = append(bodies, string(b)) if len(bodies) < 2 { return statusResponse(http.StatusServiceUnavailable), nil } return okResponse(), nil })) bodyContent := "request-body" body := bytes.NewReader([]byte(bodyContent)) req, _ := http.NewRequest(http.MethodPut, "http://example.com", body) req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader([]byte(bodyContent))), nil } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls, got %d", got) } for i, b := range bodies { if b != bodyContent { t.Fatalf("attempt %d: expected body %q, got %q", i, bodyContent, b) } } }) t.Run("custom policy", func(t *testing.T) { var calls atomic.Int32 // Custom policy: retry only on 418 custom := policyFunc(func(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { if resp != nil && resp.StatusCode == http.StatusTeapot { return true, 0 } return false, 0 }) rt := Transport( WithMaxAttempts(3), WithBackoff(ConstantBackoff(time.Millisecond)), WithPolicy(custom), )(mockTransport(func(req *http.Request) (*http.Response, error) { n := calls.Add(1) if n == 1 { return statusResponse(http.StatusTeapot), nil } return okResponse(), nil })) req, _ := http.NewRequest(http.MethodPost, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } if got := calls.Load(); got != 2 { t.Fatalf("expected 2 calls, got %d", got) } }) } // policyFunc adapts a function into a Policy. type policyFunc func(int, *http.Request, *http.Response, error) (bool, time.Duration) func (f policyFunc) ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { return f(attempt, req, resp, err) }