Files
dbx/retry_test.go

169 lines
4.7 KiB
Go

package dbx
import (
"context"
"errors"
"testing"
"time"
)
func TestRetrier_Success(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1", calls)
}
}
func TestRetrier_RetriesOnRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
retryableErr := errors.New("connection refused")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 3 {
return retryableErr
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 3 {
t.Errorf("calls = %d, want 3", calls)
}
}
func TestRetrier_NonRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
syntaxErr := errors.New("syntax problem")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return syntaxErr
})
if !errors.Is(err, syntaxErr) {
t.Errorf("expected syntax error, got %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls)
}
}
func TestRetrier_Exhausted(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
return errors.New("connection refused")
})
if !errors.Is(err, ErrRetryExhausted) {
t.Errorf("expected ErrRetryExhausted, got %v", err)
}
}
func TestRetrier_FallbackToNextNode(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("replica-1", "master")
visited := []string{}
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
visited = append(visited, n.name)
if n.name == "replica-1" {
return errors.New("connection refused")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(visited) < 2 || visited[1] != "master" {
t.Errorf("expected fallback to master, visited: %v", visited)
}
}
func TestRetrier_ContextCanceled(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := r.do(ctx, nodes, func(_ context.Context, n *Node) error {
return nil
})
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestRetrier_NoHealthyNodes(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
nodes[0].healthy.Store(false)
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
t.Fatal("should not be called")
return nil
})
if !errors.Is(err, ErrNoHealthyNode) {
t.Errorf("expected ErrNoHealthyNode, got %v", err)
}
}
func TestRetrier_CustomClassifier(t *testing.T) {
custom := func(err error) bool {
return err.Error() == "custom-retry"
}
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 2 {
return errors.New("custom-retry")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 2 {
t.Errorf("calls = %d, want 2", calls)
}
}
func TestBackoff(t *testing.T) {
r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil)
d0 := r.backoff(0)
if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond {
t.Errorf("backoff(0) = %v, expected ~100ms", d0)
}
d3 := r.backoff(3)
if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond {
t.Errorf("backoff(3) = %v, expected ~800ms", d3)
}
// Should cap at MaxDelay
d10 := r.backoff(10)
if d10 > 1250*time.Millisecond {
t.Errorf("backoff(10) = %v, should be capped near 1s", d10)
}
}