package retry import ( "errors" "fmt" "io" "net/http" "time" "git.codelab.vc/pkg/httpx/middleware" ) // ErrRetryExhausted is returned when all retry attempts have been exhausted // and the last attempt also failed. var ErrRetryExhausted = errors.New("httpx: all retry attempts exhausted") // Policy decides whether a failed request should be retried. type Policy interface { // ShouldRetry reports whether the request should be retried. The extra // duration, if non-zero, is a policy-suggested delay that overrides the // backoff strategy. ShouldRetry(attempt int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) } // Transport returns a middleware that retries failed requests according to // the provided options. func Transport(opts ...Option) middleware.Middleware { cfg := defaults() for _, o := range opts { o(&cfg) } return func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { var resp *http.Response var err error var exhausted bool for attempt := range cfg.maxAttempts { // For retries (attempt > 0), restore the request body. if attempt > 0 { if req.GetBody != nil { body, bodyErr := req.GetBody() if bodyErr != nil { return resp, bodyErr } req.Body = body } else if req.Body != nil { // Body was consumed and cannot be re-created. return resp, err } } resp, err = next.RoundTrip(req) // Last attempt — return whatever we got. if attempt == cfg.maxAttempts-1 { exhausted = true break } shouldRetry, policyDelay := cfg.policy.ShouldRetry(attempt, req, resp, err) if !shouldRetry { break } // Compute delay: use backoff or policy delay, whichever is larger. delay := cfg.backoff.Delay(attempt) if policyDelay > delay { delay = policyDelay } // Respect Retry-After header if enabled. if cfg.retryAfter && resp != nil { if ra, ok := ParseRetryAfter(resp); ok && ra > delay { delay = ra } } // Drain and close the response body to release the connection. if resp != nil { io.Copy(io.Discard, resp.Body) resp.Body.Close() } // Wait for the delay or context cancellation. timer := time.NewTimer(delay) select { case <-req.Context().Done(): timer.Stop() return nil, req.Context().Err() case <-timer.C: } } // Wrap with ErrRetryExhausted only when all attempts were used. if exhausted && err != nil { err = fmt.Errorf("%w: %w", ErrRetryExhausted, err) } return resp, err }) } } // defaultPolicy retries on network errors, 429, and 5xx server errors. // It refuses to retry non-idempotent methods. type defaultPolicy struct{} func (defaultPolicy) ShouldRetry(_ int, req *http.Request, resp *http.Response, err error) (bool, time.Duration) { if !isIdempotent(req.Method) { return false, 0 } // Network error — always retry idempotent requests. if err != nil { return true, 0 } switch resp.StatusCode { case http.StatusTooManyRequests, // 429 http.StatusBadGateway, // 502 http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout: // 504 return true, 0 } return false, 0 } // isIdempotent reports whether the HTTP method is safe to retry. func isIdempotent(method string) bool { switch method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut: return true } return false }