diff --git a/balancer/balancer.go b/balancer/balancer.go index 29f5e53..d50b11f 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -55,12 +55,19 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl opt(o) } - // Pre-parse endpoint URLs once at construction time. + // Pre-parse endpoint URLs once at construction time. A malformed URL is a + // configuration error: rather than panicking (which would crash the host + // application, often at startup from external config), we capture the + // error and surface it from the transport on first use. parsed := make(map[string]*url.URL, len(endpoints)) + var parseErr error for _, ep := range endpoints { u, err := url.Parse(ep.URL) if err != nil { - panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err)) + if parseErr == nil { + parseErr = fmt.Errorf("balancer: invalid endpoint URL %q: %w", ep.URL, err) + } + continue } parsed[ep.URL] = u } @@ -73,6 +80,10 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl return func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if parseErr != nil { + return nil, parseErr + } + healthy := endpoints if o.healthChecker != nil { healthy = o.healthChecker.Healthy(endpoints) diff --git a/balancer/balancer_test.go b/balancer/balancer_test.go index 2b05ffc..68f8fd0 100644 --- a/balancer/balancer_test.go +++ b/balancer/balancer_test.go @@ -61,6 +61,27 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) { } } +func TestTransport_InvalidEndpointURLReturnsError(t *testing.T) { + base := mockTransport(func(req *http.Request) (*http.Response, error) { + t.Fatal("base transport should not be reached for an invalid endpoint") + return nil, nil + }) + + // A malformed URL must not panic; the error surfaces on first use. + mw, closer := Transport([]Endpoint{{URL: "://missing-scheme"}}) + defer closer.Close() + rt := mw(base) + + req, err := http.NewRequest(http.MethodGet, "https://original.example.com/", nil) + if err != nil { + t.Fatal(err) + } + + if _, err := rt.RoundTrip(req); err == nil { + t.Fatal("expected an error for invalid endpoint URL, got nil") + } +} + func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) { var endpoints []Endpoint base := mockTransport(func(req *http.Request) (*http.Response, error) { diff --git a/balancer/health.go b/balancer/health.go index 4ffc7dd..55b7ed5 100644 --- a/balancer/health.go +++ b/balancer/health.go @@ -10,7 +10,7 @@ import ( const ( defaultHealthInterval = 10 * time.Second - defaultHealthPath = "/health" + defaultHealthPath = "/healthz" defaultHealthTimeout = 5 * time.Second ) diff --git a/balancer/health_test.go b/balancer/health_test.go new file mode 100644 index 0000000..e29e470 --- /dev/null +++ b/balancer/health_test.go @@ -0,0 +1,99 @@ +package balancer + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestHealthChecker_InitialProbeClassifiesEndpoints(t *testing.T) { + healthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer healthy.Close() + + unhealthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer unhealthy.Close() + + eps := []Endpoint{{URL: healthy.URL}, {URL: unhealthy.URL}} + + hc := newHealthChecker() + hc.Start(eps) // runs an initial synchronous probe + defer hc.Stop() + + if !hc.IsHealthy(eps[0]) { + t.Errorf("healthy endpoint reported unhealthy") + } + if hc.IsHealthy(eps[1]) { + t.Errorf("unhealthy endpoint reported healthy") + } + + got := hc.Healthy(eps) + if len(got) != 1 || got[0].URL != healthy.URL { + t.Errorf("Healthy() = %v, want only %s", got, healthy.URL) + } +} + +func TestHealthChecker_DetectsRecovery(t *testing.T) { + var up atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if up.Load() { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + eps := []Endpoint{{URL: srv.URL}} + + hc := newHealthChecker() + hc.Start(eps) + defer hc.Stop() + + if hc.IsHealthy(eps[0]) { + t.Fatalf("endpoint should start unhealthy") + } + + // Recover the backend and force a deterministic re-probe. + up.Store(true) + hc.probe(context.Background(), eps) + + if !hc.IsHealthy(eps[0]) { + t.Fatalf("endpoint should be healthy after recovery") + } +} + +func TestHealthChecker_StopTerminatesLoop(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + hc := newHealthChecker(WithHealthInterval(time.Millisecond)) + hc.Start([]Endpoint{{URL: srv.URL}}) + + done := make(chan struct{}) + go func() { + hc.Stop() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Stop did not return within 2s — loop goroutine leaked") + } +} + +func TestHealthChecker_UnknownEndpointIsUnhealthy(t *testing.T) { + hc := newHealthChecker() + if hc.IsHealthy(Endpoint{URL: "http://never-probed.example"}) { + t.Error("unknown endpoint should be reported unhealthy") + } +}