package balancer import ( "errors" "fmt" "net/http" "net/url" "git.codelab.vc/pkg/httpx/middleware" ) // ErrNoHealthy is returned when no healthy endpoints are available. var ErrNoHealthy = errors.New("httpx: no healthy endpoints available") // Endpoint represents a backend server that can handle requests. type Endpoint struct { URL string Weight int Meta map[string]string } // Strategy selects an endpoint from the list of healthy endpoints. type Strategy interface { Next(healthy []Endpoint) (Endpoint, error) } // Closer can be used to shut down resources associated with a balancer // transport (e.g. background health checker goroutines). type Closer struct { healthChecker *HealthChecker } // Close stops background goroutines. Safe to call multiple times. func (c *Closer) Close() { if c.healthChecker != nil { c.healthChecker.Stop() } } // Transport returns a middleware that load-balances requests across the // provided endpoints using the configured strategy. // // For each request the middleware picks an endpoint via the strategy, // replaces the request URL scheme and host with the endpoint's URL, // and forwards the request to the underlying RoundTripper. // // If active health checking is enabled (WithHealthCheck), a background // goroutine periodically probes endpoints. Otherwise all endpoints are // assumed healthy. func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Closer) { o := &options{ strategy: RoundRobin(), } for _, opt := range opts { opt(o) } // Pre-parse endpoint URLs once at construction time. parsed := make(map[string]*url.URL, len(endpoints)) for _, ep := range endpoints { u, err := url.Parse(ep.URL) if err != nil { panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err)) } parsed[ep.URL] = u } if o.healthChecker != nil { o.healthChecker.Start(endpoints) } closer := &Closer{healthChecker: o.healthChecker} return func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { healthy := endpoints if o.healthChecker != nil { healthy = o.healthChecker.Healthy(endpoints) } if len(healthy) == 0 { return nil, ErrNoHealthy } ep, err := o.strategy.Next(healthy) if err != nil { return nil, err } epURL := parsed[ep.URL] // Shallow-copy request and URL to avoid mutating the original, // without the expense of req.Clone's deep header copy. r := *req u := *req.URL r.URL = &u r.URL.Scheme = epURL.Scheme r.URL.Host = epURL.Host r.Host = epURL.Host return next.RoundTrip(&r) }) }, closer }