Files
httpx/circuitbreaker/breaker.go
Aleksey Shakhmatov b07d487e63 Drive circuit breaker state transitions via internal/clock
The Open->HalfOpen promotion used time.Now/time.Since directly, forcing tests
to use real time.Sleep and diverging from the project's clock convention. Add
an unexported withClock option (default clock.System) and replace the real
sleeps in tests with mock-clock Advance, making the transitions deterministic
and the package faster.
2026-05-23 13:47:26 +03:00

177 lines
4.1 KiB
Go

package circuitbreaker
import (
"errors"
"net/http"
"sync"
"time"
"git.codelab.vc/pkg/httpx/middleware"
)
// ErrCircuitOpen is returned by Allow when the breaker is in the Open state.
var ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
// State represents the current state of a circuit breaker.
type State int
const (
StateClosed State = iota // normal operation
StateOpen // failing, reject requests
StateHalfOpen // testing recovery
)
// String returns a human-readable name for the state.
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateOpen:
return "open"
case StateHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Breaker implements a per-endpoint circuit breaker state machine.
//
// State transitions:
//
// Closed → Open: after failureThreshold consecutive failures
// Open → HalfOpen: after openDuration passes
// HalfOpen → Closed: on success
// HalfOpen → Open: on failure (timer resets)
type Breaker struct {
mu sync.Mutex
opts options
state State
failures int // consecutive failure count (Closed state)
openedAt time.Time
halfOpenCur int // current in-flight half-open requests
}
// NewBreaker creates a Breaker with the given options.
func NewBreaker(opts ...Option) *Breaker {
o := defaults()
for _, fn := range opts {
fn(&o)
}
return &Breaker{opts: o}
}
// State returns the current state of the breaker.
func (b *Breaker) State() State {
b.mu.Lock()
defer b.mu.Unlock()
return b.stateLocked()
}
// stateLocked returns the effective state, promoting Open → HalfOpen when the
// open duration has elapsed. Caller must hold b.mu.
func (b *Breaker) stateLocked() State {
if b.state == StateOpen && b.opts.clk.Since(b.openedAt) >= b.opts.openDuration {
b.state = StateHalfOpen
b.halfOpenCur = 0
}
return b.state
}
// Allow checks whether a request is permitted. If allowed it returns a done
// callback that the caller MUST invoke with the result of the request. If the
// breaker is open, it returns ErrCircuitOpen.
func (b *Breaker) Allow() (done func(success bool), err error) {
b.mu.Lock()
defer b.mu.Unlock()
switch b.stateLocked() {
case StateClosed:
// always allow
case StateOpen:
return nil, ErrCircuitOpen
case StateHalfOpen:
if b.halfOpenCur >= b.opts.halfOpenMax {
return nil, ErrCircuitOpen
}
b.halfOpenCur++
}
return b.doneFunc(), nil
}
// doneFunc returns the callback for a single in-flight request. Caller must
// hold b.mu when calling doneFunc, but the returned function acquires the lock
// itself.
func (b *Breaker) doneFunc() func(success bool) {
var once sync.Once
return func(success bool) {
once.Do(func() {
b.mu.Lock()
defer b.mu.Unlock()
b.record(success)
})
}
}
// record processes the outcome of a single request. Caller must hold b.mu.
func (b *Breaker) record(success bool) {
switch b.state {
case StateClosed:
if success {
b.failures = 0
return
}
b.failures++
if b.failures >= b.opts.failureThreshold {
b.tripLocked()
}
case StateHalfOpen:
b.halfOpenCur--
if success {
b.state = StateClosed
b.failures = 0
} else {
b.tripLocked()
}
}
}
// tripLocked transitions to the Open state and records the timestamp.
func (b *Breaker) tripLocked() {
b.state = StateOpen
b.openedAt = b.opts.clk.Now()
b.halfOpenCur = 0
}
// Transport returns a middleware that applies per-host circuit breaking. It
// maintains an internal map of host → *Breaker so each target host is tracked
// independently.
func Transport(opts ...Option) middleware.Middleware {
var hosts sync.Map // map[string]*Breaker
return func(next http.RoundTripper) http.RoundTripper {
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
host := req.URL.Host
val, ok := hosts.Load(host)
if !ok {
val, _ = hosts.LoadOrStore(host, NewBreaker(opts...))
}
cb := val.(*Breaker)
done, err := cb.Allow()
if err != nil {
return nil, err
}
resp, rtErr := next.RoundTrip(req)
done(rtErr == nil && resp != nil && resp.StatusCode < 500)
return resp, rtErr
})
}
}