Add dbx library: PostgreSQL cluster with master/replica routing, retry, health checking

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-23 00:01:15 +03:00
parent 164c6a5723
commit 62df3a2eb3
21 changed files with 1607 additions and 0 deletions

35
balancer.go Normal file
View File

@@ -0,0 +1,35 @@
package dbx
import "sync/atomic"
// Balancer selects the next node from a list.
// It must return nil if no suitable node is available.
type Balancer interface {
Next(nodes []*Node) *Node
}
// RoundRobinBalancer distributes load evenly across healthy nodes.
type RoundRobinBalancer struct {
counter atomic.Uint64
}
// NewRoundRobinBalancer creates a new round-robin balancer.
func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{}
}
// Next returns the next healthy node, or nil if none are healthy.
func (b *RoundRobinBalancer) Next(nodes []*Node) *Node {
n := len(nodes)
if n == 0 {
return nil
}
idx := b.counter.Add(1)
for i := 0; i < n; i++ {
node := nodes[(int(idx)+i)%n]
if node.IsHealthy() {
return node
}
}
return nil
}

73
balancer_test.go Normal file
View File

@@ -0,0 +1,73 @@
package dbx
import (
"testing"
)
func TestRoundRobinBalancer_Empty(t *testing.T) {
b := NewRoundRobinBalancer()
if n := b.Next(nil); n != nil {
t.Errorf("expected nil for empty slice, got %v", n)
}
}
func TestRoundRobinBalancer_AllHealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b", "c")
seen := map[string]int{}
for range 30 {
n := b.Next(nodes)
if n == nil {
t.Fatal("unexpected nil")
}
seen[n.name]++
}
if len(seen) != 3 {
t.Errorf("expected 3 distinct nodes, got %d", len(seen))
}
for name, count := range seen {
if count != 10 {
t.Errorf("node %s hit %d times, expected 10", name, count)
}
}
}
func TestRoundRobinBalancer_SkipsUnhealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b", "c")
nodes[1].healthy.Store(false) // b is down
for range 20 {
n := b.Next(nodes)
if n == nil {
t.Fatal("unexpected nil")
}
if n.name == "b" {
t.Error("should not return unhealthy node b")
}
}
}
func TestRoundRobinBalancer_AllUnhealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b")
for _, n := range nodes {
n.healthy.Store(false)
}
if n := b.Next(nodes); n != nil {
t.Errorf("expected nil when all unhealthy, got %v", n.name)
}
}
func makeTestNodes(names ...string) []*Node {
nodes := make([]*Node, len(names))
for i, name := range names {
n := &Node{name: name}
n.healthy.Store(true)
nodes[i] = n
}
return nodes
}

247
cluster.go Normal file
View File

@@ -0,0 +1,247 @@
package dbx
import (
"context"
"sync/atomic"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
// Cluster manages a master node and optional replicas.
// Write operations go to master; read operations are distributed across replicas
// with automatic fallback to master.
type Cluster struct {
master *Node
replicas []*Node
all []*Node // master + replicas for health checker
balancer Balancer
retrier *retrier
health *healthChecker
logger Logger
metrics *MetricsHook
closed atomic.Bool
}
// NewCluster creates a Cluster, connecting to all configured nodes.
func NewCluster(ctx context.Context, cfg Config) (*Cluster, error) {
cfg.defaults()
master, err := connectNode(ctx, cfg.Master)
if err != nil {
return nil, err
}
replicas := make([]*Node, 0, len(cfg.Replicas))
for _, rc := range cfg.Replicas {
r, err := connectNode(ctx, rc)
if err != nil {
// close already-connected nodes
master.Close()
for _, opened := range replicas {
opened.Close()
}
return nil, err
}
replicas = append(replicas, r)
}
all := make([]*Node, 0, 1+len(replicas))
all = append(all, master)
all = append(all, replicas...)
c := &Cluster{
master: master,
replicas: replicas,
all: all,
balancer: NewRoundRobinBalancer(),
retrier: newRetrier(cfg.Retry, cfg.Logger, cfg.Metrics),
logger: cfg.Logger,
metrics: cfg.Metrics,
}
c.health = newHealthChecker(all, cfg.HealthCheck, cfg.Logger, cfg.Metrics)
c.health.start()
return c, nil
}
// Master returns the master node as a DB.
func (c *Cluster) Master() DB { return c.master }
// Replica returns a healthy replica as a DB, or falls back to master.
func (c *Cluster) Replica() DB {
if n := c.balancer.Next(c.replicas); n != nil {
return n
}
return c.master
}
// Close shuts down the health checker and closes all connection pools.
func (c *Cluster) Close() {
if !c.closed.CompareAndSwap(false, true) {
return
}
c.health.shutdown()
for _, n := range c.all {
n.Close()
}
}
// Ping checks connectivity to the master node.
func (c *Cluster) Ping(ctx context.Context) error {
if c.closed.Load() {
return ErrClusterClosed
}
return c.master.Ping(ctx)
}
// --- Write operations (always master) ---
func (c *Cluster) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
if c.closed.Load() {
return pgconn.CommandTag{}, ErrClusterClosed
}
var tag pgconn.CommandTag
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
tag, e = n.Exec(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return tag, err
}
func (c *Cluster) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
var rows pgx.Rows
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
rows, e = n.Query(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return rows, err
}
func (c *Cluster) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
if c.closed.Load() {
return errRow{err: ErrClusterClosed}
}
var row pgx.Row
_ = c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
row = n.QueryRow(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
return nil
})
return row
}
func (c *Cluster) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return c.master.SendBatch(ctx, b)
}
func (c *Cluster) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
if c.closed.Load() {
return 0, ErrClusterClosed
}
return c.master.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (c *Cluster) Begin(ctx context.Context) (pgx.Tx, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
return c.master.Begin(ctx)
}
func (c *Cluster) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
return c.master.BeginTx(ctx, txOptions)
}
// --- Read operations (replicas with master fallback) ---
// ReadQuery executes a read query on a replica, falling back to master if no replicas are healthy.
func (c *Cluster) ReadQuery(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
nodes := c.readNodes()
var rows pgx.Rows
err := c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
rows, e = n.Query(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return rows, err
}
// ReadQueryRow executes a read query returning a single row, using replicas with master fallback.
func (c *Cluster) ReadQueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
if c.closed.Load() {
return errRow{err: ErrClusterClosed}
}
nodes := c.readNodes()
var row pgx.Row
_ = c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
row = n.QueryRow(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
return nil
})
return row
}
// readNodes returns replicas followed by master for fallback ordering.
func (c *Cluster) readNodes() []*Node {
if len(c.replicas) == 0 {
return []*Node{c.master}
}
nodes := make([]*Node, 0, len(c.replicas)+1)
nodes = append(nodes, c.replicas...)
nodes = append(nodes, c.master)
return nodes
}
// --- metrics helpers ---
func (c *Cluster) queryStart(ctx context.Context, node, sql string) {
if c.metrics != nil && c.metrics.OnQueryStart != nil {
c.metrics.OnQueryStart(ctx, node, sql)
}
}
func (c *Cluster) queryEnd(ctx context.Context, node, sql string, err error, d time.Duration) {
if c.metrics != nil && c.metrics.OnQueryEnd != nil {
c.metrics.OnQueryEnd(ctx, node, sql, err, d)
}
}
// errRow implements pgx.Row for error cases.
type errRow struct {
err error
}
func (r errRow) Scan(...any) error {
return r.err
}
// Compile-time check that *Cluster implements DB.
var _ DB = (*Cluster)(nil)

88
config.go Normal file
View File

@@ -0,0 +1,88 @@
package dbx
import "time"
// Config is the top-level configuration for a Cluster.
type Config struct {
Master NodeConfig
Replicas []NodeConfig
Retry RetryConfig
Logger Logger
Metrics *MetricsHook
HealthCheck HealthCheckConfig
}
// NodeConfig describes a single database node.
type NodeConfig struct {
Name string // human-readable name for logs/metrics, e.g. "master", "replica-1"
DSN string
Pool PoolConfig
}
// PoolConfig controls pgxpool.Pool parameters.
type PoolConfig struct {
MaxConns int32
MinConns int32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
HealthCheckPeriod time.Duration
}
// RetryConfig controls retry behaviour.
type RetryConfig struct {
MaxAttempts int // default: 3
BaseDelay time.Duration // default: 50ms
MaxDelay time.Duration // default: 1s
RetryableErrors func(error) bool // optional custom classifier
}
// HealthCheckConfig controls the background health checker.
type HealthCheckConfig struct {
Interval time.Duration // default: 5s
Timeout time.Duration // default: 2s
}
// defaults fills zero-valued fields with sensible defaults.
func (c *Config) defaults() {
if c.Logger == nil {
c.Logger = nopLogger{}
}
if c.Retry.MaxAttempts <= 0 {
c.Retry.MaxAttempts = 3
}
if c.Retry.BaseDelay <= 0 {
c.Retry.BaseDelay = 50 * time.Millisecond
}
if c.Retry.MaxDelay <= 0 {
c.Retry.MaxDelay = time.Second
}
if c.HealthCheck.Interval <= 0 {
c.HealthCheck.Interval = 5 * time.Second
}
if c.HealthCheck.Timeout <= 0 {
c.HealthCheck.Timeout = 2 * time.Second
}
if c.Master.Name == "" {
c.Master.Name = "master"
}
for i := range c.Replicas {
if c.Replicas[i].Name == "" {
c.Replicas[i].Name = "replica-" + itoa(i+1)
}
}
}
// itoa is a minimal int-to-string without importing strconv.
func itoa(n int) string {
if n == 0 {
return "0"
}
buf := [20]byte{}
i := len(buf)
for n > 0 {
i--
buf[i] = byte('0' + n%10)
n /= 10
}
return string(buf[i:])
}

72
config_test.go Normal file
View File

@@ -0,0 +1,72 @@
package dbx
import (
"testing"
"time"
)
func TestConfigDefaults(t *testing.T) {
cfg := Config{}
cfg.defaults()
if cfg.Logger == nil {
t.Error("Logger should not be nil after defaults")
}
if cfg.Retry.MaxAttempts != 3 {
t.Errorf("MaxAttempts = %d, want 3", cfg.Retry.MaxAttempts)
}
if cfg.Retry.BaseDelay != 50*time.Millisecond {
t.Errorf("BaseDelay = %v, want 50ms", cfg.Retry.BaseDelay)
}
if cfg.Retry.MaxDelay != time.Second {
t.Errorf("MaxDelay = %v, want 1s", cfg.Retry.MaxDelay)
}
if cfg.HealthCheck.Interval != 5*time.Second {
t.Errorf("HealthCheck.Interval = %v, want 5s", cfg.HealthCheck.Interval)
}
if cfg.HealthCheck.Timeout != 2*time.Second {
t.Errorf("HealthCheck.Timeout = %v, want 2s", cfg.HealthCheck.Timeout)
}
if cfg.Master.Name != "master" {
t.Errorf("Master.Name = %q, want %q", cfg.Master.Name, "master")
}
}
func TestConfigDefaultsReplicaNames(t *testing.T) {
cfg := Config{
Replicas: []NodeConfig{
{DSN: "a"},
{Name: "custom", DSN: "b"},
{DSN: "c"},
},
}
cfg.defaults()
if cfg.Replicas[0].Name != "replica-1" {
t.Errorf("Replica[0].Name = %q, want %q", cfg.Replicas[0].Name, "replica-1")
}
if cfg.Replicas[1].Name != "custom" {
t.Errorf("Replica[1].Name = %q, want %q", cfg.Replicas[1].Name, "custom")
}
if cfg.Replicas[2].Name != "replica-3" {
t.Errorf("Replica[2].Name = %q, want %q", cfg.Replicas[2].Name, "replica-3")
}
}
func TestConfigDefaultsNoOverwrite(t *testing.T) {
cfg := Config{
Retry: RetryConfig{
MaxAttempts: 5,
BaseDelay: 100 * time.Millisecond,
MaxDelay: 2 * time.Second,
},
}
cfg.defaults()
if cfg.Retry.MaxAttempts != 5 {
t.Errorf("MaxAttempts = %d, want 5", cfg.Retry.MaxAttempts)
}
if cfg.Retry.BaseDelay != 100*time.Millisecond {
t.Errorf("BaseDelay = %v, want 100ms", cfg.Retry.BaseDelay)
}
}

54
dbx.go Normal file
View File

@@ -0,0 +1,54 @@
package dbx
import (
"context"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
// Querier is the common interface satisfied by pgxpool.Pool, pgx.Tx, and pgx.Conn.
type Querier interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
}
// DB extends Querier with transaction support and lifecycle management.
type DB interface {
Querier
Begin(ctx context.Context) (pgx.Tx, error)
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
Ping(ctx context.Context) error
Close()
}
// Logger is the interface for pluggable structured logging.
type Logger interface {
Debug(ctx context.Context, msg string, fields ...any)
Info(ctx context.Context, msg string, fields ...any)
Warn(ctx context.Context, msg string, fields ...any)
Error(ctx context.Context, msg string, fields ...any)
}
// MetricsHook provides optional callbacks for observability.
// All fields are nil-safe — only set the hooks you need.
type MetricsHook struct {
OnQueryStart func(ctx context.Context, node string, sql string)
OnQueryEnd func(ctx context.Context, node string, sql string, err error, duration time.Duration)
OnRetry func(ctx context.Context, node string, attempt int, err error)
OnNodeDown func(ctx context.Context, node string, err error)
OnNodeUp func(ctx context.Context, node string)
OnReplicaFallback func(ctx context.Context, fromNode string, toNode string)
}
// nopLogger is a Logger that discards all output.
type nopLogger struct{}
func (nopLogger) Debug(context.Context, string, ...any) {}
func (nopLogger) Info(context.Context, string, ...any) {}
func (nopLogger) Warn(context.Context, string, ...any) {}
func (nopLogger) Error(context.Context, string, ...any) {}

82
dbxtest/dbxtest.go Normal file
View File

@@ -0,0 +1,82 @@
// Package dbxtest provides test helpers for users of the dbx library.
package dbxtest
import (
"context"
"os"
"testing"
"git.codelab.vc/pkg/dbx"
)
const (
// EnvTestDSN is the environment variable for the test database DSN.
EnvTestDSN = "DBX_TEST_DSN"
// DefaultTestDSN is the default DSN used when EnvTestDSN is not set.
DefaultTestDSN = "postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable"
)
// TestDSN returns the DSN from the environment or the default.
func TestDSN() string {
if dsn := os.Getenv(EnvTestDSN); dsn != "" {
return dsn
}
return DefaultTestDSN
}
// NewTestCluster creates a Cluster connected to a test database.
// It skips the test if the database is not reachable.
// The cluster is automatically closed when the test finishes.
func NewTestCluster(t testing.TB, opts ...dbx.Option) *dbx.Cluster {
t.Helper()
dsn := TestDSN()
cfg := dbx.Config{
Master: dbx.NodeConfig{
Name: "test-master",
DSN: dsn,
Pool: dbx.PoolConfig{MaxConns: 5, MinConns: 1},
},
}
dbx.ApplyOptions(&cfg, opts...)
ctx := context.Background()
cluster, err := dbx.NewCluster(ctx, cfg)
if err != nil {
t.Skipf("dbxtest: cannot connect to test database: %v", err)
}
t.Cleanup(func() {
cluster.Close()
})
return cluster
}
// TestLogger is a Logger that writes to testing.T.
type TestLogger struct {
T testing.TB
}
func (l *TestLogger) Debug(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[DEBUG] %s %v", msg, fields)
}
func (l *TestLogger) Info(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[INFO] %s %v", msg, fields)
}
func (l *TestLogger) Warn(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[WARN] %s %v", msg, fields)
}
func (l *TestLogger) Error(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[ERROR] %s %v", msg, fields)
}
// Compile-time check.
var _ dbx.Logger = (*TestLogger)(nil)

49
doc.go Normal file
View File

@@ -0,0 +1,49 @@
// Package dbx provides a production-ready PostgreSQL helper library built on top of pgx/v5.
//
// It manages connection pools, master/replica routing, automatic retries with
// exponential backoff, load balancing across replicas, background health checking,
// and transaction helpers.
//
// # Quick Start
//
// cluster, err := dbx.NewCluster(ctx, dbx.Config{
// Master: dbx.NodeConfig{
// Name: "master",
// DSN: "postgres://user:pass@master:5432/mydb",
// Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5},
// },
// Replicas: []dbx.NodeConfig{
// {Name: "replica-1", DSN: "postgres://...@replica1:5432/mydb"},
// },
// })
// if err != nil {
// log.Fatal(err)
// }
// defer cluster.Close()
//
// // Write → master
// cluster.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "alice")
//
// // Read → replica with fallback to master
// rows, _ := cluster.ReadQuery(ctx, "SELECT * FROM users WHERE active = $1", true)
//
// // Transaction → master, panic-safe
// cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", 100, fromID)
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", 100, toID)
// return nil
// })
//
// # Routing
//
// The library uses explicit method-based routing (no SQL parsing):
// - Exec, Query, QueryRow, Begin, BeginTx, CopyFrom, SendBatch → master
// - ReadQuery, ReadQueryRow → replicas with master fallback
// - Master(), Replica() → direct access to specific nodes
//
// # Retry
//
// Retryable errors include connection errors (PG class 08), serialization failures (40001),
// deadlocks (40P01), and too_many_connections (53300). A custom classifier can be provided
// via RetryConfig.RetryableErrors.
package dbx

94
errors.go Normal file
View File

@@ -0,0 +1,94 @@
package dbx
import (
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
)
// Sentinel errors.
var (
ErrNoHealthyNode = errors.New("dbx: no healthy node available")
ErrClusterClosed = errors.New("dbx: cluster is closed")
ErrRetryExhausted = errors.New("dbx: retry attempts exhausted")
)
// IsRetryable reports whether the error is worth retrying.
// Connection errors, serialization failures, deadlocks, and too_many_connections are retryable.
func IsRetryable(err error) bool {
if err == nil {
return false
}
if IsConnectionError(err) {
return true
}
code := PgErrorCode(err)
switch code {
case "40001", // serialization_failure
"40P01", // deadlock_detected
"53300": // too_many_connections
return true
}
return false
}
// IsConnectionError reports whether the error indicates a connection problem (PG class 08).
func IsConnectionError(err error) bool {
if err == nil {
return false
}
code := PgErrorCode(err)
if strings.HasPrefix(code, "08") {
return true
}
// pgx wraps connection errors that may not have a PG code
msg := err.Error()
for _, s := range []string{
"connection refused",
"connection reset",
"broken pipe",
"EOF",
"no connection",
"conn closed",
"timeout",
} {
if strings.Contains(msg, s) {
return true
}
}
return false
}
// IsConstraintViolation reports whether the error is a constraint violation (PG class 23).
func IsConstraintViolation(err error) bool {
return strings.HasPrefix(PgErrorCode(err), "23")
}
// PgErrorCode extracts the PostgreSQL error code from err, or returns "" if not a PG error.
func PgErrorCode(err error) string {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code
}
return ""
}
// retryError wraps the last error with ErrRetryExhausted context.
type retryError struct {
attempts int
last error
}
func (e *retryError) Error() string {
return fmt.Sprintf("%s: %d attempts, last error: %v", ErrRetryExhausted, e.attempts, e.last)
}
func (e *retryError) Unwrap() []error {
return []error{ErrRetryExhausted, e.last}
}
func newRetryError(attempts int, last error) error {
return &retryError{attempts: attempts, last: last}
}

92
errors_test.go Normal file
View File

@@ -0,0 +1,92 @@
package dbx
import (
"errors"
"testing"
"github.com/jackc/pgx/v5/pgconn"
)
func TestIsRetryable(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"generic", errors.New("oops"), false},
{"serialization_failure", &pgconn.PgError{Code: "40001"}, true},
{"deadlock", &pgconn.PgError{Code: "40P01"}, true},
{"too_many_connections", &pgconn.PgError{Code: "53300"}, true},
{"connection_exception", &pgconn.PgError{Code: "08006"}, true},
{"constraint_violation", &pgconn.PgError{Code: "23505"}, false},
{"syntax_error", &pgconn.PgError{Code: "42601"}, false},
{"connection_refused", errors.New("connection refused"), true},
{"EOF", errors.New("unexpected EOF"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsRetryable(tt.err); got != tt.want {
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsConnectionError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"pg_class_08", &pgconn.PgError{Code: "08003"}, true},
{"pg_class_23", &pgconn.PgError{Code: "23505"}, false},
{"conn_refused", errors.New("connection refused"), true},
{"broken_pipe", errors.New("write: broken pipe"), true},
{"generic", errors.New("something else"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsConnectionError(tt.err); got != tt.want {
t.Errorf("IsConnectionError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsConstraintViolation(t *testing.T) {
if IsConstraintViolation(nil) {
t.Error("expected false for nil")
}
if !IsConstraintViolation(&pgconn.PgError{Code: "23505"}) {
t.Error("expected true for unique_violation")
}
if IsConstraintViolation(&pgconn.PgError{Code: "42601"}) {
t.Error("expected false for syntax_error")
}
}
func TestPgErrorCode(t *testing.T) {
if code := PgErrorCode(errors.New("not pg")); code != "" {
t.Errorf("expected empty, got %q", code)
}
if code := PgErrorCode(&pgconn.PgError{Code: "42P01"}); code != "42P01" {
t.Errorf("expected 42P01, got %q", code)
}
}
func TestRetryError(t *testing.T) {
inner := errors.New("conn lost")
err := newRetryError(3, inner)
if !errors.Is(err, ErrRetryExhausted) {
t.Error("should unwrap to ErrRetryExhausted")
}
if !errors.Is(err, inner) {
t.Error("should unwrap to inner error")
}
if err.Error() == "" {
t.Error("error string should not be empty")
}
}

13
go.mod Normal file
View File

@@ -0,0 +1,13 @@
module git.codelab.vc/pkg/dbx
go 1.25.7
require github.com/jackc/pgx/v5 v5.9.1
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
)

26
go.sum Normal file
View File

@@ -0,0 +1,26 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

90
health.go Normal file
View File

@@ -0,0 +1,90 @@
package dbx
import (
"context"
"time"
)
// healthChecker periodically pings nodes and updates their health state.
type healthChecker struct {
nodes []*Node
cfg HealthCheckConfig
logger Logger
metrics *MetricsHook
stop chan struct{}
done chan struct{}
}
func newHealthChecker(nodes []*Node, cfg HealthCheckConfig, logger Logger, metrics *MetricsHook) *healthChecker {
return &healthChecker{
nodes: nodes,
cfg: cfg,
logger: logger,
metrics: metrics,
stop: make(chan struct{}),
done: make(chan struct{}),
}
}
func (h *healthChecker) start() {
go h.loop()
}
func (h *healthChecker) loop() {
defer close(h.done)
ticker := time.NewTicker(h.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-h.stop:
return
case <-ticker.C:
h.checkAll()
}
}
}
func (h *healthChecker) checkAll() {
for _, node := range h.nodes {
h.checkNode(node)
}
}
func (h *healthChecker) checkNode(n *Node) {
ctx, cancel := context.WithTimeout(context.Background(), h.cfg.Timeout)
defer cancel()
err := n.pool.Ping(ctx)
wasHealthy := n.healthy.Load()
if err != nil {
n.healthy.Store(false)
if wasHealthy {
h.logger.Error(ctx, "dbx: node is down",
"node", n.name,
"error", err,
)
if h.metrics != nil && h.metrics.OnNodeDown != nil {
h.metrics.OnNodeDown(ctx, n.name, err)
}
}
return
}
n.healthy.Store(true)
if !wasHealthy {
h.logger.Info(ctx, "dbx: node is up",
"node", n.name,
)
if h.metrics != nil && h.metrics.OnNodeUp != nil {
h.metrics.OnNodeUp(ctx, n.name)
}
}
}
func (h *healthChecker) shutdown() {
close(h.stop)
<-h.done
}

56
health_test.go Normal file
View File

@@ -0,0 +1,56 @@
package dbx
import (
"context"
"sync/atomic"
"testing"
"time"
)
func TestHealthChecker_StartStop(t *testing.T) {
nodes := makeTestNodes("a", "b")
hc := newHealthChecker(nodes, HealthCheckConfig{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
}, nopLogger{}, nil)
hc.start()
// Just verify it can be stopped without deadlock
hc.shutdown()
}
func TestHealthChecker_MetricsCallbacks(t *testing.T) {
var downCalled, upCalled atomic.Int32
metrics := &MetricsHook{
OnNodeDown: func(_ context.Context, node string, err error) {
downCalled.Add(1)
},
OnNodeUp: func(_ context.Context, node string) {
upCalled.Add(1)
},
}
node := &Node{name: "test"}
node.healthy.Store(true)
hc := newHealthChecker([]*Node{node}, HealthCheckConfig{
Interval: time.Hour, // won't tick in test
Timeout: 50 * time.Millisecond,
}, nopLogger{}, metrics)
// Simulate a down check: node has no pool, so ping will fail
// We can't easily test with a real pool here, but checkNode
// will panic on nil pool. In integration tests, we use real pools.
_ = hc // just verify construction works
// Verify state transitions
node.healthy.Store(false)
if node.IsHealthy() {
t.Error("expected unhealthy")
}
node.healthy.Store(true)
if !node.IsHealthy() {
t.Error("expected healthy")
}
}

110
node.go Normal file
View File

@@ -0,0 +1,110 @@
package dbx
import (
"context"
"sync/atomic"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
// Node wraps a pgxpool.Pool with health state and a human-readable name.
type Node struct {
name string
pool *pgxpool.Pool
healthy atomic.Bool
}
// newNode creates a Node from an existing pool.
func newNode(name string, pool *pgxpool.Pool) *Node {
n := &Node{
name: name,
pool: pool,
}
n.healthy.Store(true)
return n
}
// connectNode parses the NodeConfig, creates a pgxpool.Pool, and returns a Node.
func connectNode(ctx context.Context, cfg NodeConfig) (*Node, error) {
poolCfg, err := pgxpool.ParseConfig(cfg.DSN)
if err != nil {
return nil, err
}
applyPoolConfig(poolCfg, cfg.Pool)
pool, err := pgxpool.NewWithConfig(ctx, poolCfg)
if err != nil {
return nil, err
}
return newNode(cfg.Name, pool), nil
}
func applyPoolConfig(dst *pgxpool.Config, src PoolConfig) {
if src.MaxConns > 0 {
dst.MaxConns = src.MaxConns
}
if src.MinConns > 0 {
dst.MinConns = src.MinConns
}
if src.MaxConnLifetime > 0 {
dst.MaxConnLifetime = src.MaxConnLifetime
}
if src.MaxConnIdleTime > 0 {
dst.MaxConnIdleTime = src.MaxConnIdleTime
}
if src.HealthCheckPeriod > 0 {
dst.HealthCheckPeriod = src.HealthCheckPeriod
}
}
// Name returns the node's human-readable name.
func (n *Node) Name() string { return n.name }
// IsHealthy reports whether the node is considered healthy.
func (n *Node) IsHealthy() bool { return n.healthy.Load() }
// Pool returns the underlying pgxpool.Pool.
func (n *Node) Pool() *pgxpool.Pool { return n.pool }
// --- DB interface implementation ---
func (n *Node) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
return n.pool.Exec(ctx, sql, args...)
}
func (n *Node) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
return n.pool.Query(ctx, sql, args...)
}
func (n *Node) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
return n.pool.QueryRow(ctx, sql, args...)
}
func (n *Node) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return n.pool.SendBatch(ctx, b)
}
func (n *Node) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return n.pool.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (n *Node) Begin(ctx context.Context) (pgx.Tx, error) {
return n.pool.Begin(ctx)
}
func (n *Node) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
return n.pool.BeginTx(ctx, txOptions)
}
func (n *Node) Ping(ctx context.Context) error {
return n.pool.Ping(ctx)
}
func (n *Node) Close() {
n.pool.Close()
}
// Compile-time check that *Node implements DB.
var _ DB = (*Node)(nil)

39
options.go Normal file
View File

@@ -0,0 +1,39 @@
package dbx
// Option is a functional option for NewCluster.
type Option func(*Config)
// WithLogger sets the logger for the cluster.
func WithLogger(l Logger) Option {
return func(c *Config) {
c.Logger = l
}
}
// WithMetrics sets the metrics hook for the cluster.
func WithMetrics(m *MetricsHook) Option {
return func(c *Config) {
c.Metrics = m
}
}
// WithRetry overrides the retry configuration.
func WithRetry(r RetryConfig) Option {
return func(c *Config) {
c.Retry = r
}
}
// WithHealthCheck overrides the health check configuration.
func WithHealthCheck(h HealthCheckConfig) Option {
return func(c *Config) {
c.HealthCheck = h
}
}
// ApplyOptions applies functional options to a Config.
func ApplyOptions(cfg *Config, opts ...Option) {
for _, o := range opts {
o(cfg)
}
}

35
options_test.go Normal file
View File

@@ -0,0 +1,35 @@
package dbx
import (
"testing"
"time"
)
func TestApplyOptions(t *testing.T) {
cfg := Config{}
logger := nopLogger{}
metrics := &MetricsHook{}
retry := RetryConfig{MaxAttempts: 7}
hc := HealthCheckConfig{Interval: 10 * time.Second}
ApplyOptions(&cfg,
WithLogger(logger),
WithMetrics(metrics),
WithRetry(retry),
WithHealthCheck(hc),
)
if cfg.Logger == nil {
t.Error("Logger should be set")
}
if cfg.Metrics == nil {
t.Error("Metrics should be set")
}
if cfg.Retry.MaxAttempts != 7 {
t.Errorf("Retry.MaxAttempts = %d, want 7", cfg.Retry.MaxAttempts)
}
if cfg.HealthCheck.Interval != 10*time.Second {
t.Errorf("HealthCheck.Interval = %v, want 10s", cfg.HealthCheck.Interval)
}
}

98
retry.go Normal file
View File

@@ -0,0 +1,98 @@
package dbx
import (
"context"
"math"
"math/rand/v2"
"time"
)
// retrier executes operations with retry and node fallback.
type retrier struct {
cfg RetryConfig
logger Logger
metrics *MetricsHook
}
func newRetrier(cfg RetryConfig, logger Logger, metrics *MetricsHook) *retrier {
return &retrier{cfg: cfg, logger: logger, metrics: metrics}
}
// isRetryable checks the custom classifier first, then falls back to the default.
func (r *retrier) isRetryable(err error) bool {
if r.cfg.RetryableErrors != nil {
return r.cfg.RetryableErrors(err)
}
return IsRetryable(err)
}
// do executes fn on the given nodes in order, retrying on retryable errors.
// For writes, pass a single-element slice with the master.
// For reads, pass [replicas..., master] for fallback.
func (r *retrier) do(ctx context.Context, nodes []*Node, fn func(ctx context.Context, n *Node) error) error {
var lastErr error
for attempt := 0; attempt < r.cfg.MaxAttempts; attempt++ {
if ctx.Err() != nil {
if lastErr != nil {
return lastErr
}
return ctx.Err()
}
for _, node := range nodes {
if !node.IsHealthy() {
continue
}
err := fn(ctx, node)
if err == nil {
return nil
}
lastErr = err
if !r.isRetryable(err) {
return err
}
if r.metrics != nil && r.metrics.OnRetry != nil {
r.metrics.OnRetry(ctx, node.name, attempt+1, err)
}
r.logger.Warn(ctx, "dbx: retryable error",
"node", node.name,
"attempt", attempt+1,
"error", err,
)
}
if attempt < r.cfg.MaxAttempts-1 {
delay := r.backoff(attempt)
t := time.NewTimer(delay)
select {
case <-ctx.Done():
t.Stop()
if lastErr != nil {
return lastErr
}
return ctx.Err()
case <-t.C:
}
}
}
if lastErr == nil {
return ErrNoHealthyNode
}
return newRetryError(r.cfg.MaxAttempts, lastErr)
}
// backoff returns the delay for the given attempt with jitter.
func (r *retrier) backoff(attempt int) time.Duration {
delay := float64(r.cfg.BaseDelay) * math.Pow(2, float64(attempt))
if delay > float64(r.cfg.MaxDelay) {
delay = float64(r.cfg.MaxDelay)
}
// add jitter: 75%-125% of computed delay
jitter := 0.75 + rand.Float64()*0.5
return time.Duration(delay * jitter)
}

168
retry_test.go Normal file
View File

@@ -0,0 +1,168 @@
package dbx
import (
"context"
"errors"
"testing"
"time"
)
func TestRetrier_Success(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1", calls)
}
}
func TestRetrier_RetriesOnRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
retryableErr := errors.New("connection refused")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 3 {
return retryableErr
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 3 {
t.Errorf("calls = %d, want 3", calls)
}
}
func TestRetrier_NonRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
syntaxErr := errors.New("syntax problem")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return syntaxErr
})
if !errors.Is(err, syntaxErr) {
t.Errorf("expected syntax error, got %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls)
}
}
func TestRetrier_Exhausted(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
return errors.New("connection refused")
})
if !errors.Is(err, ErrRetryExhausted) {
t.Errorf("expected ErrRetryExhausted, got %v", err)
}
}
func TestRetrier_FallbackToNextNode(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("replica-1", "master")
visited := []string{}
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
visited = append(visited, n.name)
if n.name == "replica-1" {
return errors.New("connection refused")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(visited) < 2 || visited[1] != "master" {
t.Errorf("expected fallback to master, visited: %v", visited)
}
}
func TestRetrier_ContextCanceled(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := r.do(ctx, nodes, func(_ context.Context, n *Node) error {
return nil
})
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestRetrier_NoHealthyNodes(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
nodes[0].healthy.Store(false)
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
t.Fatal("should not be called")
return nil
})
if !errors.Is(err, ErrNoHealthyNode) {
t.Errorf("expected ErrNoHealthyNode, got %v", err)
}
}
func TestRetrier_CustomClassifier(t *testing.T) {
custom := func(err error) bool {
return err.Error() == "custom-retry"
}
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 2 {
return errors.New("custom-retry")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 2 {
t.Errorf("calls = %d, want 2", calls)
}
}
func TestBackoff(t *testing.T) {
r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil)
d0 := r.backoff(0)
if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond {
t.Errorf("backoff(0) = %v, expected ~100ms", d0)
}
d3 := r.backoff(3)
if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond {
t.Errorf("backoff(3) = %v, expected ~800ms", d3)
}
// Should cap at MaxDelay
d10 := r.backoff(10)
if d10 > 1250*time.Millisecond {
t.Errorf("backoff(10) = %v, should be capped near 1s", d10)
}
}

61
tx.go Normal file
View File

@@ -0,0 +1,61 @@
package dbx
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
)
// TxFunc is the callback signature for RunTx.
type TxFunc func(ctx context.Context, tx pgx.Tx) error
// RunTx executes fn inside a transaction on the master node with default options.
// The transaction is committed if fn returns nil, rolled back otherwise.
// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised.
func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error {
return c.RunTxOptions(ctx, pgx.TxOptions{}, fn)
}
// RunTxOptions executes fn inside a transaction with the given options.
func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error {
tx, err := c.master.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("dbx: begin tx: %w", err)
}
return runTx(ctx, tx, fn)
}
func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) {
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback(ctx)
panic(p)
}
if retErr != nil {
_ = tx.Rollback(ctx)
}
}()
if err := fn(ctx, tx); err != nil {
return err
}
return tx.Commit(ctx)
}
// querierKey is the context key for Querier injection.
type querierKey struct{}
// InjectQuerier returns a new context carrying q.
func InjectQuerier(ctx context.Context, q Querier) context.Context {
return context.WithValue(ctx, querierKey{}, q)
}
// ExtractQuerier returns the Querier from ctx, or fallback if none is present.
func ExtractQuerier(ctx context.Context, fallback Querier) Querier {
if q, ok := ctx.Value(querierKey{}).(Querier); ok {
return q
}
return fallback
}

25
tx_test.go Normal file
View File

@@ -0,0 +1,25 @@
package dbx
import (
"context"
"testing"
)
func TestInjectExtractQuerier(t *testing.T) {
ctx := context.Background()
fallback := &Node{name: "fallback"}
// No querier in context → returns fallback
got := ExtractQuerier(ctx, fallback)
if got != fallback {
t.Error("expected fallback when no querier in context")
}
// Inject querier → extract it
injected := &Node{name: "injected"}
ctx = InjectQuerier(ctx, injected)
got = ExtractQuerier(ctx, fallback)
if got != injected {
t.Error("expected injected querier")
}
}