package dbx import ( "context" "errors" "testing" "time" ) func TestRetrier_Success(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") calls := 0 err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { calls++ return nil }) if err != nil { t.Fatalf("unexpected error: %v", err) } if calls != 1 { t.Errorf("calls = %d, want 1", calls) } } func TestRetrier_RetriesOnRetryableError(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") calls := 0 retryableErr := errors.New("connection refused") err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { calls++ if calls < 3 { return retryableErr } return nil }) if err != nil { t.Fatalf("unexpected error: %v", err) } if calls != 3 { t.Errorf("calls = %d, want 3", calls) } } func TestRetrier_NonRetryableError(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") syntaxErr := errors.New("syntax problem") calls := 0 err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { calls++ return syntaxErr }) if !errors.Is(err, syntaxErr) { t.Errorf("expected syntax error, got %v", err) } if calls != 1 { t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls) } } func TestRetrier_Exhausted(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { return errors.New("connection refused") }) if !errors.Is(err, ErrRetryExhausted) { t.Errorf("expected ErrRetryExhausted, got %v", err) } } func TestRetrier_FallbackToNextNode(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("replica-1", "master") visited := []string{} err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { visited = append(visited, n.name) if n.name == "replica-1" { return errors.New("connection refused") } return nil }) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(visited) < 2 || visited[1] != "master" { t.Errorf("expected fallback to master, visited: %v", visited) } } func TestRetrier_ContextCanceled(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") ctx, cancel := context.WithCancel(context.Background()) cancel() err := r.do(ctx, nodes, func(_ context.Context, n *Node) error { return nil }) if err != context.Canceled { t.Errorf("expected context.Canceled, got %v", err) } } func TestRetrier_NoHealthyNodes(t *testing.T) { r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) nodes := makeTestNodes("n1") nodes[0].healthy.Store(false) err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { t.Fatal("should not be called") return nil }) if !errors.Is(err, ErrNoHealthyNode) { t.Errorf("expected ErrNoHealthyNode, got %v", err) } } func TestRetrier_CustomClassifier(t *testing.T) { custom := func(err error) bool { return err.Error() == "custom-retry" } r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil) nodes := makeTestNodes("n1") calls := 0 err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { calls++ if calls < 2 { return errors.New("custom-retry") } return nil }) if err != nil { t.Fatalf("unexpected error: %v", err) } if calls != 2 { t.Errorf("calls = %d, want 2", calls) } } func TestBackoff(t *testing.T) { r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil) d0 := r.backoff(0) if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond { t.Errorf("backoff(0) = %v, expected ~100ms", d0) } d3 := r.backoff(3) if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond { t.Errorf("backoff(3) = %v, expected ~800ms", d3) } // Should cap at MaxDelay d10 := r.backoff(10) if d10 > 1250*time.Millisecond { t.Errorf("backoff(10) = %v, should be capped near 1s", d10) } }