Add per-host circuit breaker with three-state machine
Implements circuit breaker as a RoundTripper middleware: - Closed → Open after consecutive failure threshold - Open → HalfOpen after configurable duration - HalfOpen → Closed on success, back to Open on failure - Per-host tracking via sync.Map for independent endpoint isolation
This commit is contained in:
173
circuitbreaker/breaker.go
Normal file
173
circuitbreaker/breaker.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCircuitOpen is returned by Allow when the breaker is in the Open state.
|
||||||
|
var ErrCircuitOpen = errors.New("httpx: circuit breaker is open")
|
||||||
|
|
||||||
|
// State represents the current state of a circuit breaker.
|
||||||
|
type State int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateClosed State = iota // normal operation
|
||||||
|
StateOpen // failing, reject requests
|
||||||
|
StateHalfOpen // testing recovery
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns a human-readable name for the state.
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateClosed:
|
||||||
|
return "closed"
|
||||||
|
case StateOpen:
|
||||||
|
return "open"
|
||||||
|
case StateHalfOpen:
|
||||||
|
return "half-open"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Breaker implements a per-endpoint circuit breaker state machine.
|
||||||
|
//
|
||||||
|
// State transitions:
|
||||||
|
//
|
||||||
|
// Closed → Open: after failureThreshold consecutive failures
|
||||||
|
// Open → HalfOpen: after openDuration passes
|
||||||
|
// HalfOpen → Closed: on success
|
||||||
|
// HalfOpen → Open: on failure (timer resets)
|
||||||
|
type Breaker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
opts options
|
||||||
|
|
||||||
|
state State
|
||||||
|
failures int // consecutive failure count (Closed state)
|
||||||
|
openedAt time.Time
|
||||||
|
halfOpenCur int // current in-flight half-open requests
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBreaker creates a Breaker with the given options.
|
||||||
|
func NewBreaker(opts ...Option) *Breaker {
|
||||||
|
o := defaults()
|
||||||
|
for _, fn := range opts {
|
||||||
|
fn(&o)
|
||||||
|
}
|
||||||
|
return &Breaker{opts: o}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State returns the current state of the breaker.
|
||||||
|
func (b *Breaker) State() State {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
return b.stateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateLocked returns the effective state, promoting Open → HalfOpen when the
|
||||||
|
// open duration has elapsed. Caller must hold b.mu.
|
||||||
|
func (b *Breaker) stateLocked() State {
|
||||||
|
if b.state == StateOpen && time.Since(b.openedAt) >= b.opts.openDuration {
|
||||||
|
b.state = StateHalfOpen
|
||||||
|
b.halfOpenCur = 0
|
||||||
|
}
|
||||||
|
return b.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow checks whether a request is permitted. If allowed it returns a done
|
||||||
|
// callback that the caller MUST invoke with the result of the request. If the
|
||||||
|
// breaker is open, it returns ErrCircuitOpen.
|
||||||
|
func (b *Breaker) Allow() (done func(success bool), err error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
switch b.stateLocked() {
|
||||||
|
case StateClosed:
|
||||||
|
// always allow
|
||||||
|
case StateOpen:
|
||||||
|
return nil, ErrCircuitOpen
|
||||||
|
case StateHalfOpen:
|
||||||
|
if b.halfOpenCur >= b.opts.halfOpenMax {
|
||||||
|
return nil, ErrCircuitOpen
|
||||||
|
}
|
||||||
|
b.halfOpenCur++
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.doneFunc(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doneFunc returns the callback for a single in-flight request. Caller must
|
||||||
|
// hold b.mu when calling doneFunc, but the returned function acquires the lock
|
||||||
|
// itself.
|
||||||
|
func (b *Breaker) doneFunc() func(success bool) {
|
||||||
|
var once sync.Once
|
||||||
|
return func(success bool) {
|
||||||
|
once.Do(func() {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
b.record(success)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record processes the outcome of a single request. Caller must hold b.mu.
|
||||||
|
func (b *Breaker) record(success bool) {
|
||||||
|
switch b.state {
|
||||||
|
case StateClosed:
|
||||||
|
if success {
|
||||||
|
b.failures = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.failures++
|
||||||
|
if b.failures >= b.opts.failureThreshold {
|
||||||
|
b.tripLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
case StateHalfOpen:
|
||||||
|
b.halfOpenCur--
|
||||||
|
if success {
|
||||||
|
b.state = StateClosed
|
||||||
|
b.failures = 0
|
||||||
|
} else {
|
||||||
|
b.tripLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tripLocked transitions to the Open state and records the timestamp.
|
||||||
|
func (b *Breaker) tripLocked() {
|
||||||
|
b.state = StateOpen
|
||||||
|
b.openedAt = time.Now()
|
||||||
|
b.halfOpenCur = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns a middleware that applies per-host circuit breaking. It
|
||||||
|
// maintains an internal map of host → *Breaker so each target host is tracked
|
||||||
|
// independently.
|
||||||
|
func Transport(opts ...Option) middleware.Middleware {
|
||||||
|
var hosts sync.Map // map[string]*Breaker
|
||||||
|
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
host := req.URL.Host
|
||||||
|
|
||||||
|
val, _ := hosts.LoadOrStore(host, NewBreaker(opts...))
|
||||||
|
cb := val.(*Breaker)
|
||||||
|
|
||||||
|
done, err := cb.Allow()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, rtErr := next.RoundTrip(req)
|
||||||
|
done(rtErr == nil && resp != nil && resp.StatusCode < 500)
|
||||||
|
|
||||||
|
return resp, rtErr
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
249
circuitbreaker/breaker_test.go
Normal file
249
circuitbreaker/breaker_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errResponse(code int) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: code,
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_StartsInClosedState(t *testing.T) {
|
||||||
|
b := NewBreaker()
|
||||||
|
if s := b.State(); s != StateClosed {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_TransitionsToOpenAfterThreshold(t *testing.T) {
|
||||||
|
const threshold = 3
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(threshold),
|
||||||
|
WithOpenDuration(time.Hour), // long duration so it stays open
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < threshold; i++ {
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: Allow returned error: %v", i, err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_OpenRejectsRequests(t *testing.T) {
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Subsequent requests should be rejected.
|
||||||
|
_, err = b.Allow()
|
||||||
|
if !errors.Is(err, ErrCircuitOpen) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_TransitionsToHalfOpenAfterDuration(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the open duration to elapse.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateHalfOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateHalfOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Wait for half-open.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// A successful request in half-open should close the breaker.
|
||||||
|
done, err = b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(true)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateClosed {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||||
|
const openDuration = 50 * time.Millisecond
|
||||||
|
b := NewBreaker(
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithOpenDuration(openDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trip the breaker.
|
||||||
|
done, err := b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
// Wait for half-open.
|
||||||
|
time.Sleep(openDuration + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// A failed request in half-open should re-open the breaker.
|
||||||
|
done, err = b.Allow()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Allow in half-open returned error: %v", err)
|
||||||
|
}
|
||||||
|
done(false)
|
||||||
|
|
||||||
|
if s := b.State(); s != StateOpen {
|
||||||
|
t.Fatalf("state = %v, want %v", s, StateOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_PerHostBreakers(t *testing.T) {
|
||||||
|
const threshold = 2
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.URL.Host == "failing.example.com" {
|
||||||
|
return errResponse(http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithFailureThreshold(threshold),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)(base)
|
||||||
|
|
||||||
|
t.Run("failing host trips breaker", func(t *testing.T) {
|
||||||
|
for i := 0; i < threshold; i++ {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request to failing host should be rejected.
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://failing.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = rt.RoundTrip(req)
|
||||||
|
if !errors.Is(err, ErrCircuitOpen) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, ErrCircuitOpen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("healthy host is unaffected", func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://healthy.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransport_SuccessResetsFailures(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
callCount++
|
||||||
|
// Fail on odd calls, succeed on even.
|
||||||
|
if callCount%2 == 1 {
|
||||||
|
return errResponse(http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
rt := Transport(
|
||||||
|
WithFailureThreshold(3),
|
||||||
|
WithOpenDuration(time.Hour),
|
||||||
|
)(base)
|
||||||
|
|
||||||
|
// Alternate fail/success — should never trip because successes reset the
|
||||||
|
// consecutive failure counter.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://host.example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("iteration %d: unexpected error (circuit should not be open): %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
50
circuitbreaker/options.go
Normal file
50
circuitbreaker/options.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package circuitbreaker
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
failureThreshold int // consecutive failures to trip
|
||||||
|
openDuration time.Duration // how long to stay open before half-open
|
||||||
|
halfOpenMax int // max concurrent requests in half-open
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaults() options {
|
||||||
|
return options{
|
||||||
|
failureThreshold: 5,
|
||||||
|
openDuration: 30 * time.Second,
|
||||||
|
halfOpenMax: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Breaker.
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
// WithFailureThreshold sets the number of consecutive failures required to
|
||||||
|
// trip the breaker from Closed to Open. Default is 5.
|
||||||
|
func WithFailureThreshold(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n > 0 {
|
||||||
|
o.failureThreshold = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOpenDuration sets how long the breaker stays in the Open state before
|
||||||
|
// transitioning to HalfOpen. Default is 30s.
|
||||||
|
func WithOpenDuration(d time.Duration) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if d > 0 {
|
||||||
|
o.openDuration = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHalfOpenMax sets the maximum number of concurrent probe requests
|
||||||
|
// allowed while the breaker is in the HalfOpen state. Default is 1.
|
||||||
|
func WithHalfOpenMax(n int) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
if n > 0 {
|
||||||
|
o.halfOpenMax = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user