From 01478be0dce17f887d1995ca999bbd394b302b4e Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Sat, 23 May 2026 13:47:33 +0300 Subject: [PATCH] Replace balancer panic with deferred error; test HealthChecker A malformed endpoint URL panicked inside Transport, crashing the host app (often at startup from external config). Capture the parse error and surface it from the transport on first use instead. Add the previously untested HealthChecker coverage (initial probe, recovery, Stop termination, unknown endpoint), raising balancer coverage from ~41% to ~87%. Default the health probe path to /healthz to match this library's own server. --- balancer/balancer.go | 15 +++++- balancer/balancer_test.go | 21 +++++++++ balancer/health.go | 2 +- balancer/health_test.go | 99 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 balancer/health_test.go diff --git a/balancer/balancer.go b/balancer/balancer.go index 29f5e53..d50b11f 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -55,12 +55,19 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl opt(o) } - // Pre-parse endpoint URLs once at construction time. + // Pre-parse endpoint URLs once at construction time. A malformed URL is a + // configuration error: rather than panicking (which would crash the host + // application, often at startup from external config), we capture the + // error and surface it from the transport on first use. parsed := make(map[string]*url.URL, len(endpoints)) + var parseErr error for _, ep := range endpoints { u, err := url.Parse(ep.URL) if err != nil { - panic(fmt.Sprintf("balancer: invalid endpoint URL %q: %v", ep.URL, err)) + if parseErr == nil { + parseErr = fmt.Errorf("balancer: invalid endpoint URL %q: %w", ep.URL, err) + } + continue } parsed[ep.URL] = u } @@ -73,6 +80,10 @@ func Transport(endpoints []Endpoint, opts ...Option) (middleware.Middleware, *Cl return func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if parseErr != nil { + return nil, parseErr + } + healthy := endpoints if o.healthChecker != nil { healthy = o.healthChecker.Healthy(endpoints) diff --git a/balancer/balancer_test.go b/balancer/balancer_test.go index 2b05ffc..68f8fd0 100644 --- a/balancer/balancer_test.go +++ b/balancer/balancer_test.go @@ -61,6 +61,27 @@ func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) { } } +func TestTransport_InvalidEndpointURLReturnsError(t *testing.T) { + base := mockTransport(func(req *http.Request) (*http.Response, error) { + t.Fatal("base transport should not be reached for an invalid endpoint") + return nil, nil + }) + + // A malformed URL must not panic; the error surfaces on first use. + mw, closer := Transport([]Endpoint{{URL: "://missing-scheme"}}) + defer closer.Close() + rt := mw(base) + + req, err := http.NewRequest(http.MethodGet, "https://original.example.com/", nil) + if err != nil { + t.Fatal(err) + } + + if _, err := rt.RoundTrip(req); err == nil { + t.Fatal("expected an error for invalid endpoint URL, got nil") + } +} + func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) { var endpoints []Endpoint base := mockTransport(func(req *http.Request) (*http.Response, error) { diff --git a/balancer/health.go b/balancer/health.go index 4ffc7dd..55b7ed5 100644 --- a/balancer/health.go +++ b/balancer/health.go @@ -10,7 +10,7 @@ import ( const ( defaultHealthInterval = 10 * time.Second - defaultHealthPath = "/health" + defaultHealthPath = "/healthz" defaultHealthTimeout = 5 * time.Second ) diff --git a/balancer/health_test.go b/balancer/health_test.go new file mode 100644 index 0000000..e29e470 --- /dev/null +++ b/balancer/health_test.go @@ -0,0 +1,99 @@ +package balancer + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestHealthChecker_InitialProbeClassifiesEndpoints(t *testing.T) { + healthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer healthy.Close() + + unhealthy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer unhealthy.Close() + + eps := []Endpoint{{URL: healthy.URL}, {URL: unhealthy.URL}} + + hc := newHealthChecker() + hc.Start(eps) // runs an initial synchronous probe + defer hc.Stop() + + if !hc.IsHealthy(eps[0]) { + t.Errorf("healthy endpoint reported unhealthy") + } + if hc.IsHealthy(eps[1]) { + t.Errorf("unhealthy endpoint reported healthy") + } + + got := hc.Healthy(eps) + if len(got) != 1 || got[0].URL != healthy.URL { + t.Errorf("Healthy() = %v, want only %s", got, healthy.URL) + } +} + +func TestHealthChecker_DetectsRecovery(t *testing.T) { + var up atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if up.Load() { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + eps := []Endpoint{{URL: srv.URL}} + + hc := newHealthChecker() + hc.Start(eps) + defer hc.Stop() + + if hc.IsHealthy(eps[0]) { + t.Fatalf("endpoint should start unhealthy") + } + + // Recover the backend and force a deterministic re-probe. + up.Store(true) + hc.probe(context.Background(), eps) + + if !hc.IsHealthy(eps[0]) { + t.Fatalf("endpoint should be healthy after recovery") + } +} + +func TestHealthChecker_StopTerminatesLoop(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + hc := newHealthChecker(WithHealthInterval(time.Millisecond)) + hc.Start([]Endpoint{{URL: srv.URL}}) + + done := make(chan struct{}) + go func() { + hc.Stop() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Stop did not return within 2s — loop goroutine leaked") + } +} + +func TestHealthChecker_UnknownEndpointIsUnhealthy(t *testing.T) { + hc := newHealthChecker() + if hc.IsHealthy(Endpoint{URL: "http://never-probed.example"}) { + t.Error("unknown endpoint should be reported unhealthy") + } +}