package circuitbreaker import ( "errors" "net/http" "sync" "time" "git.codelab.vc/pkg/httpx/middleware" ) // ErrCircuitOpen is returned by Allow when the breaker is in the Open state. var ErrCircuitOpen = errors.New("httpx: circuit breaker is open") // State represents the current state of a circuit breaker. type State int const ( StateClosed State = iota // normal operation StateOpen // failing, reject requests StateHalfOpen // testing recovery ) // String returns a human-readable name for the state. func (s State) String() string { switch s { case StateClosed: return "closed" case StateOpen: return "open" case StateHalfOpen: return "half-open" default: return "unknown" } } // Breaker implements a per-endpoint circuit breaker state machine. // // State transitions: // // Closed → Open: after failureThreshold consecutive failures // Open → HalfOpen: after openDuration passes // HalfOpen → Closed: on success // HalfOpen → Open: on failure (timer resets) type Breaker struct { mu sync.Mutex opts options state State failures int // consecutive failure count (Closed state) openedAt time.Time halfOpenCur int // current in-flight half-open requests } // NewBreaker creates a Breaker with the given options. func NewBreaker(opts ...Option) *Breaker { o := defaults() for _, fn := range opts { fn(&o) } return &Breaker{opts: o} } // State returns the current state of the breaker. func (b *Breaker) State() State { b.mu.Lock() defer b.mu.Unlock() return b.stateLocked() } // stateLocked returns the effective state, promoting Open → HalfOpen when the // open duration has elapsed. Caller must hold b.mu. func (b *Breaker) stateLocked() State { if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration { b.state = StateHalfOpen b.halfOpenCur = 0 } return b.state } // Allow checks whether a request is permitted. If allowed it returns a done // callback that the caller MUST invoke with the result of the request. If the // breaker is open, it returns ErrCircuitOpen. func (b *Breaker) Allow() (done func(success bool), err error) { b.mu.Lock() defer b.mu.Unlock() switch b.stateLocked() { case StateClosed: // always allow case StateOpen: return nil, ErrCircuitOpen case StateHalfOpen: if b.halfOpenCur >= b.opts.halfOpenMax { return nil, ErrCircuitOpen } b.halfOpenCur++ } return b.doneFunc(), nil } // doneFunc returns the callback for a single in-flight request. Caller must // hold b.mu when calling doneFunc, but the returned function acquires the lock // itself. func (b *Breaker) doneFunc() func(success bool) { var once sync.Once return func(success bool) { once.Do(func() { b.mu.Lock() defer b.mu.Unlock() b.record(success) }) } } // record processes the outcome of a single request. Caller must hold b.mu. func (b *Breaker) record(success bool) { switch b.state { case StateClosed: if success { b.failures = 0 return } b.failures++ if b.failures >= b.opts.failureThreshold { b.tripLocked() } case StateHalfOpen: b.halfOpenCur-- if success { b.state = StateClosed b.failures = 0 } else { b.tripLocked() } } } // tripLocked transitions to the Open state and records the timestamp. func (b *Breaker) tripLocked() { b.state = StateOpen b.openedAt = time.Now() b.halfOpenCur = 0 } // Transport returns a middleware that applies per-host circuit breaking. It // maintains an internal map of host → *Breaker so each target host is tracked // independently. func Transport(opts ...Option) middleware.Middleware { var hosts sync.Map // map[string]*Breaker return func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { host := req.URL.Host val, _ := hosts.LoadOrStore(host, NewBreaker(opts...)) cb := val.(*Breaker) done, err := cb.Allow() if err != nil { return nil, err } resp, rtErr := next.RoundTrip(req) done(rtErr == nil && resp != nil && resp.StatusCode < 500) return resp, rtErr }) } }