package dbx import ( "errors" "testing" "github.com/jackc/pgx/v5/pgconn" ) func TestIsRetryable(t *testing.T) { tests := []struct { name string err error want bool }{ {"nil", nil, false}, {"generic", errors.New("oops"), false}, {"serialization_failure", &pgconn.PgError{Code: "40001"}, true}, {"deadlock", &pgconn.PgError{Code: "40P01"}, true}, {"too_many_connections", &pgconn.PgError{Code: "53300"}, true}, {"connection_exception", &pgconn.PgError{Code: "08006"}, true}, {"constraint_violation", &pgconn.PgError{Code: "23505"}, false}, {"syntax_error", &pgconn.PgError{Code: "42601"}, false}, {"connection_refused", errors.New("connection refused"), true}, {"EOF", errors.New("unexpected EOF"), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsRetryable(tt.err); got != tt.want { t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want) } }) } } func TestIsConnectionError(t *testing.T) { tests := []struct { name string err error want bool }{ {"nil", nil, false}, {"pg_class_08", &pgconn.PgError{Code: "08003"}, true}, {"pg_class_23", &pgconn.PgError{Code: "23505"}, false}, {"conn_refused", errors.New("connection refused"), true}, {"broken_pipe", errors.New("write: broken pipe"), true}, {"generic", errors.New("something else"), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsConnectionError(tt.err); got != tt.want { t.Errorf("IsConnectionError(%v) = %v, want %v", tt.err, got, tt.want) } }) } } func TestIsConstraintViolation(t *testing.T) { if IsConstraintViolation(nil) { t.Error("expected false for nil") } if !IsConstraintViolation(&pgconn.PgError{Code: "23505"}) { t.Error("expected true for unique_violation") } if IsConstraintViolation(&pgconn.PgError{Code: "42601"}) { t.Error("expected false for syntax_error") } } func TestPgErrorCode(t *testing.T) { if code := PgErrorCode(errors.New("not pg")); code != "" { t.Errorf("expected empty, got %q", code) } if code := PgErrorCode(&pgconn.PgError{Code: "42P01"}); code != "42P01" { t.Errorf("expected 42P01, got %q", code) } } func TestRetryError(t *testing.T) { inner := errors.New("conn lost") err := newRetryError(3, inner) if !errors.Is(err, ErrRetryExhausted) { t.Error("should unwrap to ErrRetryExhausted") } if !errors.Is(err, inner) { t.Error("should unwrap to inner error") } if err.Error() == "" { t.Error("error string should not be empty") } }