package balancer import ( "io" "math" "net/http" "strings" "testing" "git.codelab.vc/pkg/httpx/middleware" ) func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { return middleware.RoundTripperFunc(fn) } func okResponse() *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header), } } func TestTransport_PicksEndpointAndReplacesURL(t *testing.T) { endpoints := []Endpoint{ {URL: "https://backend1.example.com"}, } var captured *http.Request base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req return okResponse(), nil }) mw, _ := Transport(endpoints) rt := mw(base) req, err := http.NewRequest(http.MethodGet, "https://original.example.com/api/v1/users", nil) if err != nil { t.Fatal(err) } resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if captured == nil { t.Fatal("base transport was not called") } if captured.URL.Scheme != "https" { t.Errorf("scheme = %q, want %q", captured.URL.Scheme, "https") } if captured.URL.Host != "backend1.example.com" { t.Errorf("host = %q, want %q", captured.URL.Host, "backend1.example.com") } if captured.URL.Path != "/api/v1/users" { t.Errorf("path = %q, want %q", captured.URL.Path, "/api/v1/users") } } func TestTransport_ErrNoHealthyWhenNoEndpoints(t *testing.T) { var endpoints []Endpoint base := mockTransport(func(req *http.Request) (*http.Response, error) { t.Fatal("base transport should not be called") return nil, nil }) mw, _ := Transport(endpoints) rt := mw(base) req, err := http.NewRequest(http.MethodGet, "https://example.com/test", nil) if err != nil { t.Fatal(err) } _, err = rt.RoundTrip(req) if err != ErrNoHealthy { t.Fatalf("err = %v, want %v", err, ErrNoHealthy) } } func TestRoundRobin_DistributesEvenly(t *testing.T) { endpoints := []Endpoint{ {URL: "https://a.example.com"}, {URL: "https://b.example.com"}, {URL: "https://c.example.com"}, } rr := RoundRobin() counts := make(map[string]int) const iterations = 300 for i := 0; i < iterations; i++ { ep, err := rr.Next(endpoints) if err != nil { t.Fatalf("iteration %d: unexpected error: %v", i, err) } counts[ep.URL]++ } expected := iterations / len(endpoints) for _, ep := range endpoints { got := counts[ep.URL] if got != expected { t.Errorf("endpoint %s: got %d calls, want %d", ep.URL, got, expected) } } } func TestRoundRobin_ErrNoHealthy(t *testing.T) { rr := RoundRobin() _, err := rr.Next(nil) if err != ErrNoHealthy { t.Fatalf("err = %v, want %v", err, ErrNoHealthy) } } func TestFailover_AlwaysPicksFirst(t *testing.T) { endpoints := []Endpoint{ {URL: "https://primary.example.com"}, {URL: "https://secondary.example.com"}, {URL: "https://tertiary.example.com"}, } fo := Failover() for i := 0; i < 10; i++ { ep, err := fo.Next(endpoints) if err != nil { t.Fatalf("iteration %d: unexpected error: %v", i, err) } if ep.URL != "https://primary.example.com" { t.Errorf("iteration %d: got %q, want %q", i, ep.URL, "https://primary.example.com") } } } func TestFailover_ErrNoHealthy(t *testing.T) { fo := Failover() _, err := fo.Next(nil) if err != ErrNoHealthy { t.Fatalf("err = %v, want %v", err, ErrNoHealthy) } } func TestWeightedRandom_RespectsWeights(t *testing.T) { endpoints := []Endpoint{ {URL: "https://heavy.example.com", Weight: 80}, {URL: "https://light.example.com", Weight: 20}, } wr := WeightedRandom() counts := make(map[string]int) const iterations = 10000 for i := 0; i < iterations; i++ { ep, err := wr.Next(endpoints) if err != nil { t.Fatalf("iteration %d: unexpected error: %v", i, err) } counts[ep.URL]++ } totalWeight := 0 for _, ep := range endpoints { totalWeight += ep.Weight } for _, ep := range endpoints { got := float64(counts[ep.URL]) / float64(iterations) want := float64(ep.Weight) / float64(totalWeight) if math.Abs(got-want) > 0.05 { t.Errorf("endpoint %s: got ratio %.3f, want ~%.3f (tolerance 0.05)", ep.URL, got, want) } } } func TestWeightedRandom_DefaultWeightForZero(t *testing.T) { endpoints := []Endpoint{ {URL: "https://a.example.com", Weight: 0}, {URL: "https://b.example.com", Weight: 0}, } wr := WeightedRandom() counts := make(map[string]int) const iterations = 1000 for i := 0; i < iterations; i++ { ep, err := wr.Next(endpoints) if err != nil { t.Fatalf("iteration %d: unexpected error: %v", i, err) } counts[ep.URL]++ } // With equal default weights, distribution should be roughly even. for _, ep := range endpoints { got := float64(counts[ep.URL]) / float64(iterations) if math.Abs(got-0.5) > 0.1 { t.Errorf("endpoint %s: got ratio %.3f, want ~0.5 (tolerance 0.1)", ep.URL, got) } } } func TestWeightedRandom_ErrNoHealthy(t *testing.T) { wr := WeightedRandom() _, err := wr.Next(nil) if err != ErrNoHealthy { t.Fatalf("err = %v, want %v", err, ErrNoHealthy) } }