169 lines
4.7 KiB
Go
169 lines
4.7 KiB
Go
package dbx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRetrier_Success(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
calls := 0
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
calls++
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Errorf("calls = %d, want 1", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_RetriesOnRetryableError(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
calls := 0
|
|
retryableErr := errors.New("connection refused")
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
calls++
|
|
if calls < 3 {
|
|
return retryableErr
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if calls != 3 {
|
|
t.Errorf("calls = %d, want 3", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_NonRetryableError(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
syntaxErr := errors.New("syntax problem")
|
|
calls := 0
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
calls++
|
|
return syntaxErr
|
|
})
|
|
if !errors.Is(err, syntaxErr) {
|
|
t.Errorf("expected syntax error, got %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_Exhausted(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
return errors.New("connection refused")
|
|
})
|
|
if !errors.Is(err, ErrRetryExhausted) {
|
|
t.Errorf("expected ErrRetryExhausted, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_FallbackToNextNode(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("replica-1", "master")
|
|
|
|
visited := []string{}
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
visited = append(visited, n.name)
|
|
if n.name == "replica-1" {
|
|
return errors.New("connection refused")
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(visited) < 2 || visited[1] != "master" {
|
|
t.Errorf("expected fallback to master, visited: %v", visited)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_ContextCanceled(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err := r.do(ctx, nodes, func(_ context.Context, n *Node) error {
|
|
return nil
|
|
})
|
|
if err != context.Canceled {
|
|
t.Errorf("expected context.Canceled, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_NoHealthyNodes(t *testing.T) {
|
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
nodes[0].healthy.Store(false)
|
|
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
t.Fatal("should not be called")
|
|
return nil
|
|
})
|
|
if !errors.Is(err, ErrNoHealthyNode) {
|
|
t.Errorf("expected ErrNoHealthyNode, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRetrier_CustomClassifier(t *testing.T) {
|
|
custom := func(err error) bool {
|
|
return err.Error() == "custom-retry"
|
|
}
|
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil)
|
|
nodes := makeTestNodes("n1")
|
|
|
|
calls := 0
|
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
|
calls++
|
|
if calls < 2 {
|
|
return errors.New("custom-retry")
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if calls != 2 {
|
|
t.Errorf("calls = %d, want 2", calls)
|
|
}
|
|
}
|
|
|
|
func TestBackoff(t *testing.T) {
|
|
r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil)
|
|
|
|
d0 := r.backoff(0)
|
|
if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond {
|
|
t.Errorf("backoff(0) = %v, expected ~100ms", d0)
|
|
}
|
|
|
|
d3 := r.backoff(3)
|
|
if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond {
|
|
t.Errorf("backoff(3) = %v, expected ~800ms", d3)
|
|
}
|
|
|
|
// Should cap at MaxDelay
|
|
d10 := r.backoff(10)
|
|
if d10 > 1250*time.Millisecond {
|
|
t.Errorf("backoff(10) = %v, should be capped near 1s", d10)
|
|
}
|
|
}
|