diff --git a/circuitbreaker/breaker.go b/circuitbreaker/breaker.go new file mode 100644 index 0000000..68a2695 --- /dev/null +++ b/circuitbreaker/breaker.go @@ -0,0 +1,173 @@ +package circuitbreaker + +import ( + "errors" + "net/http" + "sync" + "time" + + "git.codelab.vc/pkg/httpx/middleware" +) + +// ErrCircuitOpen is returned by Allow when the breaker is in the Open state. +var ErrCircuitOpen = errors.New("httpx: circuit breaker is open") + +// State represents the current state of a circuit breaker. +type State int + +const ( + StateClosed State = iota // normal operation + StateOpen // failing, reject requests + StateHalfOpen // testing recovery +) + +// String returns a human-readable name for the state. +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// Breaker implements a per-endpoint circuit breaker state machine. +// +// State transitions: +// +// Closed → Open: after failureThreshold consecutive failures +// Open → HalfOpen: after openDuration passes +// HalfOpen → Closed: on success +// HalfOpen → Open: on failure (timer resets) +type Breaker struct { + mu sync.Mutex + opts options + + state State + failures int // consecutive failure count (Closed state) + openedAt time.Time + halfOpenCur int // current in-flight half-open requests +} + +// NewBreaker creates a Breaker with the given options. +func NewBreaker(opts ...Option) *Breaker { + o := defaults() + for _, fn := range opts { + fn(&o) + } + return &Breaker{opts: o} +} + +// State returns the current state of the breaker. +func (b *Breaker) State() State { + b.mu.Lock() + defer b.mu.Unlock() + return b.stateLocked() +} + +// stateLocked returns the effective state, promoting Open → HalfOpen when the +// open duration has elapsed. Caller must hold b.mu. +func (b *Breaker) stateLocked() State { + if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration { + b.state = StateHalfOpen + b.halfOpenCur = 0 + } + return b.state +} + +// Allow checks whether a request is permitted. If allowed it returns a done +// callback that the caller MUST invoke with the result of the request. If the +// breaker is open, it returns ErrCircuitOpen. +func (b *Breaker) Allow() (done func(success bool), err error) { + b.mu.Lock() + defer b.mu.Unlock() + + switch b.stateLocked() { + case StateClosed: + // always allow + case StateOpen: + return nil, ErrCircuitOpen + case StateHalfOpen: + if b.halfOpenCur >= b.opts.halfOpenMax { + return nil, ErrCircuitOpen + } + b.halfOpenCur++ + } + + return b.doneFunc(), nil +} + +// doneFunc returns the callback for a single in-flight request. Caller must +// hold b.mu when calling doneFunc, but the returned function acquires the lock +// itself. +func (b *Breaker) doneFunc() func(success bool) { + var once sync.Once + return func(success bool) { + once.Do(func() { + b.mu.Lock() + defer b.mu.Unlock() + b.record(success) + }) + } +} + +// record processes the outcome of a single request. Caller must hold b.mu. +func (b *Breaker) record(success bool) { + switch b.state { + case StateClosed: + if success { + b.failures = 0 + return + } + b.failures++ + if b.failures >= b.opts.failureThreshold { + b.tripLocked() + } + + case StateHalfOpen: + b.halfOpenCur-- + if success { + b.state = StateClosed + b.failures = 0 + } else { + b.tripLocked() + } + } +} + +// tripLocked transitions to the Open state and records the timestamp. +func (b *Breaker) tripLocked() { + b.state = StateOpen + b.openedAt = time.Now() + b.halfOpenCur = 0 +} + +// Transport returns a middleware that applies per-host circuit breaking. It +// maintains an internal map of host → *Breaker so each target host is tracked +// independently. +func Transport(opts ...Option) middleware.Middleware { + var hosts sync.Map // map[string]*Breaker + + return func(next http.RoundTripper) http.RoundTripper { + return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + host := req.URL.Host + + val, _ := hosts.LoadOrStore(host, NewBreaker(opts...)) + cb := val.(*Breaker) + + done, err := cb.Allow() + if err != nil { + return nil, err + } + + resp, rtErr := next.RoundTrip(req) + done(rtErr == nil && resp != nil && resp.StatusCode < 500) + + return resp, rtErr + }) + } +} diff --git a/circuitbreaker/breaker_test.go b/circuitbreaker/breaker_test.go new file mode 100644 index 0000000..0388a37 --- /dev/null +++ b/circuitbreaker/breaker_test.go @@ -0,0 +1,249 @@ +package circuitbreaker + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { + return middleware.RoundTripperFunc(fn) +} + +func okResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + } +} + +func errResponse(code int) *http.Response { + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + } +} + +func TestBreaker_StartsInClosedState(t *testing.T) { + b := NewBreaker() + if s := b.State(); s != StateClosed { + t.Fatalf("state = %v, want %v", s, StateClosed) + } +} + +func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) { + const threshold = 3 + b := NewBreaker( + WithFailureThreshold(threshold), + WithOpenDuration(time.Hour), // long duration so it stays open + ) + + for i := 0; i < threshold; i++ { + done, err := b.Allow() + if err != nil { + t.Fatalf("iteration %d: Allow returned error: %v", i, err) + } + done(false) + } + + if s := b.State(); s != StateOpen { + t.Fatalf("state = %v, want %v", s, StateOpen) + } +} + +func TestBreaker_OpenRejectsRequests(t *testing.T) { + b := NewBreaker( + WithFailureThreshold(1), + WithOpenDuration(time.Hour), + ) + + // Trip the breaker. + done, err := b.Allow() + if err != nil { + t.Fatalf("Allow returned error: %v", err) + } + done(false) + + // Subsequent requests should be rejected. + _, err = b.Allow() + if !errors.Is(err, ErrCircuitOpen) { + t.Fatalf("err = %v, want %v", err, ErrCircuitOpen) + } +} + +func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) { + const openDuration = 50 * time.Millisecond + b := NewBreaker( + WithFailureThreshold(1), + WithOpenDuration(openDuration), + ) + + // Trip the breaker. + done, err := b.Allow() + if err != nil { + t.Fatal(err) + } + done(false) + + if s := b.State(); s != StateOpen { + t.Fatalf("state = %v, want %v", s, StateOpen) + } + + // Wait for the open duration to elapse. + time.Sleep(openDuration + 10*time.Millisecond) + + if s := b.State(); s != StateHalfOpen { + t.Fatalf("state = %v, want %v", s, StateHalfOpen) + } +} + +func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) { + const openDuration = 50 * time.Millisecond + b := NewBreaker( + WithFailureThreshold(1), + WithOpenDuration(openDuration), + ) + + // Trip the breaker. + done, err := b.Allow() + if err != nil { + t.Fatal(err) + } + done(false) + + // Wait for half-open. + time.Sleep(openDuration + 10*time.Millisecond) + + // A successful request in half-open should close the breaker. + done, err = b.Allow() + if err != nil { + t.Fatalf("Allow in half-open returned error: %v", err) + } + done(true) + + if s := b.State(); s != StateClosed { + t.Fatalf("state = %v, want %v", s, StateClosed) + } +} + +func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) { + const openDuration = 50 * time.Millisecond + b := NewBreaker( + WithFailureThreshold(1), + WithOpenDuration(openDuration), + ) + + // Trip the breaker. + done, err := b.Allow() + if err != nil { + t.Fatal(err) + } + done(false) + + // Wait for half-open. + time.Sleep(openDuration + 10*time.Millisecond) + + // A failed request in half-open should re-open the breaker. + done, err = b.Allow() + if err != nil { + t.Fatalf("Allow in half-open returned error: %v", err) + } + done(false) + + if s := b.State(); s != StateOpen { + t.Fatalf("state = %v, want %v", s, StateOpen) + } +} + +func TestTransport_PerHostBreakers(t *testing.T) { + const threshold = 2 + + base := mockTransport(func(req *http.Request) (*http.Response, error) { + if req.URL.Host == "failing.example.com" { + return errResponse(http.StatusInternalServerError), nil + } + return okResponse(), nil + }) + + rt := Transport( + WithFailureThreshold(threshold), + WithOpenDuration(time.Hour), + )(base) + + t.Run("failing host trips breaker", func(t *testing.T) { + for i := 0; i < threshold; i++ { + req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil) + if err != nil { + t.Fatal(err) + } + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + resp.Body.Close() + } + + // Next request to failing host should be rejected. + req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil) + if err != nil { + t.Fatal(err) + } + _, err = rt.RoundTrip(req) + if !errors.Is(err, ErrCircuitOpen) { + t.Fatalf("err = %v, want %v", err, ErrCircuitOpen) + } + }) + + t.Run("healthy host is unaffected", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil) + if err != nil { + t.Fatal(err) + } + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + }) +} + +func TestTransport_SuccessResetsFailures(t *testing.T) { + callCount := 0 + base := mockTransport(func(req *http.Request) (*http.Response, error) { + callCount++ + // Fail on odd calls, succeed on even. + if callCount%2 == 1 { + return errResponse(http.StatusInternalServerError), nil + } + return okResponse(), nil + }) + + rt := Transport( + WithFailureThreshold(3), + WithOpenDuration(time.Hour), + )(base) + + // Alternate fail/success — should never trip because successes reset the + // consecutive failure counter. + for i := 0; i < 10; i++ { + req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil) + if err != nil { + t.Fatal(err) + } + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err) + } + resp.Body.Close() + } +} diff --git a/circuitbreaker/options.go b/circuitbreaker/options.go new file mode 100644 index 0000000..57bd770 --- /dev/null +++ b/circuitbreaker/options.go @@ -0,0 +1,50 @@ +package circuitbreaker + +import "time" + +type options struct { + failureThreshold int // consecutive failures to trip + openDuration time.Duration // how long to stay open before half-open + halfOpenMax int // max concurrent requests in half-open +} + +func defaults() options { + return options{ + failureThreshold: 5, + openDuration: 30 * time.Second, + halfOpenMax: 1, + } +} + +// Option configures a Breaker. +type Option func(*options) + +// WithFailureThreshold sets the number of consecutive failures required to +// trip the breaker from Closed to Open. Default is 5. +func WithFailureThreshold(n int) Option { + return func(o *options) { + if n > 0 { + o.failureThreshold = n + } + } +} + +// WithOpenDuration sets how long the breaker stays in the Open state before +// transitioning to HalfOpen. Default is 30s. +func WithOpenDuration(d time.Duration) Option { + return func(o *options) { + if d > 0 { + o.openDuration = d + } + } +} + +// WithHalfOpenMax sets the maximum number of concurrent probe requests +// allowed while the breaker is in the HalfOpen state. Default is 1. +func WithHalfOpenMax(n int) Option { + return func(o *options) { + if n > 0 { + o.halfOpenMax = n + } + } +}