package circuitbreaker import ( "errors" "io" "net/http" "strings" "testing" "time" "git.codelab.vc/pkg/httpx/middleware" ) func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { return middleware.RoundTripperFunc(fn) } func okResponse() *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func errResponse(code int) *http.Response { return &http.Response{ StatusCode: code, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func TestBreaker_StartsInClosedState(t *testing.T) { b := NewBreaker() if s := b.State(); s != StateClosed { t.Fatalf("state = %v, want %v", s, StateClosed) } } func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) { const threshold = 3 b := NewBreaker( WithFailureThreshold(threshold), WithOpenDuration(time.Hour), // long duration so it stays open ) for i := 0; i < threshold; i++ { done, err := b.Allow() if err != nil { t.Fatalf("iteration %d: Allow returned error: %v", i, err) } done(false) } if s := b.State(); s != StateOpen { t.Fatalf("state = %v, want %v", s, StateOpen) } } func TestBreaker_OpenRejectsRequests(t *testing.T) { b := NewBreaker( WithFailureThreshold(1), WithOpenDuration(time.Hour), ) // Trip the breaker. done, err := b.Allow() if err != nil { t.Fatalf("Allow returned error: %v", err) } done(false) // Subsequent requests should be rejected. _, err = b.Allow() if !errors.Is(err, ErrCircuitOpen) { t.Fatalf("err = %v, want %v", err, ErrCircuitOpen) } } func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) { const openDuration = 50 * time.Millisecond b := NewBreaker( WithFailureThreshold(1), WithOpenDuration(openDuration), ) // Trip the breaker. done, err := b.Allow() if err != nil { t.Fatal(err) } done(false) if s := b.State(); s != StateOpen { t.Fatalf("state = %v, want %v", s, StateOpen) } // Wait for the open duration to elapse. time.Sleep(openDuration + 10*time.Millisecond) if s := b.State(); s != StateHalfOpen { t.Fatalf("state = %v, want %v", s, StateHalfOpen) } } func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) { const openDuration = 50 * time.Millisecond b := NewBreaker( WithFailureThreshold(1), WithOpenDuration(openDuration), ) // Trip the breaker. done, err := b.Allow() if err != nil { t.Fatal(err) } done(false) // Wait for half-open. time.Sleep(openDuration + 10*time.Millisecond) // A successful request in half-open should close the breaker. done, err = b.Allow() if err != nil { t.Fatalf("Allow in half-open returned error: %v", err) } done(true) if s := b.State(); s != StateClosed { t.Fatalf("state = %v, want %v", s, StateClosed) } } func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) { const openDuration = 50 * time.Millisecond b := NewBreaker( WithFailureThreshold(1), WithOpenDuration(openDuration), ) // Trip the breaker. done, err := b.Allow() if err != nil { t.Fatal(err) } done(false) // Wait for half-open. time.Sleep(openDuration + 10*time.Millisecond) // A failed request in half-open should re-open the breaker. done, err = b.Allow() if err != nil { t.Fatalf("Allow in half-open returned error: %v", err) } done(false) if s := b.State(); s != StateOpen { t.Fatalf("state = %v, want %v", s, StateOpen) } } func TestTransport_PerHostBreakers(t *testing.T) { const threshold = 2 base := mockTransport(func(req *http.Request) (*http.Response, error) { if req.URL.Host == "failing.example.com" { return errResponse(http.StatusInternalServerError), nil } return okResponse(), nil }) rt := Transport( WithFailureThreshold(threshold), WithOpenDuration(time.Hour), )(base) t.Run("failing host trips breaker", func(t *testing.T) { for i := 0; i < threshold; i++ { req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil) if err != nil { t.Fatal(err) } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("iteration %d: unexpected error: %v", i, err) } resp.Body.Close() } // Next request to failing host should be rejected. req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil) if err != nil { t.Fatal(err) } _, err = rt.RoundTrip(req) if !errors.Is(err, ErrCircuitOpen) { t.Fatalf("err = %v, want %v", err, ErrCircuitOpen) } }) t.Run("healthy host is unaffected", func(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil) if err != nil { t.Fatal(err) } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) } }) } func TestTransport_SuccessResetsFailures(t *testing.T) { callCount := 0 base := mockTransport(func(req *http.Request) (*http.Response, error) { callCount++ // Fail on odd calls, succeed on even. if callCount%2 == 1 { return errResponse(http.StatusInternalServerError), nil } return okResponse(), nil }) rt := Transport( WithFailureThreshold(3), WithOpenDuration(time.Hour), )(base) // Alternate fail/success — should never trip because successes reset the // consecutive failure counter. for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil) if err != nil { t.Fatal(err) } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err) } resp.Body.Close() } }