From 8d322123a4456d13b439236c877fcd1743233c84 Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Fri, 20 Mar 2026 14:22:07 +0300 Subject: [PATCH] Add load balancer with round-robin, failover, and weighted strategies Implements balancer middleware with URL rewriting per-request: - RoundRobin, Failover, and WeightedRandom endpoint selection strategies - Background HealthChecker with configurable probe interval and path - Thread-safe health state tracking with sync.RWMutex --- balancer/balancer.go | 78 ++++++++++++++ balancer/balancer_test.go | 212 ++++++++++++++++++++++++++++++++++++++ balancer/failover.go | 17 +++ balancer/health.go | 162 +++++++++++++++++++++++++++++ balancer/options.go | 25 +++++ balancer/roundrobin.go | 21 ++++ balancer/weighted.go | 42 ++++++++ 7 files changed, 557 insertions(+) create mode 100644 balancer/balancer.go create mode 100644 balancer/balancer_test.go create mode 100644 balancer/failover.go create mode 100644 balancer/health.go create mode 100644 balancer/options.go create mode 100644 balancer/roundrobin.go create mode 100644 balancer/weighted.go diff --git a/balancer/balancer.go b/balancer/balancer.go new file mode 100644 index 0000000..d561670 --- /dev/null +++ b/balancer/balancer.go @@ -0,0 +1,78 @@ +package balancer + +import ( + "errors" + "net/http" + "net/url" + + "git.codelab.vc/pkg/httpx/middleware" +) + +// ErrNoHealthy is returned when no healthy endpoints are available. +var ErrNoHealthy = errors.New("httpx: no healthy endpoints available") + +// Endpoint represents a backend server that can handle requests. +type Endpoint struct { + URL string + Weight int + Meta map[string]string +} + +// Strategy selects an endpoint from the list of healthy endpoints. +type Strategy interface { + Next(healthy []Endpoint) (Endpoint, error) +} + +// Transport returns a middleware that load-balances requests across the +// provided endpoints using the configured strategy. +// +// For each request the middleware picks an endpoint via the strategy, +// replaces the request URL scheme and host with the endpoint's URL, +// and forwards the request to the underlying RoundTripper. +// +// If active health checking is enabled (WithHealthCheck), a background +// goroutine periodically probes endpoints. Otherwise all endpoints are +// assumed healthy. +func Transport(endpoints []Endpoint, opts ...Option) middleware.Middleware { + o := &options{ + strategy: RoundRobin(), + } + for _, opt := range opts { + opt(o) + } + + if o.healthChecker != nil { + o.healthChecker.Start(endpoints) + } + + return func(next http.RoundTripper) http.RoundTripper { + return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + healthy := endpoints + if o.healthChecker != nil { + healthy = o.healthChecker.Healthy(endpoints) + } + + if len(healthy) == 0 { + return nil, ErrNoHealthy + } + + ep, err := o.strategy.Next(healthy) + if err != nil { + return nil, err + } + + epURL, err := url.Parse(ep.URL) + if err != nil { + return nil, err + } + + // Clone the request URL and replace scheme+host with the endpoint. + r := req.Clone(req.Context()) + r.URL.Scheme = epURL.Scheme + r.URL.Host = epURL.Host + r.Host = epURL.Host + + return next.RoundTrip(r) + }) + } +} diff --git a/balancer/balancer_test.go b/balancer/balancer_test.go new file mode 100644 index 0000000..d659cb3 --- /dev/null +++ b/balancer/balancer_test.go @@ -0,0 +1,212 @@ +package balancer + +import ( + "io" + "math" + "net/http" + "strings" + "testing" + + "git.codelab.vc/pkg/httpx/middleware" +) + +func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { + return middleware.RoundTripperFunc(fn) +} + +func okResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + } +} + +func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) { + endpoints := []Endpoint{ + {URL: "https://backend1.example.com"}, + } + + var captured *http.Request + base := mockTransport(func(req *http.Request) (*http.Response, error) { + captured = req + return okResponse(), nil + }) + + rt := Transport(endpoints)(base) + + req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if captured == nil { + t.Fatal("base transport was not called") + } + if captured.URL.Scheme != "https" { + t.Errorf("scheme = %q, want %q", captured.URL.Scheme, "https") + } + if captured.URL.Host != "backend1.example.com" { + t.Errorf("host = %q, want %q", captured.URL.Host, "backend1.example.com") + } + if captured.URL.Path != "/api/v1/users" { + t.Errorf("path = %q, want %q", captured.URL.Path, "/api/v1/users") + } +} + +func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) { + var endpoints []Endpoint + base := mockTransport(func(req *http.Request) (*http.Response, error) { + t.Fatal("base transport should not be called") + return nil, nil + }) + + rt := Transport(endpoints)(base) + + req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil) + if err != nil { + t.Fatal(err) + } + + _, err = rt.RoundTrip(req) + if err != ErrNoHealthy { + t.Fatalf("err = %v, want %v", err, ErrNoHealthy) + } +} + +func TestRoundRobin_DistributesEvenly(t *testing.T) { + endpoints := []Endpoint{ + {URL: "https://a.example.com"}, + {URL: "https://b.example.com"}, + {URL: "https://c.example.com"}, + } + + rr := RoundRobin() + counts := make(map[string]int) + + const iterations = 300 + for i := 0; i < iterations; i++ { + ep, err := rr.Next(endpoints) + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + counts[ep.URL]++ + } + + expected := iterations / len(endpoints) + for _, ep := range endpoints { + got := counts[ep.URL] + if got != expected { + t.Errorf("endpoint %s: got %d calls, want %d", ep.URL, got, expected) + } + } +} + +func TestRoundRobin_ErrNoHealthy(t *testing.T) { + rr := RoundRobin() + _, err := rr.Next(nil) + if err != ErrNoHealthy { + t.Fatalf("err = %v, want %v", err, ErrNoHealthy) + } +} + +func TestFailover_AlwaysPicksFirst(t *testing.T) { + endpoints := []Endpoint{ + {URL: "https://primary.example.com"}, + {URL: "https://secondary.example.com"}, + {URL: "https://tertiary.example.com"}, + } + + fo := Failover() + + for i := 0; i < 10; i++ { + ep, err := fo.Next(endpoints) + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + if ep.URL != "https://primary.example.com" { + t.Errorf("iteration %d: got %q, want %q", i, ep.URL, "https://primary.example.com") + } + } +} + +func TestFailover_ErrNoHealthy(t *testing.T) { + fo := Failover() + _, err := fo.Next(nil) + if err != ErrNoHealthy { + t.Fatalf("err = %v, want %v", err, ErrNoHealthy) + } +} + +func TestWeightedRandom_RespectsWeights(t *testing.T) { + endpoints := []Endpoint{ + {URL: "https://heavy.example.com", Weight: 80}, + {URL: "https://light.example.com", Weight: 20}, + } + + wr := WeightedRandom() + counts := make(map[string]int) + + const iterations = 10000 + for i := 0; i < iterations; i++ { + ep, err := wr.Next(endpoints) + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + counts[ep.URL]++ + } + + totalWeight := 0 + for _, ep := range endpoints { + totalWeight += ep.Weight + } + + for _, ep := range endpoints { + got := float64(counts[ep.URL]) / float64(iterations) + want := float64(ep.Weight) / float64(totalWeight) + if math.Abs(got-want) > 0.05 { + t.Errorf("endpoint %s: got ratio %.3f, want ~%.3f (tolerance 0.05)", ep.URL, got, want) + } + } +} + +func TestWeightedRandom_DefaultWeightForZero(t *testing.T) { + endpoints := []Endpoint{ + {URL: "https://a.example.com", Weight: 0}, + {URL: "https://b.example.com", Weight: 0}, + } + + wr := WeightedRandom() + counts := make(map[string]int) + + const iterations = 1000 + for i := 0; i < iterations; i++ { + ep, err := wr.Next(endpoints) + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + counts[ep.URL]++ + } + + // With equal default weights, distribution should be roughly even. + for _, ep := range endpoints { + got := float64(counts[ep.URL]) / float64(iterations) + if math.Abs(got-0.5) > 0.1 { + t.Errorf("endpoint %s: got ratio %.3f, want ~0.5 (tolerance 0.1)", ep.URL, got) + } + } +} + +func TestWeightedRandom_ErrNoHealthy(t *testing.T) { + wr := WeightedRandom() + _, err := wr.Next(nil) + if err != ErrNoHealthy { + t.Fatalf("err = %v, want %v", err, ErrNoHealthy) + } +} diff --git a/balancer/failover.go b/balancer/failover.go new file mode 100644 index 0000000..8ec9930 --- /dev/null +++ b/balancer/failover.go @@ -0,0 +1,17 @@ +package balancer + +type failover struct{} + +// Failover returns a strategy that always picks the first healthy endpoint. +// If the primary endpoint is unhealthy, it falls back to the next available +// healthy endpoint in order. +func Failover() Strategy { + return &failover{} +} + +func (f *failover) Next(healthy []Endpoint) (Endpoint, error) { + if len(healthy) == 0 { + return Endpoint{}, ErrNoHealthy + } + return healthy[0], nil +} diff --git a/balancer/health.go b/balancer/health.go new file mode 100644 index 0000000..90db44a --- /dev/null +++ b/balancer/health.go @@ -0,0 +1,162 @@ +package balancer + +import ( + "context" + "net/http" + "sync" + "time" +) + +const ( + defaultHealthInterval = 10 * time.Second + defaultHealthPath = "/health" + defaultHealthTimeout = 5 * time.Second +) + +// HealthOption configures the HealthChecker. +type HealthOption func(*HealthChecker) + +// WithHealthInterval sets the interval between health check probes. +// Default is 10 seconds. +func WithHealthInterval(d time.Duration) HealthOption { + return func(h *HealthChecker) { + h.interval = d + } +} + +// WithHealthPath sets the HTTP path to probe for health checks. +// Default is "/health". +func WithHealthPath(path string) HealthOption { + return func(h *HealthChecker) { + h.path = path + } +} + +// WithHealthTimeout sets the timeout for each health check request. +// Default is 5 seconds. +func WithHealthTimeout(d time.Duration) HealthOption { + return func(h *HealthChecker) { + h.timeout = d + } +} + +// HealthChecker periodically probes endpoints to determine their health status. +type HealthChecker struct { + interval time.Duration + path string + timeout time.Duration + client *http.Client + + mu sync.RWMutex + status map[string]bool + cancel context.CancelFunc + stopped chan struct{} +} + +func newHealthChecker(opts ...HealthOption) *HealthChecker { + h := &HealthChecker{ + interval: defaultHealthInterval, + path: defaultHealthPath, + timeout: defaultHealthTimeout, + status: make(map[string]bool), + } + for _, opt := range opts { + opt(h) + } + h.client = &http.Client{ + Timeout: h.timeout, + } + return h +} + +// Start begins the background health checking loop for the given endpoints. +// All endpoints are initially considered healthy. +func (h *HealthChecker) Start(endpoints []Endpoint) { + h.mu.Lock() + for _, ep := range endpoints { + h.status[ep.URL] = true + } + h.mu.Unlock() + + ctx, cancel := context.WithCancel(context.Background()) + h.cancel = cancel + h.stopped = make(chan struct{}) + + go h.loop(ctx, endpoints) +} + +// Stop terminates the background health checking goroutine and waits for +// it to finish. +func (h *HealthChecker) Stop() { + if h.cancel != nil { + h.cancel() + <-h.stopped + } +} + +// IsHealthy reports whether the given endpoint is currently healthy. +func (h *HealthChecker) IsHealthy(ep Endpoint) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + healthy, ok := h.status[ep.URL] + if !ok { + return false + } + return healthy +} + +// Healthy returns the subset of endpoints that are currently healthy. +func (h *HealthChecker) Healthy(endpoints []Endpoint) []Endpoint { + h.mu.RLock() + defer h.mu.RUnlock() + + var result []Endpoint + for _, ep := range endpoints { + if h.status[ep.URL] { + result = append(result, ep) + } + } + return result +} + +func (h *HealthChecker) loop(ctx context.Context, endpoints []Endpoint) { + defer close(h.stopped) + + ticker := time.NewTicker(h.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + h.probe(ctx, endpoints) + } + } +} + +func (h *HealthChecker) probe(ctx context.Context, endpoints []Endpoint) { + for _, ep := range endpoints { + healthy := h.check(ctx, ep) + + h.mu.Lock() + h.status[ep.URL] = healthy + h.mu.Unlock() + } +} + +func (h *HealthChecker) check(ctx context.Context, ep Endpoint) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ep.URL+h.path, nil) + if err != nil { + return false + } + + resp, err := h.client.Do(req) + if err != nil { + return false + } + resp.Body.Close() + + return resp.StatusCode >= 200 && resp.StatusCode < 300 +} diff --git a/balancer/options.go b/balancer/options.go new file mode 100644 index 0000000..d818545 --- /dev/null +++ b/balancer/options.go @@ -0,0 +1,25 @@ +package balancer + +// options holds configuration for the load balancer transport. +type options struct { + strategy Strategy // default RoundRobin + healthChecker *HealthChecker // optional +} + +// Option configures the load balancer transport. +type Option func(*options) + +// WithStrategy sets the endpoint selection strategy. +// If not specified, RoundRobin is used. +func WithStrategy(s Strategy) Option { + return func(o *options) { + o.strategy = s + } +} + +// WithHealthCheck enables active health checking of endpoints. +func WithHealthCheck(opts ...HealthOption) Option { + return func(o *options) { + o.healthChecker = newHealthChecker(opts...) + } +} diff --git a/balancer/roundrobin.go b/balancer/roundrobin.go new file mode 100644 index 0000000..819e0f0 --- /dev/null +++ b/balancer/roundrobin.go @@ -0,0 +1,21 @@ +package balancer + +import "sync/atomic" + +type roundRobin struct { + counter atomic.Uint64 +} + +// RoundRobin returns a strategy that cycles through healthy endpoints +// sequentially using an atomic counter. +func RoundRobin() Strategy { + return &roundRobin{} +} + +func (r *roundRobin) Next(healthy []Endpoint) (Endpoint, error) { + if len(healthy) == 0 { + return Endpoint{}, ErrNoHealthy + } + idx := r.counter.Add(1) - 1 + return healthy[idx%uint64(len(healthy))], nil +} diff --git a/balancer/weighted.go b/balancer/weighted.go new file mode 100644 index 0000000..7472f56 --- /dev/null +++ b/balancer/weighted.go @@ -0,0 +1,42 @@ +package balancer + +import "math/rand/v2" + +type weightedRandom struct{} + +// WeightedRandom returns a strategy that selects endpoints randomly, +// weighted by each endpoint's Weight field. Endpoints with Weight <= 0 +// are treated as having a weight of 1. +func WeightedRandom() Strategy { + return &weightedRandom{} +} + +func (w *weightedRandom) Next(healthy []Endpoint) (Endpoint, error) { + if len(healthy) == 0 { + return Endpoint{}, ErrNoHealthy + } + + totalWeight := 0 + for _, ep := range healthy { + weight := ep.Weight + if weight <= 0 { + weight = 1 + } + totalWeight += weight + } + + r := rand.IntN(totalWeight) + for _, ep := range healthy { + weight := ep.Weight + if weight <= 0 { + weight = 1 + } + r -= weight + if r < 0 { + return ep, nil + } + } + + // Should never reach here, but return last endpoint as a safeguard. + return healthy[len(healthy)-1], nil +}