Add dbx library: PostgreSQL cluster with master/replica routing, retry, health checking
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
35
balancer.go
Normal file
35
balancer.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
// Balancer selects the next node from a list.
|
||||||
|
// It must return nil if no suitable node is available.
|
||||||
|
type Balancer interface {
|
||||||
|
Next(nodes []*Node) *Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundRobinBalancer distributes load evenly across healthy nodes.
|
||||||
|
type RoundRobinBalancer struct {
|
||||||
|
counter atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoundRobinBalancer creates a new round-robin balancer.
|
||||||
|
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||||
|
return &RoundRobinBalancer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns the next healthy node, or nil if none are healthy.
|
||||||
|
func (b *RoundRobinBalancer) Next(nodes []*Node) *Node {
|
||||||
|
n := len(nodes)
|
||||||
|
if n == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
idx := b.counter.Add(1)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
node := nodes[(int(idx)+i)%n]
|
||||||
|
if node.IsHealthy() {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
73
balancer_test.go
Normal file
73
balancer_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoundRobinBalancer_Empty(t *testing.T) {
|
||||||
|
b := NewRoundRobinBalancer()
|
||||||
|
if n := b.Next(nil); n != nil {
|
||||||
|
t.Errorf("expected nil for empty slice, got %v", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinBalancer_AllHealthy(t *testing.T) {
|
||||||
|
b := NewRoundRobinBalancer()
|
||||||
|
nodes := makeTestNodes("a", "b", "c")
|
||||||
|
|
||||||
|
seen := map[string]int{}
|
||||||
|
for range 30 {
|
||||||
|
n := b.Next(nodes)
|
||||||
|
if n == nil {
|
||||||
|
t.Fatal("unexpected nil")
|
||||||
|
}
|
||||||
|
seen[n.name]++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(seen) != 3 {
|
||||||
|
t.Errorf("expected 3 distinct nodes, got %d", len(seen))
|
||||||
|
}
|
||||||
|
for name, count := range seen {
|
||||||
|
if count != 10 {
|
||||||
|
t.Errorf("node %s hit %d times, expected 10", name, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinBalancer_SkipsUnhealthy(t *testing.T) {
|
||||||
|
b := NewRoundRobinBalancer()
|
||||||
|
nodes := makeTestNodes("a", "b", "c")
|
||||||
|
nodes[1].healthy.Store(false) // b is down
|
||||||
|
|
||||||
|
for range 20 {
|
||||||
|
n := b.Next(nodes)
|
||||||
|
if n == nil {
|
||||||
|
t.Fatal("unexpected nil")
|
||||||
|
}
|
||||||
|
if n.name == "b" {
|
||||||
|
t.Error("should not return unhealthy node b")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinBalancer_AllUnhealthy(t *testing.T) {
|
||||||
|
b := NewRoundRobinBalancer()
|
||||||
|
nodes := makeTestNodes("a", "b")
|
||||||
|
for _, n := range nodes {
|
||||||
|
n.healthy.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := b.Next(nodes); n != nil {
|
||||||
|
t.Errorf("expected nil when all unhealthy, got %v", n.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTestNodes(names ...string) []*Node {
|
||||||
|
nodes := make([]*Node, len(names))
|
||||||
|
for i, name := range names {
|
||||||
|
n := &Node{name: name}
|
||||||
|
n.healthy.Store(true)
|
||||||
|
nodes[i] = n
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
247
cluster.go
Normal file
247
cluster.go
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cluster manages a master node and optional replicas.
|
||||||
|
// Write operations go to master; read operations are distributed across replicas
|
||||||
|
// with automatic fallback to master.
|
||||||
|
type Cluster struct {
|
||||||
|
master *Node
|
||||||
|
replicas []*Node
|
||||||
|
all []*Node // master + replicas for health checker
|
||||||
|
|
||||||
|
balancer Balancer
|
||||||
|
retrier *retrier
|
||||||
|
health *healthChecker
|
||||||
|
logger Logger
|
||||||
|
metrics *MetricsHook
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCluster creates a Cluster, connecting to all configured nodes.
|
||||||
|
func NewCluster(ctx context.Context, cfg Config) (*Cluster, error) {
|
||||||
|
cfg.defaults()
|
||||||
|
|
||||||
|
master, err := connectNode(ctx, cfg.Master)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
replicas := make([]*Node, 0, len(cfg.Replicas))
|
||||||
|
for _, rc := range cfg.Replicas {
|
||||||
|
r, err := connectNode(ctx, rc)
|
||||||
|
if err != nil {
|
||||||
|
// close already-connected nodes
|
||||||
|
master.Close()
|
||||||
|
for _, opened := range replicas {
|
||||||
|
opened.Close()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
replicas = append(replicas, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
all := make([]*Node, 0, 1+len(replicas))
|
||||||
|
all = append(all, master)
|
||||||
|
all = append(all, replicas...)
|
||||||
|
|
||||||
|
c := &Cluster{
|
||||||
|
master: master,
|
||||||
|
replicas: replicas,
|
||||||
|
all: all,
|
||||||
|
balancer: NewRoundRobinBalancer(),
|
||||||
|
retrier: newRetrier(cfg.Retry, cfg.Logger, cfg.Metrics),
|
||||||
|
logger: cfg.Logger,
|
||||||
|
metrics: cfg.Metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.health = newHealthChecker(all, cfg.HealthCheck, cfg.Logger, cfg.Metrics)
|
||||||
|
c.health.start()
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Master returns the master node as a DB.
|
||||||
|
func (c *Cluster) Master() DB { return c.master }
|
||||||
|
|
||||||
|
// Replica returns a healthy replica as a DB, or falls back to master.
|
||||||
|
func (c *Cluster) Replica() DB {
|
||||||
|
if n := c.balancer.Next(c.replicas); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
return c.master
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the health checker and closes all connection pools.
|
||||||
|
func (c *Cluster) Close() {
|
||||||
|
if !c.closed.CompareAndSwap(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.health.shutdown()
|
||||||
|
for _, n := range c.all {
|
||||||
|
n.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping checks connectivity to the master node.
|
||||||
|
func (c *Cluster) Ping(ctx context.Context) error {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return ErrClusterClosed
|
||||||
|
}
|
||||||
|
return c.master.Ping(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Write operations (always master) ---
|
||||||
|
|
||||||
|
func (c *Cluster) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return pgconn.CommandTag{}, ErrClusterClosed
|
||||||
|
}
|
||||||
|
var tag pgconn.CommandTag
|
||||||
|
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
|
||||||
|
start := time.Now()
|
||||||
|
c.queryStart(ctx, n.name, sql)
|
||||||
|
var e error
|
||||||
|
tag, e = n.Exec(ctx, sql, args...)
|
||||||
|
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
|
||||||
|
return e
|
||||||
|
})
|
||||||
|
return tag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return nil, ErrClusterClosed
|
||||||
|
}
|
||||||
|
var rows pgx.Rows
|
||||||
|
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
|
||||||
|
start := time.Now()
|
||||||
|
c.queryStart(ctx, n.name, sql)
|
||||||
|
var e error
|
||||||
|
rows, e = n.Query(ctx, sql, args...)
|
||||||
|
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
|
||||||
|
return e
|
||||||
|
})
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return errRow{err: ErrClusterClosed}
|
||||||
|
}
|
||||||
|
var row pgx.Row
|
||||||
|
_ = c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
|
||||||
|
start := time.Now()
|
||||||
|
c.queryStart(ctx, n.name, sql)
|
||||||
|
row = n.QueryRow(ctx, sql, args...)
|
||||||
|
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||||
|
return c.master.SendBatch(ctx, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return 0, ErrClusterClosed
|
||||||
|
}
|
||||||
|
return c.master.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) Begin(ctx context.Context) (pgx.Tx, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return nil, ErrClusterClosed
|
||||||
|
}
|
||||||
|
return c.master.Begin(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return nil, ErrClusterClosed
|
||||||
|
}
|
||||||
|
return c.master.BeginTx(ctx, txOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Read operations (replicas with master fallback) ---
|
||||||
|
|
||||||
|
// ReadQuery executes a read query on a replica, falling back to master if no replicas are healthy.
|
||||||
|
func (c *Cluster) ReadQuery(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return nil, ErrClusterClosed
|
||||||
|
}
|
||||||
|
nodes := c.readNodes()
|
||||||
|
var rows pgx.Rows
|
||||||
|
err := c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
|
||||||
|
start := time.Now()
|
||||||
|
c.queryStart(ctx, n.name, sql)
|
||||||
|
var e error
|
||||||
|
rows, e = n.Query(ctx, sql, args...)
|
||||||
|
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
|
||||||
|
return e
|
||||||
|
})
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadQueryRow executes a read query returning a single row, using replicas with master fallback.
|
||||||
|
func (c *Cluster) ReadQueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
||||||
|
if c.closed.Load() {
|
||||||
|
return errRow{err: ErrClusterClosed}
|
||||||
|
}
|
||||||
|
nodes := c.readNodes()
|
||||||
|
var row pgx.Row
|
||||||
|
_ = c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
|
||||||
|
start := time.Now()
|
||||||
|
c.queryStart(ctx, n.name, sql)
|
||||||
|
row = n.QueryRow(ctx, sql, args...)
|
||||||
|
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
// readNodes returns replicas followed by master for fallback ordering.
|
||||||
|
func (c *Cluster) readNodes() []*Node {
|
||||||
|
if len(c.replicas) == 0 {
|
||||||
|
return []*Node{c.master}
|
||||||
|
}
|
||||||
|
nodes := make([]*Node, 0, len(c.replicas)+1)
|
||||||
|
nodes = append(nodes, c.replicas...)
|
||||||
|
nodes = append(nodes, c.master)
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- metrics helpers ---
|
||||||
|
|
||||||
|
func (c *Cluster) queryStart(ctx context.Context, node, sql string) {
|
||||||
|
if c.metrics != nil && c.metrics.OnQueryStart != nil {
|
||||||
|
c.metrics.OnQueryStart(ctx, node, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cluster) queryEnd(ctx context.Context, node, sql string, err error, d time.Duration) {
|
||||||
|
if c.metrics != nil && c.metrics.OnQueryEnd != nil {
|
||||||
|
c.metrics.OnQueryEnd(ctx, node, sql, err, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errRow implements pgx.Row for error cases.
|
||||||
|
type errRow struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r errRow) Scan(...any) error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time check that *Cluster implements DB.
|
||||||
|
var _ DB = (*Cluster)(nil)
|
||||||
88
config.go
Normal file
88
config.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Config is the top-level configuration for a Cluster.
|
||||||
|
type Config struct {
|
||||||
|
Master NodeConfig
|
||||||
|
Replicas []NodeConfig
|
||||||
|
Retry RetryConfig
|
||||||
|
Logger Logger
|
||||||
|
Metrics *MetricsHook
|
||||||
|
HealthCheck HealthCheckConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeConfig describes a single database node.
|
||||||
|
type NodeConfig struct {
|
||||||
|
Name string // human-readable name for logs/metrics, e.g. "master", "replica-1"
|
||||||
|
DSN string
|
||||||
|
Pool PoolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// PoolConfig controls pgxpool.Pool parameters.
|
||||||
|
type PoolConfig struct {
|
||||||
|
MaxConns int32
|
||||||
|
MinConns int32
|
||||||
|
MaxConnLifetime time.Duration
|
||||||
|
MaxConnIdleTime time.Duration
|
||||||
|
HealthCheckPeriod time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryConfig controls retry behaviour.
|
||||||
|
type RetryConfig struct {
|
||||||
|
MaxAttempts int // default: 3
|
||||||
|
BaseDelay time.Duration // default: 50ms
|
||||||
|
MaxDelay time.Duration // default: 1s
|
||||||
|
RetryableErrors func(error) bool // optional custom classifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheckConfig controls the background health checker.
|
||||||
|
type HealthCheckConfig struct {
|
||||||
|
Interval time.Duration // default: 5s
|
||||||
|
Timeout time.Duration // default: 2s
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults fills zero-valued fields with sensible defaults.
|
||||||
|
func (c *Config) defaults() {
|
||||||
|
if c.Logger == nil {
|
||||||
|
c.Logger = nopLogger{}
|
||||||
|
}
|
||||||
|
if c.Retry.MaxAttempts <= 0 {
|
||||||
|
c.Retry.MaxAttempts = 3
|
||||||
|
}
|
||||||
|
if c.Retry.BaseDelay <= 0 {
|
||||||
|
c.Retry.BaseDelay = 50 * time.Millisecond
|
||||||
|
}
|
||||||
|
if c.Retry.MaxDelay <= 0 {
|
||||||
|
c.Retry.MaxDelay = time.Second
|
||||||
|
}
|
||||||
|
if c.HealthCheck.Interval <= 0 {
|
||||||
|
c.HealthCheck.Interval = 5 * time.Second
|
||||||
|
}
|
||||||
|
if c.HealthCheck.Timeout <= 0 {
|
||||||
|
c.HealthCheck.Timeout = 2 * time.Second
|
||||||
|
}
|
||||||
|
if c.Master.Name == "" {
|
||||||
|
c.Master.Name = "master"
|
||||||
|
}
|
||||||
|
for i := range c.Replicas {
|
||||||
|
if c.Replicas[i].Name == "" {
|
||||||
|
c.Replicas[i].Name = "replica-" + itoa(i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// itoa is a minimal int-to-string without importing strconv.
|
||||||
|
func itoa(n int) string {
|
||||||
|
if n == 0 {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
buf := [20]byte{}
|
||||||
|
i := len(buf)
|
||||||
|
for n > 0 {
|
||||||
|
i--
|
||||||
|
buf[i] = byte('0' + n%10)
|
||||||
|
n /= 10
|
||||||
|
}
|
||||||
|
return string(buf[i:])
|
||||||
|
}
|
||||||
72
config_test.go
Normal file
72
config_test.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigDefaults(t *testing.T) {
|
||||||
|
cfg := Config{}
|
||||||
|
cfg.defaults()
|
||||||
|
|
||||||
|
if cfg.Logger == nil {
|
||||||
|
t.Error("Logger should not be nil after defaults")
|
||||||
|
}
|
||||||
|
if cfg.Retry.MaxAttempts != 3 {
|
||||||
|
t.Errorf("MaxAttempts = %d, want 3", cfg.Retry.MaxAttempts)
|
||||||
|
}
|
||||||
|
if cfg.Retry.BaseDelay != 50*time.Millisecond {
|
||||||
|
t.Errorf("BaseDelay = %v, want 50ms", cfg.Retry.BaseDelay)
|
||||||
|
}
|
||||||
|
if cfg.Retry.MaxDelay != time.Second {
|
||||||
|
t.Errorf("MaxDelay = %v, want 1s", cfg.Retry.MaxDelay)
|
||||||
|
}
|
||||||
|
if cfg.HealthCheck.Interval != 5*time.Second {
|
||||||
|
t.Errorf("HealthCheck.Interval = %v, want 5s", cfg.HealthCheck.Interval)
|
||||||
|
}
|
||||||
|
if cfg.HealthCheck.Timeout != 2*time.Second {
|
||||||
|
t.Errorf("HealthCheck.Timeout = %v, want 2s", cfg.HealthCheck.Timeout)
|
||||||
|
}
|
||||||
|
if cfg.Master.Name != "master" {
|
||||||
|
t.Errorf("Master.Name = %q, want %q", cfg.Master.Name, "master")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigDefaultsReplicaNames(t *testing.T) {
|
||||||
|
cfg := Config{
|
||||||
|
Replicas: []NodeConfig{
|
||||||
|
{DSN: "a"},
|
||||||
|
{Name: "custom", DSN: "b"},
|
||||||
|
{DSN: "c"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg.defaults()
|
||||||
|
|
||||||
|
if cfg.Replicas[0].Name != "replica-1" {
|
||||||
|
t.Errorf("Replica[0].Name = %q, want %q", cfg.Replicas[0].Name, "replica-1")
|
||||||
|
}
|
||||||
|
if cfg.Replicas[1].Name != "custom" {
|
||||||
|
t.Errorf("Replica[1].Name = %q, want %q", cfg.Replicas[1].Name, "custom")
|
||||||
|
}
|
||||||
|
if cfg.Replicas[2].Name != "replica-3" {
|
||||||
|
t.Errorf("Replica[2].Name = %q, want %q", cfg.Replicas[2].Name, "replica-3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigDefaultsNoOverwrite(t *testing.T) {
|
||||||
|
cfg := Config{
|
||||||
|
Retry: RetryConfig{
|
||||||
|
MaxAttempts: 5,
|
||||||
|
BaseDelay: 100 * time.Millisecond,
|
||||||
|
MaxDelay: 2 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg.defaults()
|
||||||
|
|
||||||
|
if cfg.Retry.MaxAttempts != 5 {
|
||||||
|
t.Errorf("MaxAttempts = %d, want 5", cfg.Retry.MaxAttempts)
|
||||||
|
}
|
||||||
|
if cfg.Retry.BaseDelay != 100*time.Millisecond {
|
||||||
|
t.Errorf("BaseDelay = %v, want 100ms", cfg.Retry.BaseDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
54
dbx.go
Normal file
54
dbx.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Querier is the common interface satisfied by pgxpool.Pool, pgx.Tx, and pgx.Conn.
|
||||||
|
type Querier interface {
|
||||||
|
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||||
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||||
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||||
|
SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults
|
||||||
|
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB extends Querier with transaction support and lifecycle management.
|
||||||
|
type DB interface {
|
||||||
|
Querier
|
||||||
|
Begin(ctx context.Context) (pgx.Tx, error)
|
||||||
|
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
|
||||||
|
Ping(ctx context.Context) error
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is the interface for pluggable structured logging.
|
||||||
|
type Logger interface {
|
||||||
|
Debug(ctx context.Context, msg string, fields ...any)
|
||||||
|
Info(ctx context.Context, msg string, fields ...any)
|
||||||
|
Warn(ctx context.Context, msg string, fields ...any)
|
||||||
|
Error(ctx context.Context, msg string, fields ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsHook provides optional callbacks for observability.
|
||||||
|
// All fields are nil-safe — only set the hooks you need.
|
||||||
|
type MetricsHook struct {
|
||||||
|
OnQueryStart func(ctx context.Context, node string, sql string)
|
||||||
|
OnQueryEnd func(ctx context.Context, node string, sql string, err error, duration time.Duration)
|
||||||
|
OnRetry func(ctx context.Context, node string, attempt int, err error)
|
||||||
|
OnNodeDown func(ctx context.Context, node string, err error)
|
||||||
|
OnNodeUp func(ctx context.Context, node string)
|
||||||
|
OnReplicaFallback func(ctx context.Context, fromNode string, toNode string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// nopLogger is a Logger that discards all output.
|
||||||
|
type nopLogger struct{}
|
||||||
|
|
||||||
|
func (nopLogger) Debug(context.Context, string, ...any) {}
|
||||||
|
func (nopLogger) Info(context.Context, string, ...any) {}
|
||||||
|
func (nopLogger) Warn(context.Context, string, ...any) {}
|
||||||
|
func (nopLogger) Error(context.Context, string, ...any) {}
|
||||||
82
dbxtest/dbxtest.go
Normal file
82
dbxtest/dbxtest.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// Package dbxtest provides test helpers for users of the dbx library.
|
||||||
|
package dbxtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/dbx"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// EnvTestDSN is the environment variable for the test database DSN.
|
||||||
|
EnvTestDSN = "DBX_TEST_DSN"
|
||||||
|
// DefaultTestDSN is the default DSN used when EnvTestDSN is not set.
|
||||||
|
DefaultTestDSN = "postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDSN returns the DSN from the environment or the default.
|
||||||
|
func TestDSN() string {
|
||||||
|
if dsn := os.Getenv(EnvTestDSN); dsn != "" {
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
return DefaultTestDSN
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestCluster creates a Cluster connected to a test database.
|
||||||
|
// It skips the test if the database is not reachable.
|
||||||
|
// The cluster is automatically closed when the test finishes.
|
||||||
|
func NewTestCluster(t testing.TB, opts ...dbx.Option) *dbx.Cluster {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dsn := TestDSN()
|
||||||
|
cfg := dbx.Config{
|
||||||
|
Master: dbx.NodeConfig{
|
||||||
|
Name: "test-master",
|
||||||
|
DSN: dsn,
|
||||||
|
Pool: dbx.PoolConfig{MaxConns: 5, MinConns: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dbx.ApplyOptions(&cfg, opts...)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cluster, err := dbx.NewCluster(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("dbxtest: cannot connect to test database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cluster.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return cluster
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLogger is a Logger that writes to testing.T.
|
||||||
|
type TestLogger struct {
|
||||||
|
T testing.TB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLogger) Debug(_ context.Context, msg string, fields ...any) {
|
||||||
|
l.T.Helper()
|
||||||
|
l.T.Logf("[DEBUG] %s %v", msg, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLogger) Info(_ context.Context, msg string, fields ...any) {
|
||||||
|
l.T.Helper()
|
||||||
|
l.T.Logf("[INFO] %s %v", msg, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLogger) Warn(_ context.Context, msg string, fields ...any) {
|
||||||
|
l.T.Helper()
|
||||||
|
l.T.Logf("[WARN] %s %v", msg, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLogger) Error(_ context.Context, msg string, fields ...any) {
|
||||||
|
l.T.Helper()
|
||||||
|
l.T.Logf("[ERROR] %s %v", msg, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time check.
|
||||||
|
var _ dbx.Logger = (*TestLogger)(nil)
|
||||||
49
doc.go
Normal file
49
doc.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// Package dbx provides a production-ready PostgreSQL helper library built on top of pgx/v5.
|
||||||
|
//
|
||||||
|
// It manages connection pools, master/replica routing, automatic retries with
|
||||||
|
// exponential backoff, load balancing across replicas, background health checking,
|
||||||
|
// and transaction helpers.
|
||||||
|
//
|
||||||
|
// # Quick Start
|
||||||
|
//
|
||||||
|
// cluster, err := dbx.NewCluster(ctx, dbx.Config{
|
||||||
|
// Master: dbx.NodeConfig{
|
||||||
|
// Name: "master",
|
||||||
|
// DSN: "postgres://user:pass@master:5432/mydb",
|
||||||
|
// Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5},
|
||||||
|
// },
|
||||||
|
// Replicas: []dbx.NodeConfig{
|
||||||
|
// {Name: "replica-1", DSN: "postgres://...@replica1:5432/mydb"},
|
||||||
|
// },
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// defer cluster.Close()
|
||||||
|
//
|
||||||
|
// // Write → master
|
||||||
|
// cluster.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "alice")
|
||||||
|
//
|
||||||
|
// // Read → replica with fallback to master
|
||||||
|
// rows, _ := cluster.ReadQuery(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||||
|
//
|
||||||
|
// // Transaction → master, panic-safe
|
||||||
|
// cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", 100, fromID)
|
||||||
|
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", 100, toID)
|
||||||
|
// return nil
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// # Routing
|
||||||
|
//
|
||||||
|
// The library uses explicit method-based routing (no SQL parsing):
|
||||||
|
// - Exec, Query, QueryRow, Begin, BeginTx, CopyFrom, SendBatch → master
|
||||||
|
// - ReadQuery, ReadQueryRow → replicas with master fallback
|
||||||
|
// - Master(), Replica() → direct access to specific nodes
|
||||||
|
//
|
||||||
|
// # Retry
|
||||||
|
//
|
||||||
|
// Retryable errors include connection errors (PG class 08), serialization failures (40001),
|
||||||
|
// deadlocks (40P01), and too_many_connections (53300). A custom classifier can be provided
|
||||||
|
// via RetryConfig.RetryableErrors.
|
||||||
|
package dbx
|
||||||
94
errors.go
Normal file
94
errors.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors.
|
||||||
|
var (
|
||||||
|
ErrNoHealthyNode = errors.New("dbx: no healthy node available")
|
||||||
|
ErrClusterClosed = errors.New("dbx: cluster is closed")
|
||||||
|
ErrRetryExhausted = errors.New("dbx: retry attempts exhausted")
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsRetryable reports whether the error is worth retrying.
|
||||||
|
// Connection errors, serialization failures, deadlocks, and too_many_connections are retryable.
|
||||||
|
func IsRetryable(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if IsConnectionError(err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
code := PgErrorCode(err)
|
||||||
|
switch code {
|
||||||
|
case "40001", // serialization_failure
|
||||||
|
"40P01", // deadlock_detected
|
||||||
|
"53300": // too_many_connections
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnectionError reports whether the error indicates a connection problem (PG class 08).
|
||||||
|
func IsConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
code := PgErrorCode(err)
|
||||||
|
if strings.HasPrefix(code, "08") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// pgx wraps connection errors that may not have a PG code
|
||||||
|
msg := err.Error()
|
||||||
|
for _, s := range []string{
|
||||||
|
"connection refused",
|
||||||
|
"connection reset",
|
||||||
|
"broken pipe",
|
||||||
|
"EOF",
|
||||||
|
"no connection",
|
||||||
|
"conn closed",
|
||||||
|
"timeout",
|
||||||
|
} {
|
||||||
|
if strings.Contains(msg, s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConstraintViolation reports whether the error is a constraint violation (PG class 23).
|
||||||
|
func IsConstraintViolation(err error) bool {
|
||||||
|
return strings.HasPrefix(PgErrorCode(err), "23")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PgErrorCode extracts the PostgreSQL error code from err, or returns "" if not a PG error.
|
||||||
|
func PgErrorCode(err error) string {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) {
|
||||||
|
return pgErr.Code
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryError wraps the last error with ErrRetryExhausted context.
|
||||||
|
type retryError struct {
|
||||||
|
attempts int
|
||||||
|
last error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *retryError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %d attempts, last error: %v", ErrRetryExhausted, e.attempts, e.last)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *retryError) Unwrap() []error {
|
||||||
|
return []error{ErrRetryExhausted, e.last}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRetryError(attempts int, last error) error {
|
||||||
|
return &retryError{attempts: attempts, last: last}
|
||||||
|
}
|
||||||
92
errors_test.go
Normal file
92
errors_test.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsRetryable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false},
|
||||||
|
{"generic", errors.New("oops"), false},
|
||||||
|
{"serialization_failure", &pgconn.PgError{Code: "40001"}, true},
|
||||||
|
{"deadlock", &pgconn.PgError{Code: "40P01"}, true},
|
||||||
|
{"too_many_connections", &pgconn.PgError{Code: "53300"}, true},
|
||||||
|
{"connection_exception", &pgconn.PgError{Code: "08006"}, true},
|
||||||
|
{"constraint_violation", &pgconn.PgError{Code: "23505"}, false},
|
||||||
|
{"syntax_error", &pgconn.PgError{Code: "42601"}, false},
|
||||||
|
{"connection_refused", errors.New("connection refused"), true},
|
||||||
|
{"EOF", errors.New("unexpected EOF"), true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsRetryable(tt.err); got != tt.want {
|
||||||
|
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsConnectionError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false},
|
||||||
|
{"pg_class_08", &pgconn.PgError{Code: "08003"}, true},
|
||||||
|
{"pg_class_23", &pgconn.PgError{Code: "23505"}, false},
|
||||||
|
{"conn_refused", errors.New("connection refused"), true},
|
||||||
|
{"broken_pipe", errors.New("write: broken pipe"), true},
|
||||||
|
{"generic", errors.New("something else"), false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsConnectionError(tt.err); got != tt.want {
|
||||||
|
t.Errorf("IsConnectionError(%v) = %v, want %v", tt.err, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsConstraintViolation(t *testing.T) {
|
||||||
|
if IsConstraintViolation(nil) {
|
||||||
|
t.Error("expected false for nil")
|
||||||
|
}
|
||||||
|
if !IsConstraintViolation(&pgconn.PgError{Code: "23505"}) {
|
||||||
|
t.Error("expected true for unique_violation")
|
||||||
|
}
|
||||||
|
if IsConstraintViolation(&pgconn.PgError{Code: "42601"}) {
|
||||||
|
t.Error("expected false for syntax_error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgErrorCode(t *testing.T) {
|
||||||
|
if code := PgErrorCode(errors.New("not pg")); code != "" {
|
||||||
|
t.Errorf("expected empty, got %q", code)
|
||||||
|
}
|
||||||
|
if code := PgErrorCode(&pgconn.PgError{Code: "42P01"}); code != "42P01" {
|
||||||
|
t.Errorf("expected 42P01, got %q", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryError(t *testing.T) {
|
||||||
|
inner := errors.New("conn lost")
|
||||||
|
err := newRetryError(3, inner)
|
||||||
|
|
||||||
|
if !errors.Is(err, ErrRetryExhausted) {
|
||||||
|
t.Error("should unwrap to ErrRetryExhausted")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, inner) {
|
||||||
|
t.Error("should unwrap to inner error")
|
||||||
|
}
|
||||||
|
if err.Error() == "" {
|
||||||
|
t.Error("error string should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
13
go.mod
Normal file
13
go.mod
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
module git.codelab.vc/pkg/dbx
|
||||||
|
|
||||||
|
go 1.25.7
|
||||||
|
|
||||||
|
require github.com/jackc/pgx/v5 v5.9.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
golang.org/x/sync v0.17.0 // indirect
|
||||||
|
golang.org/x/text v0.29.0 // indirect
|
||||||
|
)
|
||||||
26
go.sum
Normal file
26
go.sum
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||||
|
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
90
health.go
Normal file
90
health.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// healthChecker periodically pings nodes and updates their health state.
|
||||||
|
type healthChecker struct {
|
||||||
|
nodes []*Node
|
||||||
|
cfg HealthCheckConfig
|
||||||
|
logger Logger
|
||||||
|
metrics *MetricsHook
|
||||||
|
stop chan struct{}
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHealthChecker(nodes []*Node, cfg HealthCheckConfig, logger Logger, metrics *MetricsHook) *healthChecker {
|
||||||
|
return &healthChecker{
|
||||||
|
nodes: nodes,
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
metrics: metrics,
|
||||||
|
stop: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthChecker) start() {
|
||||||
|
go h.loop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthChecker) loop() {
|
||||||
|
defer close(h.done)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(h.cfg.Interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.stop:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
h.checkAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthChecker) checkAll() {
|
||||||
|
for _, node := range h.nodes {
|
||||||
|
h.checkNode(node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthChecker) checkNode(n *Node) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), h.cfg.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := n.pool.Ping(ctx)
|
||||||
|
wasHealthy := n.healthy.Load()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
n.healthy.Store(false)
|
||||||
|
if wasHealthy {
|
||||||
|
h.logger.Error(ctx, "dbx: node is down",
|
||||||
|
"node", n.name,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
if h.metrics != nil && h.metrics.OnNodeDown != nil {
|
||||||
|
h.metrics.OnNodeDown(ctx, n.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.healthy.Store(true)
|
||||||
|
if !wasHealthy {
|
||||||
|
h.logger.Info(ctx, "dbx: node is up",
|
||||||
|
"node", n.name,
|
||||||
|
)
|
||||||
|
if h.metrics != nil && h.metrics.OnNodeUp != nil {
|
||||||
|
h.metrics.OnNodeUp(ctx, n.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthChecker) shutdown() {
|
||||||
|
close(h.stop)
|
||||||
|
<-h.done
|
||||||
|
}
|
||||||
56
health_test.go
Normal file
56
health_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthChecker_StartStop(t *testing.T) {
|
||||||
|
nodes := makeTestNodes("a", "b")
|
||||||
|
hc := newHealthChecker(nodes, HealthCheckConfig{
|
||||||
|
Interval: 100 * time.Millisecond,
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
}, nopLogger{}, nil)
|
||||||
|
|
||||||
|
hc.start()
|
||||||
|
// Just verify it can be stopped without deadlock
|
||||||
|
hc.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthChecker_MetricsCallbacks(t *testing.T) {
|
||||||
|
var downCalled, upCalled atomic.Int32
|
||||||
|
|
||||||
|
metrics := &MetricsHook{
|
||||||
|
OnNodeDown: func(_ context.Context, node string, err error) {
|
||||||
|
downCalled.Add(1)
|
||||||
|
},
|
||||||
|
OnNodeUp: func(_ context.Context, node string) {
|
||||||
|
upCalled.Add(1)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &Node{name: "test"}
|
||||||
|
node.healthy.Store(true)
|
||||||
|
|
||||||
|
hc := newHealthChecker([]*Node{node}, HealthCheckConfig{
|
||||||
|
Interval: time.Hour, // won't tick in test
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
}, nopLogger{}, metrics)
|
||||||
|
|
||||||
|
// Simulate a down check: node has no pool, so ping will fail
|
||||||
|
// We can't easily test with a real pool here, but checkNode
|
||||||
|
// will panic on nil pool. In integration tests, we use real pools.
|
||||||
|
_ = hc // just verify construction works
|
||||||
|
|
||||||
|
// Verify state transitions
|
||||||
|
node.healthy.Store(false)
|
||||||
|
if node.IsHealthy() {
|
||||||
|
t.Error("expected unhealthy")
|
||||||
|
}
|
||||||
|
node.healthy.Store(true)
|
||||||
|
if !node.IsHealthy() {
|
||||||
|
t.Error("expected healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
110
node.go
Normal file
110
node.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Node wraps a pgxpool.Pool with health state and a human-readable name.
|
||||||
|
type Node struct {
|
||||||
|
name string
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
healthy atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNode creates a Node from an existing pool.
|
||||||
|
func newNode(name string, pool *pgxpool.Pool) *Node {
|
||||||
|
n := &Node{
|
||||||
|
name: name,
|
||||||
|
pool: pool,
|
||||||
|
}
|
||||||
|
n.healthy.Store(true)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectNode parses the NodeConfig, creates a pgxpool.Pool, and returns a Node.
|
||||||
|
func connectNode(ctx context.Context, cfg NodeConfig) (*Node, error) {
|
||||||
|
poolCfg, err := pgxpool.ParseConfig(cfg.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
applyPoolConfig(poolCfg, cfg.Pool)
|
||||||
|
|
||||||
|
pool, err := pgxpool.NewWithConfig(ctx, poolCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newNode(cfg.Name, pool), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPoolConfig(dst *pgxpool.Config, src PoolConfig) {
|
||||||
|
if src.MaxConns > 0 {
|
||||||
|
dst.MaxConns = src.MaxConns
|
||||||
|
}
|
||||||
|
if src.MinConns > 0 {
|
||||||
|
dst.MinConns = src.MinConns
|
||||||
|
}
|
||||||
|
if src.MaxConnLifetime > 0 {
|
||||||
|
dst.MaxConnLifetime = src.MaxConnLifetime
|
||||||
|
}
|
||||||
|
if src.MaxConnIdleTime > 0 {
|
||||||
|
dst.MaxConnIdleTime = src.MaxConnIdleTime
|
||||||
|
}
|
||||||
|
if src.HealthCheckPeriod > 0 {
|
||||||
|
dst.HealthCheckPeriod = src.HealthCheckPeriod
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the node's human-readable name.
|
||||||
|
func (n *Node) Name() string { return n.name }
|
||||||
|
|
||||||
|
// IsHealthy reports whether the node is considered healthy.
|
||||||
|
func (n *Node) IsHealthy() bool { return n.healthy.Load() }
|
||||||
|
|
||||||
|
// Pool returns the underlying pgxpool.Pool.
|
||||||
|
func (n *Node) Pool() *pgxpool.Pool { return n.pool }
|
||||||
|
|
||||||
|
// --- DB interface implementation ---
|
||||||
|
|
||||||
|
func (n *Node) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
||||||
|
return n.pool.Exec(ctx, sql, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
||||||
|
return n.pool.Query(ctx, sql, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
||||||
|
return n.pool.QueryRow(ctx, sql, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||||
|
return n.pool.SendBatch(ctx, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
|
return n.pool.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) Begin(ctx context.Context) (pgx.Tx, error) {
|
||||||
|
return n.pool.Begin(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
|
||||||
|
return n.pool.BeginTx(ctx, txOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) Ping(ctx context.Context) error {
|
||||||
|
return n.pool.Ping(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) Close() {
|
||||||
|
n.pool.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time check that *Node implements DB.
|
||||||
|
var _ DB = (*Node)(nil)
|
||||||
39
options.go
Normal file
39
options.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
// Option is a functional option for NewCluster.
|
||||||
|
type Option func(*Config)
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the cluster.
|
||||||
|
func WithLogger(l Logger) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.Logger = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMetrics sets the metrics hook for the cluster.
|
||||||
|
func WithMetrics(m *MetricsHook) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.Metrics = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetry overrides the retry configuration.
|
||||||
|
func WithRetry(r RetryConfig) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.Retry = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHealthCheck overrides the health check configuration.
|
||||||
|
func WithHealthCheck(h HealthCheckConfig) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.HealthCheck = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyOptions applies functional options to a Config.
|
||||||
|
func ApplyOptions(cfg *Config, opts ...Option) {
|
||||||
|
for _, o := range opts {
|
||||||
|
o(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
35
options_test.go
Normal file
35
options_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyOptions(t *testing.T) {
|
||||||
|
cfg := Config{}
|
||||||
|
|
||||||
|
logger := nopLogger{}
|
||||||
|
metrics := &MetricsHook{}
|
||||||
|
retry := RetryConfig{MaxAttempts: 7}
|
||||||
|
hc := HealthCheckConfig{Interval: 10 * time.Second}
|
||||||
|
|
||||||
|
ApplyOptions(&cfg,
|
||||||
|
WithLogger(logger),
|
||||||
|
WithMetrics(metrics),
|
||||||
|
WithRetry(retry),
|
||||||
|
WithHealthCheck(hc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.Logger == nil {
|
||||||
|
t.Error("Logger should be set")
|
||||||
|
}
|
||||||
|
if cfg.Metrics == nil {
|
||||||
|
t.Error("Metrics should be set")
|
||||||
|
}
|
||||||
|
if cfg.Retry.MaxAttempts != 7 {
|
||||||
|
t.Errorf("Retry.MaxAttempts = %d, want 7", cfg.Retry.MaxAttempts)
|
||||||
|
}
|
||||||
|
if cfg.HealthCheck.Interval != 10*time.Second {
|
||||||
|
t.Errorf("HealthCheck.Interval = %v, want 10s", cfg.HealthCheck.Interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
98
retry.go
Normal file
98
retry.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// retrier executes operations with retry and node fallback.
|
||||||
|
type retrier struct {
|
||||||
|
cfg RetryConfig
|
||||||
|
logger Logger
|
||||||
|
metrics *MetricsHook
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRetrier(cfg RetryConfig, logger Logger, metrics *MetricsHook) *retrier {
|
||||||
|
return &retrier{cfg: cfg, logger: logger, metrics: metrics}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRetryable checks the custom classifier first, then falls back to the default.
|
||||||
|
func (r *retrier) isRetryable(err error) bool {
|
||||||
|
if r.cfg.RetryableErrors != nil {
|
||||||
|
return r.cfg.RetryableErrors(err)
|
||||||
|
}
|
||||||
|
return IsRetryable(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// do executes fn on the given nodes in order, retrying on retryable errors.
|
||||||
|
// For writes, pass a single-element slice with the master.
|
||||||
|
// For reads, pass [replicas..., master] for fallback.
|
||||||
|
func (r *retrier) do(ctx context.Context, nodes []*Node, fn func(ctx context.Context, n *Node) error) error {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt < r.cfg.MaxAttempts; attempt++ {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
if !node.IsHealthy() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := fn(ctx, node)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
if !r.isRetryable(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.metrics != nil && r.metrics.OnRetry != nil {
|
||||||
|
r.metrics.OnRetry(ctx, node.name, attempt+1, err)
|
||||||
|
}
|
||||||
|
r.logger.Warn(ctx, "dbx: retryable error",
|
||||||
|
"node", node.name,
|
||||||
|
"attempt", attempt+1,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt < r.cfg.MaxAttempts-1 {
|
||||||
|
delay := r.backoff(attempt)
|
||||||
|
t := time.NewTimer(delay)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Stop()
|
||||||
|
if lastErr != nil {
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
case <-t.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr == nil {
|
||||||
|
return ErrNoHealthyNode
|
||||||
|
}
|
||||||
|
return newRetryError(r.cfg.MaxAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// backoff returns the delay for the given attempt with jitter.
|
||||||
|
func (r *retrier) backoff(attempt int) time.Duration {
|
||||||
|
delay := float64(r.cfg.BaseDelay) * math.Pow(2, float64(attempt))
|
||||||
|
if delay > float64(r.cfg.MaxDelay) {
|
||||||
|
delay = float64(r.cfg.MaxDelay)
|
||||||
|
}
|
||||||
|
// add jitter: 75%-125% of computed delay
|
||||||
|
jitter := 0.75 + rand.Float64()*0.5
|
||||||
|
return time.Duration(delay * jitter)
|
||||||
|
}
|
||||||
168
retry_test.go
Normal file
168
retry_test.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetrier_Success(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
calls := 0
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
calls++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Errorf("calls = %d, want 1", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_RetriesOnRetryableError(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
calls := 0
|
||||||
|
retryableErr := errors.New("connection refused")
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
calls++
|
||||||
|
if calls < 3 {
|
||||||
|
return retryableErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 3 {
|
||||||
|
t.Errorf("calls = %d, want 3", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_NonRetryableError(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
syntaxErr := errors.New("syntax problem")
|
||||||
|
calls := 0
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
calls++
|
||||||
|
return syntaxErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, syntaxErr) {
|
||||||
|
t.Errorf("expected syntax error, got %v", err)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_Exhausted(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
return errors.New("connection refused")
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrRetryExhausted) {
|
||||||
|
t.Errorf("expected ErrRetryExhausted, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_FallbackToNextNode(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("replica-1", "master")
|
||||||
|
|
||||||
|
visited := []string{}
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
visited = append(visited, n.name)
|
||||||
|
if n.name == "replica-1" {
|
||||||
|
return errors.New("connection refused")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(visited) < 2 || visited[1] != "master" {
|
||||||
|
t.Errorf("expected fallback to master, visited: %v", visited)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_ContextCanceled(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := r.do(ctx, nodes, func(_ context.Context, n *Node) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Errorf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_NoHealthyNodes(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
nodes[0].healthy.Store(false)
|
||||||
|
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
t.Fatal("should not be called")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, ErrNoHealthyNode) {
|
||||||
|
t.Errorf("expected ErrNoHealthyNode, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrier_CustomClassifier(t *testing.T) {
|
||||||
|
custom := func(err error) bool {
|
||||||
|
return err.Error() == "custom-retry"
|
||||||
|
}
|
||||||
|
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil)
|
||||||
|
nodes := makeTestNodes("n1")
|
||||||
|
|
||||||
|
calls := 0
|
||||||
|
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
|
||||||
|
calls++
|
||||||
|
if calls < 2 {
|
||||||
|
return errors.New("custom-retry")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 2 {
|
||||||
|
t.Errorf("calls = %d, want 2", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackoff(t *testing.T) {
|
||||||
|
r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil)
|
||||||
|
|
||||||
|
d0 := r.backoff(0)
|
||||||
|
if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond {
|
||||||
|
t.Errorf("backoff(0) = %v, expected ~100ms", d0)
|
||||||
|
}
|
||||||
|
|
||||||
|
d3 := r.backoff(3)
|
||||||
|
if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond {
|
||||||
|
t.Errorf("backoff(3) = %v, expected ~800ms", d3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should cap at MaxDelay
|
||||||
|
d10 := r.backoff(10)
|
||||||
|
if d10 > 1250*time.Millisecond {
|
||||||
|
t.Errorf("backoff(10) = %v, should be capped near 1s", d10)
|
||||||
|
}
|
||||||
|
}
|
||||||
61
tx.go
Normal file
61
tx.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TxFunc is the callback signature for RunTx.
|
||||||
|
type TxFunc func(ctx context.Context, tx pgx.Tx) error
|
||||||
|
|
||||||
|
// RunTx executes fn inside a transaction on the master node with default options.
|
||||||
|
// The transaction is committed if fn returns nil, rolled back otherwise.
|
||||||
|
// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised.
|
||||||
|
func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error {
|
||||||
|
return c.RunTxOptions(ctx, pgx.TxOptions{}, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunTxOptions executes fn inside a transaction with the given options.
|
||||||
|
func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error {
|
||||||
|
tx, err := c.master.BeginTx(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dbx: begin tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return runTx(ctx, tx, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) {
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
|
panic(p)
|
||||||
|
}
|
||||||
|
if retErr != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := fn(ctx, tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// querierKey is the context key for Querier injection.
|
||||||
|
type querierKey struct{}
|
||||||
|
|
||||||
|
// InjectQuerier returns a new context carrying q.
|
||||||
|
func InjectQuerier(ctx context.Context, q Querier) context.Context {
|
||||||
|
return context.WithValue(ctx, querierKey{}, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractQuerier returns the Querier from ctx, or fallback if none is present.
|
||||||
|
func ExtractQuerier(ctx context.Context, fallback Querier) Querier {
|
||||||
|
if q, ok := ctx.Value(querierKey{}).(Querier); ok {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
25
tx_test.go
Normal file
25
tx_test.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package dbx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInjectExtractQuerier(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
fallback := &Node{name: "fallback"}
|
||||||
|
|
||||||
|
// No querier in context → returns fallback
|
||||||
|
got := ExtractQuerier(ctx, fallback)
|
||||||
|
if got != fallback {
|
||||||
|
t.Error("expected fallback when no querier in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject querier → extract it
|
||||||
|
injected := &Node{name: "injected"}
|
||||||
|
ctx = InjectQuerier(ctx, injected)
|
||||||
|
got = ExtractQuerier(ctx, fallback)
|
||||||
|
if got != injected {
|
||||||
|
t.Error("expected injected querier")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user