From 62df3a2eb342a693aa84c3ac14cdb36303179521 Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Mon, 23 Mar 2026 00:01:15 +0300 Subject: [PATCH] Add dbx library: PostgreSQL cluster with master/replica routing, retry, health checking Co-Authored-By: Claude Opus 4.6 (1M context) --- balancer.go | 35 +++++++ balancer_test.go | 73 ++++++++++++++ cluster.go | 247 +++++++++++++++++++++++++++++++++++++++++++++ config.go | 88 ++++++++++++++++ config_test.go | 72 +++++++++++++ dbx.go | 54 ++++++++++ dbxtest/dbxtest.go | 82 +++++++++++++++ doc.go | 49 +++++++++ errors.go | 94 +++++++++++++++++ errors_test.go | 92 +++++++++++++++++ go.mod | 13 +++ go.sum | 26 +++++ health.go | 90 +++++++++++++++++ health_test.go | 56 ++++++++++ node.go | 110 ++++++++++++++++++++ options.go | 39 +++++++ options_test.go | 35 +++++++ retry.go | 98 ++++++++++++++++++ retry_test.go | 168 ++++++++++++++++++++++++++++++ tx.go | 61 +++++++++++ tx_test.go | 25 +++++ 21 files changed, 1607 insertions(+) create mode 100644 balancer.go create mode 100644 balancer_test.go create mode 100644 cluster.go create mode 100644 config.go create mode 100644 config_test.go create mode 100644 dbx.go create mode 100644 dbxtest/dbxtest.go create mode 100644 doc.go create mode 100644 errors.go create mode 100644 errors_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 health.go create mode 100644 health_test.go create mode 100644 node.go create mode 100644 options.go create mode 100644 options_test.go create mode 100644 retry.go create mode 100644 retry_test.go create mode 100644 tx.go create mode 100644 tx_test.go diff --git a/balancer.go b/balancer.go new file mode 100644 index 0000000..2a8f463 --- /dev/null +++ b/balancer.go @@ -0,0 +1,35 @@ +package dbx + +import "sync/atomic" + +// Balancer selects the next node from a list. +// It must return nil if no suitable node is available. +type Balancer interface { + Next(nodes []*Node) *Node +} + +// RoundRobinBalancer distributes load evenly across healthy nodes. +type RoundRobinBalancer struct { + counter atomic.Uint64 +} + +// NewRoundRobinBalancer creates a new round-robin balancer. +func NewRoundRobinBalancer() *RoundRobinBalancer { + return &RoundRobinBalancer{} +} + +// Next returns the next healthy node, or nil if none are healthy. +func (b *RoundRobinBalancer) Next(nodes []*Node) *Node { + n := len(nodes) + if n == 0 { + return nil + } + idx := b.counter.Add(1) + for i := 0; i < n; i++ { + node := nodes[(int(idx)+i)%n] + if node.IsHealthy() { + return node + } + } + return nil +} diff --git a/balancer_test.go b/balancer_test.go new file mode 100644 index 0000000..3c8bac6 --- /dev/null +++ b/balancer_test.go @@ -0,0 +1,73 @@ +package dbx + +import ( + "testing" +) + +func TestRoundRobinBalancer_Empty(t *testing.T) { + b := NewRoundRobinBalancer() + if n := b.Next(nil); n != nil { + t.Errorf("expected nil for empty slice, got %v", n) + } +} + +func TestRoundRobinBalancer_AllHealthy(t *testing.T) { + b := NewRoundRobinBalancer() + nodes := makeTestNodes("a", "b", "c") + + seen := map[string]int{} + for range 30 { + n := b.Next(nodes) + if n == nil { + t.Fatal("unexpected nil") + } + seen[n.name]++ + } + + if len(seen) != 3 { + t.Errorf("expected 3 distinct nodes, got %d", len(seen)) + } + for name, count := range seen { + if count != 10 { + t.Errorf("node %s hit %d times, expected 10", name, count) + } + } +} + +func TestRoundRobinBalancer_SkipsUnhealthy(t *testing.T) { + b := NewRoundRobinBalancer() + nodes := makeTestNodes("a", "b", "c") + nodes[1].healthy.Store(false) // b is down + + for range 20 { + n := b.Next(nodes) + if n == nil { + t.Fatal("unexpected nil") + } + if n.name == "b" { + t.Error("should not return unhealthy node b") + } + } +} + +func TestRoundRobinBalancer_AllUnhealthy(t *testing.T) { + b := NewRoundRobinBalancer() + nodes := makeTestNodes("a", "b") + for _, n := range nodes { + n.healthy.Store(false) + } + + if n := b.Next(nodes); n != nil { + t.Errorf("expected nil when all unhealthy, got %v", n.name) + } +} + +func makeTestNodes(names ...string) []*Node { + nodes := make([]*Node, len(names)) + for i, name := range names { + n := &Node{name: name} + n.healthy.Store(true) + nodes[i] = n + } + return nodes +} diff --git a/cluster.go b/cluster.go new file mode 100644 index 0000000..145401a --- /dev/null +++ b/cluster.go @@ -0,0 +1,247 @@ +package dbx + +import ( + "context" + "sync/atomic" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// Cluster manages a master node and optional replicas. +// Write operations go to master; read operations are distributed across replicas +// with automatic fallback to master. +type Cluster struct { + master *Node + replicas []*Node + all []*Node // master + replicas for health checker + + balancer Balancer + retrier *retrier + health *healthChecker + logger Logger + metrics *MetricsHook + closed atomic.Bool +} + +// NewCluster creates a Cluster, connecting to all configured nodes. +func NewCluster(ctx context.Context, cfg Config) (*Cluster, error) { + cfg.defaults() + + master, err := connectNode(ctx, cfg.Master) + if err != nil { + return nil, err + } + + replicas := make([]*Node, 0, len(cfg.Replicas)) + for _, rc := range cfg.Replicas { + r, err := connectNode(ctx, rc) + if err != nil { + // close already-connected nodes + master.Close() + for _, opened := range replicas { + opened.Close() + } + return nil, err + } + replicas = append(replicas, r) + } + + all := make([]*Node, 0, 1+len(replicas)) + all = append(all, master) + all = append(all, replicas...) + + c := &Cluster{ + master: master, + replicas: replicas, + all: all, + balancer: NewRoundRobinBalancer(), + retrier: newRetrier(cfg.Retry, cfg.Logger, cfg.Metrics), + logger: cfg.Logger, + metrics: cfg.Metrics, + } + + c.health = newHealthChecker(all, cfg.HealthCheck, cfg.Logger, cfg.Metrics) + c.health.start() + + return c, nil +} + +// Master returns the master node as a DB. +func (c *Cluster) Master() DB { return c.master } + +// Replica returns a healthy replica as a DB, or falls back to master. +func (c *Cluster) Replica() DB { + if n := c.balancer.Next(c.replicas); n != nil { + return n + } + return c.master +} + +// Close shuts down the health checker and closes all connection pools. +func (c *Cluster) Close() { + if !c.closed.CompareAndSwap(false, true) { + return + } + c.health.shutdown() + for _, n := range c.all { + n.Close() + } +} + +// Ping checks connectivity to the master node. +func (c *Cluster) Ping(ctx context.Context) error { + if c.closed.Load() { + return ErrClusterClosed + } + return c.master.Ping(ctx) +} + +// --- Write operations (always master) --- + +func (c *Cluster) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + if c.closed.Load() { + return pgconn.CommandTag{}, ErrClusterClosed + } + var tag pgconn.CommandTag + err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { + start := time.Now() + c.queryStart(ctx, n.name, sql) + var e error + tag, e = n.Exec(ctx, sql, args...) + c.queryEnd(ctx, n.name, sql, e, time.Since(start)) + return e + }) + return tag, err +} + +func (c *Cluster) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + if c.closed.Load() { + return nil, ErrClusterClosed + } + var rows pgx.Rows + err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { + start := time.Now() + c.queryStart(ctx, n.name, sql) + var e error + rows, e = n.Query(ctx, sql, args...) + c.queryEnd(ctx, n.name, sql, e, time.Since(start)) + return e + }) + return rows, err +} + +func (c *Cluster) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + if c.closed.Load() { + return errRow{err: ErrClusterClosed} + } + var row pgx.Row + _ = c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { + start := time.Now() + c.queryStart(ctx, n.name, sql) + row = n.QueryRow(ctx, sql, args...) + c.queryEnd(ctx, n.name, sql, nil, time.Since(start)) + return nil + }) + return row +} + +func (c *Cluster) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return c.master.SendBatch(ctx, b) +} + +func (c *Cluster) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + if c.closed.Load() { + return 0, ErrClusterClosed + } + return c.master.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +func (c *Cluster) Begin(ctx context.Context) (pgx.Tx, error) { + if c.closed.Load() { + return nil, ErrClusterClosed + } + return c.master.Begin(ctx) +} + +func (c *Cluster) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + if c.closed.Load() { + return nil, ErrClusterClosed + } + return c.master.BeginTx(ctx, txOptions) +} + +// --- Read operations (replicas with master fallback) --- + +// ReadQuery executes a read query on a replica, falling back to master if no replicas are healthy. +func (c *Cluster) ReadQuery(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + if c.closed.Load() { + return nil, ErrClusterClosed + } + nodes := c.readNodes() + var rows pgx.Rows + err := c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error { + start := time.Now() + c.queryStart(ctx, n.name, sql) + var e error + rows, e = n.Query(ctx, sql, args...) + c.queryEnd(ctx, n.name, sql, e, time.Since(start)) + return e + }) + return rows, err +} + +// ReadQueryRow executes a read query returning a single row, using replicas with master fallback. +func (c *Cluster) ReadQueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + if c.closed.Load() { + return errRow{err: ErrClusterClosed} + } + nodes := c.readNodes() + var row pgx.Row + _ = c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error { + start := time.Now() + c.queryStart(ctx, n.name, sql) + row = n.QueryRow(ctx, sql, args...) + c.queryEnd(ctx, n.name, sql, nil, time.Since(start)) + return nil + }) + return row +} + +// readNodes returns replicas followed by master for fallback ordering. +func (c *Cluster) readNodes() []*Node { + if len(c.replicas) == 0 { + return []*Node{c.master} + } + nodes := make([]*Node, 0, len(c.replicas)+1) + nodes = append(nodes, c.replicas...) + nodes = append(nodes, c.master) + return nodes +} + +// --- metrics helpers --- + +func (c *Cluster) queryStart(ctx context.Context, node, sql string) { + if c.metrics != nil && c.metrics.OnQueryStart != nil { + c.metrics.OnQueryStart(ctx, node, sql) + } +} + +func (c *Cluster) queryEnd(ctx context.Context, node, sql string, err error, d time.Duration) { + if c.metrics != nil && c.metrics.OnQueryEnd != nil { + c.metrics.OnQueryEnd(ctx, node, sql, err, d) + } +} + +// errRow implements pgx.Row for error cases. +type errRow struct { + err error +} + +func (r errRow) Scan(...any) error { + return r.err +} + +// Compile-time check that *Cluster implements DB. +var _ DB = (*Cluster)(nil) diff --git a/config.go b/config.go new file mode 100644 index 0000000..9dc1c9b --- /dev/null +++ b/config.go @@ -0,0 +1,88 @@ +package dbx + +import "time" + +// Config is the top-level configuration for a Cluster. +type Config struct { + Master NodeConfig + Replicas []NodeConfig + Retry RetryConfig + Logger Logger + Metrics *MetricsHook + HealthCheck HealthCheckConfig +} + +// NodeConfig describes a single database node. +type NodeConfig struct { + Name string // human-readable name for logs/metrics, e.g. "master", "replica-1" + DSN string + Pool PoolConfig +} + +// PoolConfig controls pgxpool.Pool parameters. +type PoolConfig struct { + MaxConns int32 + MinConns int32 + MaxConnLifetime time.Duration + MaxConnIdleTime time.Duration + HealthCheckPeriod time.Duration +} + +// RetryConfig controls retry behaviour. +type RetryConfig struct { + MaxAttempts int // default: 3 + BaseDelay time.Duration // default: 50ms + MaxDelay time.Duration // default: 1s + RetryableErrors func(error) bool // optional custom classifier +} + +// HealthCheckConfig controls the background health checker. +type HealthCheckConfig struct { + Interval time.Duration // default: 5s + Timeout time.Duration // default: 2s +} + +// defaults fills zero-valued fields with sensible defaults. +func (c *Config) defaults() { + if c.Logger == nil { + c.Logger = nopLogger{} + } + if c.Retry.MaxAttempts <= 0 { + c.Retry.MaxAttempts = 3 + } + if c.Retry.BaseDelay <= 0 { + c.Retry.BaseDelay = 50 * time.Millisecond + } + if c.Retry.MaxDelay <= 0 { + c.Retry.MaxDelay = time.Second + } + if c.HealthCheck.Interval <= 0 { + c.HealthCheck.Interval = 5 * time.Second + } + if c.HealthCheck.Timeout <= 0 { + c.HealthCheck.Timeout = 2 * time.Second + } + if c.Master.Name == "" { + c.Master.Name = "master" + } + for i := range c.Replicas { + if c.Replicas[i].Name == "" { + c.Replicas[i].Name = "replica-" + itoa(i+1) + } + } +} + +// itoa is a minimal int-to-string without importing strconv. +func itoa(n int) string { + if n == 0 { + return "0" + } + buf := [20]byte{} + i := len(buf) + for n > 0 { + i-- + buf[i] = byte('0' + n%10) + n /= 10 + } + return string(buf[i:]) +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..0d1c095 --- /dev/null +++ b/config_test.go @@ -0,0 +1,72 @@ +package dbx + +import ( + "testing" + "time" +) + +func TestConfigDefaults(t *testing.T) { + cfg := Config{} + cfg.defaults() + + if cfg.Logger == nil { + t.Error("Logger should not be nil after defaults") + } + if cfg.Retry.MaxAttempts != 3 { + t.Errorf("MaxAttempts = %d, want 3", cfg.Retry.MaxAttempts) + } + if cfg.Retry.BaseDelay != 50*time.Millisecond { + t.Errorf("BaseDelay = %v, want 50ms", cfg.Retry.BaseDelay) + } + if cfg.Retry.MaxDelay != time.Second { + t.Errorf("MaxDelay = %v, want 1s", cfg.Retry.MaxDelay) + } + if cfg.HealthCheck.Interval != 5*time.Second { + t.Errorf("HealthCheck.Interval = %v, want 5s", cfg.HealthCheck.Interval) + } + if cfg.HealthCheck.Timeout != 2*time.Second { + t.Errorf("HealthCheck.Timeout = %v, want 2s", cfg.HealthCheck.Timeout) + } + if cfg.Master.Name != "master" { + t.Errorf("Master.Name = %q, want %q", cfg.Master.Name, "master") + } +} + +func TestConfigDefaultsReplicaNames(t *testing.T) { + cfg := Config{ + Replicas: []NodeConfig{ + {DSN: "a"}, + {Name: "custom", DSN: "b"}, + {DSN: "c"}, + }, + } + cfg.defaults() + + if cfg.Replicas[0].Name != "replica-1" { + t.Errorf("Replica[0].Name = %q, want %q", cfg.Replicas[0].Name, "replica-1") + } + if cfg.Replicas[1].Name != "custom" { + t.Errorf("Replica[1].Name = %q, want %q", cfg.Replicas[1].Name, "custom") + } + if cfg.Replicas[2].Name != "replica-3" { + t.Errorf("Replica[2].Name = %q, want %q", cfg.Replicas[2].Name, "replica-3") + } +} + +func TestConfigDefaultsNoOverwrite(t *testing.T) { + cfg := Config{ + Retry: RetryConfig{ + MaxAttempts: 5, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 2 * time.Second, + }, + } + cfg.defaults() + + if cfg.Retry.MaxAttempts != 5 { + t.Errorf("MaxAttempts = %d, want 5", cfg.Retry.MaxAttempts) + } + if cfg.Retry.BaseDelay != 100*time.Millisecond { + t.Errorf("BaseDelay = %v, want 100ms", cfg.Retry.BaseDelay) + } +} diff --git a/dbx.go b/dbx.go new file mode 100644 index 0000000..e7ca8e7 --- /dev/null +++ b/dbx.go @@ -0,0 +1,54 @@ +package dbx + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// Querier is the common interface satisfied by pgxpool.Pool, pgx.Tx, and pgx.Conn. +type Querier interface { + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults + CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) +} + +// DB extends Querier with transaction support and lifecycle management. +type DB interface { + Querier + Begin(ctx context.Context) (pgx.Tx, error) + BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) + Ping(ctx context.Context) error + Close() +} + +// Logger is the interface for pluggable structured logging. +type Logger interface { + Debug(ctx context.Context, msg string, fields ...any) + Info(ctx context.Context, msg string, fields ...any) + Warn(ctx context.Context, msg string, fields ...any) + Error(ctx context.Context, msg string, fields ...any) +} + +// MetricsHook provides optional callbacks for observability. +// All fields are nil-safe — only set the hooks you need. +type MetricsHook struct { + OnQueryStart func(ctx context.Context, node string, sql string) + OnQueryEnd func(ctx context.Context, node string, sql string, err error, duration time.Duration) + OnRetry func(ctx context.Context, node string, attempt int, err error) + OnNodeDown func(ctx context.Context, node string, err error) + OnNodeUp func(ctx context.Context, node string) + OnReplicaFallback func(ctx context.Context, fromNode string, toNode string) +} + +// nopLogger is a Logger that discards all output. +type nopLogger struct{} + +func (nopLogger) Debug(context.Context, string, ...any) {} +func (nopLogger) Info(context.Context, string, ...any) {} +func (nopLogger) Warn(context.Context, string, ...any) {} +func (nopLogger) Error(context.Context, string, ...any) {} diff --git a/dbxtest/dbxtest.go b/dbxtest/dbxtest.go new file mode 100644 index 0000000..1ffcc96 --- /dev/null +++ b/dbxtest/dbxtest.go @@ -0,0 +1,82 @@ +// Package dbxtest provides test helpers for users of the dbx library. +package dbxtest + +import ( + "context" + "os" + "testing" + + "git.codelab.vc/pkg/dbx" +) + +const ( + // EnvTestDSN is the environment variable for the test database DSN. + EnvTestDSN = "DBX_TEST_DSN" + // DefaultTestDSN is the default DSN used when EnvTestDSN is not set. + DefaultTestDSN = "postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable" +) + +// TestDSN returns the DSN from the environment or the default. +func TestDSN() string { + if dsn := os.Getenv(EnvTestDSN); dsn != "" { + return dsn + } + return DefaultTestDSN +} + +// NewTestCluster creates a Cluster connected to a test database. +// It skips the test if the database is not reachable. +// The cluster is automatically closed when the test finishes. +func NewTestCluster(t testing.TB, opts ...dbx.Option) *dbx.Cluster { + t.Helper() + + dsn := TestDSN() + cfg := dbx.Config{ + Master: dbx.NodeConfig{ + Name: "test-master", + DSN: dsn, + Pool: dbx.PoolConfig{MaxConns: 5, MinConns: 1}, + }, + } + dbx.ApplyOptions(&cfg, opts...) + + ctx := context.Background() + cluster, err := dbx.NewCluster(ctx, cfg) + if err != nil { + t.Skipf("dbxtest: cannot connect to test database: %v", err) + } + + t.Cleanup(func() { + cluster.Close() + }) + + return cluster +} + +// TestLogger is a Logger that writes to testing.T. +type TestLogger struct { + T testing.TB +} + +func (l *TestLogger) Debug(_ context.Context, msg string, fields ...any) { + l.T.Helper() + l.T.Logf("[DEBUG] %s %v", msg, fields) +} + +func (l *TestLogger) Info(_ context.Context, msg string, fields ...any) { + l.T.Helper() + l.T.Logf("[INFO] %s %v", msg, fields) +} + +func (l *TestLogger) Warn(_ context.Context, msg string, fields ...any) { + l.T.Helper() + l.T.Logf("[WARN] %s %v", msg, fields) +} + +func (l *TestLogger) Error(_ context.Context, msg string, fields ...any) { + l.T.Helper() + l.T.Logf("[ERROR] %s %v", msg, fields) +} + +// Compile-time check. +var _ dbx.Logger = (*TestLogger)(nil) diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..9061917 --- /dev/null +++ b/doc.go @@ -0,0 +1,49 @@ +// Package dbx provides a production-ready PostgreSQL helper library built on top of pgx/v5. +// +// It manages connection pools, master/replica routing, automatic retries with +// exponential backoff, load balancing across replicas, background health checking, +// and transaction helpers. +// +// # Quick Start +// +// cluster, err := dbx.NewCluster(ctx, dbx.Config{ +// Master: dbx.NodeConfig{ +// Name: "master", +// DSN: "postgres://user:pass@master:5432/mydb", +// Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5}, +// }, +// Replicas: []dbx.NodeConfig{ +// {Name: "replica-1", DSN: "postgres://...@replica1:5432/mydb"}, +// }, +// }) +// if err != nil { +// log.Fatal(err) +// } +// defer cluster.Close() +// +// // Write → master +// cluster.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "alice") +// +// // Read → replica with fallback to master +// rows, _ := cluster.ReadQuery(ctx, "SELECT * FROM users WHERE active = $1", true) +// +// // Transaction → master, panic-safe +// cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error { +// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", 100, fromID) +// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", 100, toID) +// return nil +// }) +// +// # Routing +// +// The library uses explicit method-based routing (no SQL parsing): +// - Exec, Query, QueryRow, Begin, BeginTx, CopyFrom, SendBatch → master +// - ReadQuery, ReadQueryRow → replicas with master fallback +// - Master(), Replica() → direct access to specific nodes +// +// # Retry +// +// Retryable errors include connection errors (PG class 08), serialization failures (40001), +// deadlocks (40P01), and too_many_connections (53300). A custom classifier can be provided +// via RetryConfig.RetryableErrors. +package dbx diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..25637b8 --- /dev/null +++ b/errors.go @@ -0,0 +1,94 @@ +package dbx + +import ( + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" +) + +// Sentinel errors. +var ( + ErrNoHealthyNode = errors.New("dbx: no healthy node available") + ErrClusterClosed = errors.New("dbx: cluster is closed") + ErrRetryExhausted = errors.New("dbx: retry attempts exhausted") +) + +// IsRetryable reports whether the error is worth retrying. +// Connection errors, serialization failures, deadlocks, and too_many_connections are retryable. +func IsRetryable(err error) bool { + if err == nil { + return false + } + if IsConnectionError(err) { + return true + } + code := PgErrorCode(err) + switch code { + case "40001", // serialization_failure + "40P01", // deadlock_detected + "53300": // too_many_connections + return true + } + return false +} + +// IsConnectionError reports whether the error indicates a connection problem (PG class 08). +func IsConnectionError(err error) bool { + if err == nil { + return false + } + code := PgErrorCode(err) + if strings.HasPrefix(code, "08") { + return true + } + // pgx wraps connection errors that may not have a PG code + msg := err.Error() + for _, s := range []string{ + "connection refused", + "connection reset", + "broken pipe", + "EOF", + "no connection", + "conn closed", + "timeout", + } { + if strings.Contains(msg, s) { + return true + } + } + return false +} + +// IsConstraintViolation reports whether the error is a constraint violation (PG class 23). +func IsConstraintViolation(err error) bool { + return strings.HasPrefix(PgErrorCode(err), "23") +} + +// PgErrorCode extracts the PostgreSQL error code from err, or returns "" if not a PG error. +func PgErrorCode(err error) string { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code + } + return "" +} + +// retryError wraps the last error with ErrRetryExhausted context. +type retryError struct { + attempts int + last error +} + +func (e *retryError) Error() string { + return fmt.Sprintf("%s: %d attempts, last error: %v", ErrRetryExhausted, e.attempts, e.last) +} + +func (e *retryError) Unwrap() []error { + return []error{ErrRetryExhausted, e.last} +} + +func newRetryError(attempts int, last error) error { + return &retryError{attempts: attempts, last: last} +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..274e166 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,92 @@ +package dbx + +import ( + "errors" + "testing" + + "github.com/jackc/pgx/v5/pgconn" +) + +func TestIsRetryable(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"generic", errors.New("oops"), false}, + {"serialization_failure", &pgconn.PgError{Code: "40001"}, true}, + {"deadlock", &pgconn.PgError{Code: "40P01"}, true}, + {"too_many_connections", &pgconn.PgError{Code: "53300"}, true}, + {"connection_exception", &pgconn.PgError{Code: "08006"}, true}, + {"constraint_violation", &pgconn.PgError{Code: "23505"}, false}, + {"syntax_error", &pgconn.PgError{Code: "42601"}, false}, + {"connection_refused", errors.New("connection refused"), true}, + {"EOF", errors.New("unexpected EOF"), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsRetryable(tt.err); got != tt.want { + t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestIsConnectionError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"pg_class_08", &pgconn.PgError{Code: "08003"}, true}, + {"pg_class_23", &pgconn.PgError{Code: "23505"}, false}, + {"conn_refused", errors.New("connection refused"), true}, + {"broken_pipe", errors.New("write: broken pipe"), true}, + {"generic", errors.New("something else"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsConnectionError(tt.err); got != tt.want { + t.Errorf("IsConnectionError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestIsConstraintViolation(t *testing.T) { + if IsConstraintViolation(nil) { + t.Error("expected false for nil") + } + if !IsConstraintViolation(&pgconn.PgError{Code: "23505"}) { + t.Error("expected true for unique_violation") + } + if IsConstraintViolation(&pgconn.PgError{Code: "42601"}) { + t.Error("expected false for syntax_error") + } +} + +func TestPgErrorCode(t *testing.T) { + if code := PgErrorCode(errors.New("not pg")); code != "" { + t.Errorf("expected empty, got %q", code) + } + if code := PgErrorCode(&pgconn.PgError{Code: "42P01"}); code != "42P01" { + t.Errorf("expected 42P01, got %q", code) + } +} + +func TestRetryError(t *testing.T) { + inner := errors.New("conn lost") + err := newRetryError(3, inner) + + if !errors.Is(err, ErrRetryExhausted) { + t.Error("should unwrap to ErrRetryExhausted") + } + if !errors.Is(err, inner) { + t.Error("should unwrap to inner error") + } + if err.Error() == "" { + t.Error("error string should not be empty") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f1e7759 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module git.codelab.vc/pkg/dbx + +go 1.25.7 + +require github.com/jackc/pgx/v5 v5.9.1 + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/text v0.29.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8e29ab9 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/health.go b/health.go new file mode 100644 index 0000000..a36151c --- /dev/null +++ b/health.go @@ -0,0 +1,90 @@ +package dbx + +import ( + "context" + "time" +) + +// healthChecker periodically pings nodes and updates their health state. +type healthChecker struct { + nodes []*Node + cfg HealthCheckConfig + logger Logger + metrics *MetricsHook + stop chan struct{} + done chan struct{} +} + +func newHealthChecker(nodes []*Node, cfg HealthCheckConfig, logger Logger, metrics *MetricsHook) *healthChecker { + return &healthChecker{ + nodes: nodes, + cfg: cfg, + logger: logger, + metrics: metrics, + stop: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (h *healthChecker) start() { + go h.loop() +} + +func (h *healthChecker) loop() { + defer close(h.done) + + ticker := time.NewTicker(h.cfg.Interval) + defer ticker.Stop() + + for { + select { + case <-h.stop: + return + case <-ticker.C: + h.checkAll() + } + } +} + +func (h *healthChecker) checkAll() { + for _, node := range h.nodes { + h.checkNode(node) + } +} + +func (h *healthChecker) checkNode(n *Node) { + ctx, cancel := context.WithTimeout(context.Background(), h.cfg.Timeout) + defer cancel() + + err := n.pool.Ping(ctx) + wasHealthy := n.healthy.Load() + + if err != nil { + n.healthy.Store(false) + if wasHealthy { + h.logger.Error(ctx, "dbx: node is down", + "node", n.name, + "error", err, + ) + if h.metrics != nil && h.metrics.OnNodeDown != nil { + h.metrics.OnNodeDown(ctx, n.name, err) + } + } + return + } + + n.healthy.Store(true) + if !wasHealthy { + h.logger.Info(ctx, "dbx: node is up", + "node", n.name, + ) + if h.metrics != nil && h.metrics.OnNodeUp != nil { + h.metrics.OnNodeUp(ctx, n.name) + } + } +} + +func (h *healthChecker) shutdown() { + close(h.stop) + <-h.done +} diff --git a/health_test.go b/health_test.go new file mode 100644 index 0000000..f668213 --- /dev/null +++ b/health_test.go @@ -0,0 +1,56 @@ +package dbx + +import ( + "context" + "sync/atomic" + "testing" + "time" +) + +func TestHealthChecker_StartStop(t *testing.T) { + nodes := makeTestNodes("a", "b") + hc := newHealthChecker(nodes, HealthCheckConfig{ + Interval: 100 * time.Millisecond, + Timeout: 50 * time.Millisecond, + }, nopLogger{}, nil) + + hc.start() + // Just verify it can be stopped without deadlock + hc.shutdown() +} + +func TestHealthChecker_MetricsCallbacks(t *testing.T) { + var downCalled, upCalled atomic.Int32 + + metrics := &MetricsHook{ + OnNodeDown: func(_ context.Context, node string, err error) { + downCalled.Add(1) + }, + OnNodeUp: func(_ context.Context, node string) { + upCalled.Add(1) + }, + } + + node := &Node{name: "test"} + node.healthy.Store(true) + + hc := newHealthChecker([]*Node{node}, HealthCheckConfig{ + Interval: time.Hour, // won't tick in test + Timeout: 50 * time.Millisecond, + }, nopLogger{}, metrics) + + // Simulate a down check: node has no pool, so ping will fail + // We can't easily test with a real pool here, but checkNode + // will panic on nil pool. In integration tests, we use real pools. + _ = hc // just verify construction works + + // Verify state transitions + node.healthy.Store(false) + if node.IsHealthy() { + t.Error("expected unhealthy") + } + node.healthy.Store(true) + if !node.IsHealthy() { + t.Error("expected healthy") + } +} diff --git a/node.go b/node.go new file mode 100644 index 0000000..6e28328 --- /dev/null +++ b/node.go @@ -0,0 +1,110 @@ +package dbx + +import ( + "context" + "sync/atomic" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Node wraps a pgxpool.Pool with health state and a human-readable name. +type Node struct { + name string + pool *pgxpool.Pool + healthy atomic.Bool +} + +// newNode creates a Node from an existing pool. +func newNode(name string, pool *pgxpool.Pool) *Node { + n := &Node{ + name: name, + pool: pool, + } + n.healthy.Store(true) + return n +} + +// connectNode parses the NodeConfig, creates a pgxpool.Pool, and returns a Node. +func connectNode(ctx context.Context, cfg NodeConfig) (*Node, error) { + poolCfg, err := pgxpool.ParseConfig(cfg.DSN) + if err != nil { + return nil, err + } + applyPoolConfig(poolCfg, cfg.Pool) + + pool, err := pgxpool.NewWithConfig(ctx, poolCfg) + if err != nil { + return nil, err + } + return newNode(cfg.Name, pool), nil +} + +func applyPoolConfig(dst *pgxpool.Config, src PoolConfig) { + if src.MaxConns > 0 { + dst.MaxConns = src.MaxConns + } + if src.MinConns > 0 { + dst.MinConns = src.MinConns + } + if src.MaxConnLifetime > 0 { + dst.MaxConnLifetime = src.MaxConnLifetime + } + if src.MaxConnIdleTime > 0 { + dst.MaxConnIdleTime = src.MaxConnIdleTime + } + if src.HealthCheckPeriod > 0 { + dst.HealthCheckPeriod = src.HealthCheckPeriod + } +} + +// Name returns the node's human-readable name. +func (n *Node) Name() string { return n.name } + +// IsHealthy reports whether the node is considered healthy. +func (n *Node) IsHealthy() bool { return n.healthy.Load() } + +// Pool returns the underlying pgxpool.Pool. +func (n *Node) Pool() *pgxpool.Pool { return n.pool } + +// --- DB interface implementation --- + +func (n *Node) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + return n.pool.Exec(ctx, sql, args...) +} + +func (n *Node) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return n.pool.Query(ctx, sql, args...) +} + +func (n *Node) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return n.pool.QueryRow(ctx, sql, args...) +} + +func (n *Node) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return n.pool.SendBatch(ctx, b) +} + +func (n *Node) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return n.pool.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +func (n *Node) Begin(ctx context.Context) (pgx.Tx, error) { + return n.pool.Begin(ctx) +} + +func (n *Node) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + return n.pool.BeginTx(ctx, txOptions) +} + +func (n *Node) Ping(ctx context.Context) error { + return n.pool.Ping(ctx) +} + +func (n *Node) Close() { + n.pool.Close() +} + +// Compile-time check that *Node implements DB. +var _ DB = (*Node)(nil) diff --git a/options.go b/options.go new file mode 100644 index 0000000..ec09f27 --- /dev/null +++ b/options.go @@ -0,0 +1,39 @@ +package dbx + +// Option is a functional option for NewCluster. +type Option func(*Config) + +// WithLogger sets the logger for the cluster. +func WithLogger(l Logger) Option { + return func(c *Config) { + c.Logger = l + } +} + +// WithMetrics sets the metrics hook for the cluster. +func WithMetrics(m *MetricsHook) Option { + return func(c *Config) { + c.Metrics = m + } +} + +// WithRetry overrides the retry configuration. +func WithRetry(r RetryConfig) Option { + return func(c *Config) { + c.Retry = r + } +} + +// WithHealthCheck overrides the health check configuration. +func WithHealthCheck(h HealthCheckConfig) Option { + return func(c *Config) { + c.HealthCheck = h + } +} + +// ApplyOptions applies functional options to a Config. +func ApplyOptions(cfg *Config, opts ...Option) { + for _, o := range opts { + o(cfg) + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..09294de --- /dev/null +++ b/options_test.go @@ -0,0 +1,35 @@ +package dbx + +import ( + "testing" + "time" +) + +func TestApplyOptions(t *testing.T) { + cfg := Config{} + + logger := nopLogger{} + metrics := &MetricsHook{} + retry := RetryConfig{MaxAttempts: 7} + hc := HealthCheckConfig{Interval: 10 * time.Second} + + ApplyOptions(&cfg, + WithLogger(logger), + WithMetrics(metrics), + WithRetry(retry), + WithHealthCheck(hc), + ) + + if cfg.Logger == nil { + t.Error("Logger should be set") + } + if cfg.Metrics == nil { + t.Error("Metrics should be set") + } + if cfg.Retry.MaxAttempts != 7 { + t.Errorf("Retry.MaxAttempts = %d, want 7", cfg.Retry.MaxAttempts) + } + if cfg.HealthCheck.Interval != 10*time.Second { + t.Errorf("HealthCheck.Interval = %v, want 10s", cfg.HealthCheck.Interval) + } +} diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..0cafb95 --- /dev/null +++ b/retry.go @@ -0,0 +1,98 @@ +package dbx + +import ( + "context" + "math" + "math/rand/v2" + "time" +) + +// retrier executes operations with retry and node fallback. +type retrier struct { + cfg RetryConfig + logger Logger + metrics *MetricsHook +} + +func newRetrier(cfg RetryConfig, logger Logger, metrics *MetricsHook) *retrier { + return &retrier{cfg: cfg, logger: logger, metrics: metrics} +} + +// isRetryable checks the custom classifier first, then falls back to the default. +func (r *retrier) isRetryable(err error) bool { + if r.cfg.RetryableErrors != nil { + return r.cfg.RetryableErrors(err) + } + return IsRetryable(err) +} + +// do executes fn on the given nodes in order, retrying on retryable errors. +// For writes, pass a single-element slice with the master. +// For reads, pass [replicas..., master] for fallback. +func (r *retrier) do(ctx context.Context, nodes []*Node, fn func(ctx context.Context, n *Node) error) error { + var lastErr error + + for attempt := 0; attempt < r.cfg.MaxAttempts; attempt++ { + if ctx.Err() != nil { + if lastErr != nil { + return lastErr + } + return ctx.Err() + } + + for _, node := range nodes { + if !node.IsHealthy() { + continue + } + + err := fn(ctx, node) + if err == nil { + return nil + } + lastErr = err + + if !r.isRetryable(err) { + return err + } + + if r.metrics != nil && r.metrics.OnRetry != nil { + r.metrics.OnRetry(ctx, node.name, attempt+1, err) + } + r.logger.Warn(ctx, "dbx: retryable error", + "node", node.name, + "attempt", attempt+1, + "error", err, + ) + } + + if attempt < r.cfg.MaxAttempts-1 { + delay := r.backoff(attempt) + t := time.NewTimer(delay) + select { + case <-ctx.Done(): + t.Stop() + if lastErr != nil { + return lastErr + } + return ctx.Err() + case <-t.C: + } + } + } + + if lastErr == nil { + return ErrNoHealthyNode + } + return newRetryError(r.cfg.MaxAttempts, lastErr) +} + +// backoff returns the delay for the given attempt with jitter. +func (r *retrier) backoff(attempt int) time.Duration { + delay := float64(r.cfg.BaseDelay) * math.Pow(2, float64(attempt)) + if delay > float64(r.cfg.MaxDelay) { + delay = float64(r.cfg.MaxDelay) + } + // add jitter: 75%-125% of computed delay + jitter := 0.75 + rand.Float64()*0.5 + return time.Duration(delay * jitter) +} diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 0000000..1495554 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,168 @@ +package dbx + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestRetrier_Success(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + calls := 0 + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + calls++ + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 1 { + t.Errorf("calls = %d, want 1", calls) + } +} + +func TestRetrier_RetriesOnRetryableError(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + calls := 0 + retryableErr := errors.New("connection refused") + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + calls++ + if calls < 3 { + return retryableErr + } + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 3 { + t.Errorf("calls = %d, want 3", calls) + } +} + +func TestRetrier_NonRetryableError(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + syntaxErr := errors.New("syntax problem") + calls := 0 + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + calls++ + return syntaxErr + }) + if !errors.Is(err, syntaxErr) { + t.Errorf("expected syntax error, got %v", err) + } + if calls != 1 { + t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls) + } +} + +func TestRetrier_Exhausted(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + return errors.New("connection refused") + }) + if !errors.Is(err, ErrRetryExhausted) { + t.Errorf("expected ErrRetryExhausted, got %v", err) + } +} + +func TestRetrier_FallbackToNextNode(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("replica-1", "master") + + visited := []string{} + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + visited = append(visited, n.name) + if n.name == "replica-1" { + return errors.New("connection refused") + } + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(visited) < 2 || visited[1] != "master" { + t.Errorf("expected fallback to master, visited: %v", visited) + } +} + +func TestRetrier_ContextCanceled(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := r.do(ctx, nodes, func(_ context.Context, n *Node) error { + return nil + }) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestRetrier_NoHealthyNodes(t *testing.T) { + r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + nodes[0].healthy.Store(false) + + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + t.Fatal("should not be called") + return nil + }) + if !errors.Is(err, ErrNoHealthyNode) { + t.Errorf("expected ErrNoHealthyNode, got %v", err) + } +} + +func TestRetrier_CustomClassifier(t *testing.T) { + custom := func(err error) bool { + return err.Error() == "custom-retry" + } + r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil) + nodes := makeTestNodes("n1") + + calls := 0 + err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error { + calls++ + if calls < 2 { + return errors.New("custom-retry") + } + return nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 2 { + t.Errorf("calls = %d, want 2", calls) + } +} + +func TestBackoff(t *testing.T) { + r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil) + + d0 := r.backoff(0) + if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond { + t.Errorf("backoff(0) = %v, expected ~100ms", d0) + } + + d3 := r.backoff(3) + if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond { + t.Errorf("backoff(3) = %v, expected ~800ms", d3) + } + + // Should cap at MaxDelay + d10 := r.backoff(10) + if d10 > 1250*time.Millisecond { + t.Errorf("backoff(10) = %v, should be capped near 1s", d10) + } +} diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..402de5f --- /dev/null +++ b/tx.go @@ -0,0 +1,61 @@ +package dbx + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// TxFunc is the callback signature for RunTx. +type TxFunc func(ctx context.Context, tx pgx.Tx) error + +// RunTx executes fn inside a transaction on the master node with default options. +// The transaction is committed if fn returns nil, rolled back otherwise. +// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised. +func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error { + return c.RunTxOptions(ctx, pgx.TxOptions{}, fn) +} + +// RunTxOptions executes fn inside a transaction with the given options. +func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error { + tx, err := c.master.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("dbx: begin tx: %w", err) + } + + return runTx(ctx, tx, fn) +} + +func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) { + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback(ctx) + panic(p) + } + if retErr != nil { + _ = tx.Rollback(ctx) + } + }() + + if err := fn(ctx, tx); err != nil { + return err + } + return tx.Commit(ctx) +} + +// querierKey is the context key for Querier injection. +type querierKey struct{} + +// InjectQuerier returns a new context carrying q. +func InjectQuerier(ctx context.Context, q Querier) context.Context { + return context.WithValue(ctx, querierKey{}, q) +} + +// ExtractQuerier returns the Querier from ctx, or fallback if none is present. +func ExtractQuerier(ctx context.Context, fallback Querier) Querier { + if q, ok := ctx.Value(querierKey{}).(Querier); ok { + return q + } + return fallback +} diff --git a/tx_test.go b/tx_test.go new file mode 100644 index 0000000..5a065fe --- /dev/null +++ b/tx_test.go @@ -0,0 +1,25 @@ +package dbx + +import ( + "context" + "testing" +) + +func TestInjectExtractQuerier(t *testing.T) { + ctx := context.Background() + fallback := &Node{name: "fallback"} + + // No querier in context → returns fallback + got := ExtractQuerier(ctx, fallback) + if got != fallback { + t.Error("expected fallback when no querier in context") + } + + // Inject querier → extract it + injected := &Node{name: "injected"} + ctx = InjectQuerier(ctx, injected) + got = ExtractQuerier(ctx, fallback) + if got != injected { + t.Error("expected injected querier") + } +}