Compare commits

..

5 Commits

Author SHA1 Message Date
2113d9cc75 Add publish workflow for Gitea Go Package Registry
All checks were successful
CI / test (push) Successful in 32s
Publish / publish (push) Successful in 29s
Publishes the module to Gitea Package Registry on tag push (v*).
Runs vet and tests before publishing to prevent broken releases.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 14:16:02 +03:00
5aa9c783c3 Fix NewTestCluster to skip tests when DB is unreachable
All checks were successful
CI / test (push) Successful in 30s
pgxpool.NewWithConfig does not connect eagerly, so NewCluster succeeds
even without a running database. Added a Ping check after cluster
creation to reliably skip tests when the database cannot be reached.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 00:22:02 +03:00
2c9af28548 Add production features: slog adapter, scan helpers, slow query logging, pool stats, tracer passthrough, test tx isolation
Some checks failed
CI / test (push) Failing after 13s
- slog.go: SlogLogger adapts *slog.Logger to dbx.Logger interface
- scan.go: Collect[T] and CollectOne[T] generic helpers using pgx.RowToStructByName
- cluster.go: slow query logging via Config.SlowQueryThreshold (Warn level in queryEnd)
- stats.go: PoolStats with Cluster.Stats() aggregating pool stats across all nodes
- config.go/node.go: NodeConfig.Tracer passthrough for pgx.QueryTracer (OpenTelemetry)
- options.go: WithSlowQueryThreshold and WithTracer functional options
- dbxtest/tx.go: RunInTx runs callback in always-rolled-back transaction for test isolation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 00:19:26 +03:00
7d25e1b73e Add CI workflow, README, CLAUDE.md, AGENTS.md, and .cursorrules
All checks were successful
CI / test (push) Successful in 51s
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 00:01:27 +03:00
62df3a2eb3 Add dbx library: PostgreSQL cluster with master/replica routing, retry, health checking
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 00:01:15 +03:00
35 changed files with 2526 additions and 0 deletions

58
.cursorrules Normal file
View File

@@ -0,0 +1,58 @@
You are working on `git.codelab.vc/pkg/dbx`, a Go PostgreSQL cluster library built on pgx/v5.
## Architecture
- Cluster manages master + replicas with method-based routing (no SQL parsing)
- Write ops (Exec, Query, QueryRow, Begin, BeginTx, CopyFrom, SendBatch) → master
- Read ops (ReadQuery, ReadQueryRow) → replicas with master fallback
- Retry with exponential backoff + jitter, iterates nodes then backs off
- Round-robin balancer skips unhealthy nodes
- Background health checker pings all nodes on interval
- RunTx — panic-safe transaction wrapper (recover → rollback → re-panic)
- InjectQuerier/ExtractQuerier — context-based Querier for service layers
## Package structure
- `dbx` (root) — Cluster, Node, Balancer, retry, health, errors, tx, config, options
- `dbx.go` — interfaces: Querier, DB, Logger, MetricsHook
- `cluster.go` — Cluster routing and query execution
- `node.go` — Node wrapping pgxpool.Pool with health state
- `balancer.go` — Balancer interface + RoundRobinBalancer
- `retry.go` — retrier with backoff and node fallback
- `health.go` — background health checker goroutine
- `tx.go` — RunTx, RunTxOptions, InjectQuerier, ExtractQuerier
- `errors.go` — IsRetryable, IsConnectionError, IsConstraintViolation, PgErrorCode
- `config.go` — Config, NodeConfig, PoolConfig, RetryConfig, HealthCheckConfig
- `options.go` — functional options (WithLogger, WithMetrics, WithRetry, WithHealthCheck, WithSlowQueryThreshold, WithTracer)
- `slog.go` — SlogLogger adapting *slog.Logger to dbx.Logger
- `scan.go` — Collect[T], CollectOne[T] generic row scan helpers
- `stats.go` — PoolStats aggregate pool statistics via Cluster.Stats()
- `dbxtest/` — test helpers: NewTestCluster, TestLogger, RunInTx
## Code conventions
- Struct-based Config with defaults() method for zero-value defaults
- Functional options (Option func(*Config)) used via ApplyOptions
- stdlib only testing — no testify, no gomock
- Thread safety with atomic.Bool (Node.healthy, Cluster.closed)
- dbxtest.NewTestCluster skips on unreachable DB, auto-closes via t.Cleanup
- Sentinel errors: ErrNoHealthyNode, ErrClusterClosed, ErrRetryExhausted
- retryError multi-unwrap for errors.Is compatibility
## When writing new code
- New node type → add to Cluster struct, Config, connect in NewCluster, add to `all` for health checking
- New balancer → implement Balancer interface, check IsHealthy(), return nil if no suitable node
- New retry logic → provide RetryConfig.RetryableErrors or extend IsRetryable()
- New metrics hook → add field to MetricsHook, nil-check before calling
- Close() is required — leaking a Cluster leaks goroutines and connections
- No SQL parsing — routing is method-based, Exec with SELECT still goes to master
## Commands
```bash
go build ./... # compile
go test ./... # test
go test -race ./... # test with race detector
go vet ./... # static analysis
```

23
.gitea/workflows/ci.yml Normal file
View File

@@ -0,0 +1,23 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24"
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...

View File

@@ -0,0 +1,39 @@
name: Publish
on:
push:
tags: ["v*"]
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Vet
run: go vet ./...
- name: Test
run: go test -race -count=1 ./...
- name: Publish to Gitea Package Registry
run: |
VERSION=${GITHUB_REF#refs/tags/}
MODULE=$(go list -m)
# Create module zip with required prefix: module@version/
git archive --format=zip --prefix="${MODULE}@${VERSION}/" HEAD -o module.zip
# Gitea Go Package Registry API
curl -s -f \
-X PUT \
-H "Authorization: token ${{ secrets.PUBLISH_TOKEN }}" \
-H "Content-Type: application/zip" \
--data-binary @module.zip \
"${{ github.server_url }}/api/packages/pkg/go/upload?module=${MODULE}&version=${VERSION}"
env:
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}

113
AGENTS.md Normal file
View File

@@ -0,0 +1,113 @@
# AGENTS.md — dbx
Universal guide for AI coding agents working with this codebase.
## Overview
`git.codelab.vc/pkg/dbx` is a Go PostgreSQL cluster library built on **pgx/v5**. It provides master/replica routing, automatic retries, load balancing, background health checking, panic-safe transactions, and context-based Querier injection.
## Package map
```
dbx/ Root — Cluster, Node, Balancer, retry, health, errors, tx, config
├── dbx.go Interfaces: Querier, DB, Logger, MetricsHook
├── cluster.go Cluster — routing, write/read operations, slow query logging
├── node.go Node — pgxpool.Pool wrapper with health state, tracer passthrough
├── balancer.go Balancer interface + RoundRobinBalancer
├── retry.go retrier — exponential backoff with jitter and node fallback
├── health.go healthChecker — background goroutine pinging nodes
├── tx.go RunTx, RunTxOptions, InjectQuerier, ExtractQuerier
├── errors.go Error classification (IsRetryable, IsConnectionError, etc.)
├── config.go Config, NodeConfig, PoolConfig, RetryConfig, HealthCheckConfig
├── options.go Functional options (WithLogger, WithMetrics, WithRetry, WithTracer, etc.)
├── slog.go SlogLogger — adapts *slog.Logger to dbx.Logger
├── scan.go Collect[T], CollectOne[T] — generic row scan helpers
├── stats.go PoolStats — aggregate pool statistics via Cluster.Stats()
└── dbxtest/
├── dbxtest.go Test helpers: NewTestCluster, TestLogger
└── tx.go RunInTx — test transaction isolation (always rolled back)
```
## Routing architecture
```
┌──────────────┐
│ Cluster │
└──────┬───────┘
┌───────────────┴───────────────┐
│ │
Write ops Read ops
Exec, Query, QueryRow ReadQuery, ReadQueryRow
Begin, BeginTx, RunTx
CopyFrom, SendBatch
│ │
▼ ▼
┌──────────┐ ┌────────────────────────┐
│ Master │ │ Balancer → Replicas │
└──────────┘ │ fallback → Master │
└────────────────────────┘
Retry loop (retrier.do):
For each attempt (up to MaxAttempts):
For each node in [target nodes]:
if healthy → execute → on retryable error → continue
Backoff (exponential + jitter)
```
## Common tasks
### Add a new node type (e.g., analytics replica)
1. Add a field to `Cluster` struct (e.g., `analytics []*Node`)
2. Add corresponding config to `Config` struct
3. Connect nodes in `NewCluster`, add to `all` slice for health checking
4. Add routing methods (e.g., `AnalyticsQuery`)
### Customize retry logic
1. Provide `RetryConfig.RetryableErrors` — custom `func(error) bool` classifier
2. Or modify `IsRetryable()` in `errors.go` to add new PG error codes
3. Adjust `MaxAttempts`, `BaseDelay`, `MaxDelay` in `RetryConfig`
### Add a metrics hook
1. Add a new callback field to `MetricsHook` struct in `dbx.go`
2. Call it at the appropriate point (nil-check the hook and the field)
3. See existing hooks in `cluster.go` (queryStart/queryEnd) and `health.go` (OnNodeDown/OnNodeUp)
### Add a new balancer strategy
1. Implement the `Balancer` interface: `Next(nodes []*Node) *Node`
2. Must return `nil` if no suitable node is available
3. Must check `node.IsHealthy()` to skip down nodes
## Gotchas
- **Close() is required**: `Cluster.Close()` stops the health checker goroutine and closes all pools. Leaking a Cluster leaks goroutines and connections
- **RunTx panic safety**: `runTx` uses `defer` with `recover()` — it rolls back on panic, then re-panics. Do not catch panics outside `RunTx` expecting the tx to be committed
- **Context-based Querier injection**: `ExtractQuerier` returns the fallback if no Querier is in context. Always pass the cluster/pool as fallback so code works both inside and outside transactions
- **Health checker goroutine**: Starts immediately in `NewCluster`. Uses `time.NewTicker` — the first check happens after one interval, not immediately. Nodes start as healthy (`healthy.Store(true)` in `newNode`)
- **readNodes ordering**: `readNodes()` returns `[replicas..., master]` — the retrier tries replicas first, master is the last fallback
- **errRow for closed cluster**: When cluster is closed, `QueryRow`/`ReadQueryRow` return `errRow{err: ErrClusterClosed}` — the error surfaces on `Scan()`
- **No SQL parsing**: Routing is purely method-based. If you call `Exec` with a SELECT, it still goes to master
## Commands
```bash
go build ./... # compile
go test ./... # all tests
go test -race ./... # tests with race detector
go test -v -run TestName ./... # single test
go vet ./... # static analysis
```
## Conventions
- **Struct-based Config** with `defaults()` method for zero-value defaults
- **Functional options** (`Option func(*Config)`) used via `ApplyOptions` (primarily in dbxtest)
- **stdlib only** testing — no testify, no gomock
- **Thread safety** — `atomic.Bool` for `Node.healthy` and `Cluster.closed`
- **dbxtest helpers** — `NewTestCluster` skips on unreachable DB, auto-closes via `t.Cleanup`; `TestLogger` routes to `testing.T`
- **Sentinel errors** — `ErrNoHealthyNode`, `ErrClusterClosed`, `ErrRetryExhausted`
- **retryError** uses multi-unwrap (`Unwrap() []error`) so both `ErrRetryExhausted` and the last error can be matched with `errors.Is`

53
CLAUDE.md Normal file
View File

@@ -0,0 +1,53 @@
# CLAUDE.md — dbx
## Commands
```bash
go build ./... # compile
go test ./... # all tests
go test -race ./... # tests with race detector
go test -v -run TestName ./... # single test
go vet ./... # static analysis
```
## Architecture
- **Module**: `git.codelab.vc/pkg/dbx`, Go 1.24, depends on pgx/v5
- **Single package** `dbx` (+ `dbxtest` for test helpers)
### Core patterns
- **Cluster** is the entry point — connects master + replicas, routes writes to master, reads to replicas with master fallback
- **Routing is method-based**: `Exec`/`Query`/`QueryRow`/`Begin`/`BeginTx`/`CopyFrom`/`SendBatch` → master; `ReadQuery`/`ReadQueryRow` → replicas
- **Retry** with exponential backoff + jitter, node fallback; retrier.do() iterates nodes then backs off
- **Balancer** interface (`Next([]*Node) *Node`) — built-in `RoundRobinBalancer` skips unhealthy nodes
- **Health checker** — background goroutine pings all nodes on an interval, flips `Node.healthy` atomic bool
- **RunTx** — panic-safe transaction wrapper: recovers panics, rolls back, re-panics
- **Querier injection** — `InjectQuerier`/`ExtractQuerier` pass `Querier` via context for service layers
- **SlogLogger** — adapts `*slog.Logger` to the `dbx.Logger` interface (`slog.go`)
- **Collect/CollectOne** — generic scan helpers using `pgx.RowToStructByName` (`scan.go`)
- **Slow query logging** — `Config.SlowQueryThreshold` triggers Warn-level logging in `queryEnd`
- **PoolStats** — `Cluster.Stats()` aggregates pool statistics across all nodes (`stats.go`)
- **Tracer passthrough** — `NodeConfig.Tracer` / `WithTracer` sets `pgx.QueryTracer` for OpenTelemetry
- **RunInTx** — test helper that runs a callback in an always-rolled-back transaction (`dbxtest/tx.go`)
### Error classification
- `IsRetryable(err)` — connection errors (class 08), serialization failures (40001), deadlocks (40P01), too_many_connections (53300)
- `IsConnectionError(err)` — PG class 08 + string matching for pgx-wrapped errors
- `IsConstraintViolation(err)` — PG class 23
- `PgErrorCode(err)` — extract raw code from `*pgconn.PgError`
## Conventions
- Struct-based `Config` with `defaults()` method (not functional options for NewCluster constructor, but `Option` type exists for `ApplyOptions` in tests)
- Functional options (`Option func(*Config)`) used via `ApplyOptions` (e.g., in dbxtest)
- stdlib-only tests — no testify, no gomock
- `atomic.Bool` for thread safety (`Node.healthy`, `Cluster.closed`)
- `dbxtest.NewTestCluster` skips tests when DB unreachable, auto-closes via `t.Cleanup`
- `dbxtest.TestLogger` writes to `testing.T` for test log output
- `dbxtest.RunInTx` runs a callback in a transaction that is always rolled back
## See also
- `AGENTS.md` — universal AI agent guide with common tasks, gotchas, and ASCII diagrams

245
README.md
View File

@@ -1,2 +1,247 @@
# dbx
PostgreSQL cluster library for Go built on pgx/v5. Master/replica routing, automatic retries with exponential backoff, round-robin load balancing, background health checking, panic-safe transactions, and context-based Querier injection.
```
go get git.codelab.vc/pkg/dbx
```
## Quick start
```go
cluster, err := dbx.NewCluster(ctx, dbx.Config{
Master: dbx.NodeConfig{
Name: "master",
DSN: "postgres://user:pass@master:5432/mydb",
Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5},
},
Replicas: []dbx.NodeConfig{
{Name: "replica-1", DSN: "postgres://user:pass@replica1:5432/mydb"},
{Name: "replica-2", DSN: "postgres://user:pass@replica2:5432/mydb"},
},
})
if err != nil {
log.Fatal(err)
}
defer cluster.Close()
// Write → master
cluster.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "alice")
// Read → replica with automatic fallback to master
rows, err := cluster.ReadQuery(ctx, "SELECT * FROM users WHERE active = $1", true)
// Transaction → master, panic-safe
cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", 100, fromID)
tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", 100, toID)
return nil
})
```
## Components
| Component | What it does |
|-----------|-------------|
| `Cluster` | Entry point. Connects to master + replicas, routes queries, manages lifecycle. |
| `Node` | Wraps `pgxpool.Pool` with health state and a human-readable name. |
| `Balancer` | Interface for replica selection. Built-in: `RoundRobinBalancer`. |
| `retrier` | Exponential backoff with jitter, node fallback, custom error classifiers. |
| `healthChecker` | Background goroutine that pings all nodes on an interval. |
| `Querier` injection | `InjectQuerier` / `ExtractQuerier` — context-based Querier for service layers. |
| `MetricsHook` | Optional callbacks: query start/end, retry, node up/down, replica fallback. |
| `SlogLogger` | Adapts `*slog.Logger` to the `dbx.Logger` interface. |
| `Collect`/`CollectOne` | Generic scan helpers — read rows directly into structs via `pgx.RowToStructByName`. |
| `PoolStats` | Aggregate pool statistics across all nodes via `cluster.Stats()`. |
## Routing
The library uses explicit method-based routing (no SQL parsing):
```
┌──────────────┐
│ Cluster │
└──────┬───────┘
┌───────────────┴───────────────┐
│ │
Write ops Read ops
Exec, Query, QueryRow ReadQuery, ReadQueryRow
Begin, BeginTx, RunTx
CopyFrom, SendBatch
│ │
▼ ▼
┌──────────┐ ┌────────────────────────┐
│ Master │ │ Balancer → Replicas │
└──────────┘ │ fallback → Master │
└────────────────────────┘
```
Direct node access: `cluster.Master()` and `cluster.Replica()` return `DB`.
## Multi-replica setup
```go
cluster, _ := dbx.NewCluster(ctx, dbx.Config{
Master: dbx.NodeConfig{
Name: "master",
DSN: "postgres://master:5432/mydb",
Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5},
},
Replicas: []dbx.NodeConfig{
{Name: "replica-1", DSN: "postgres://replica1:5432/mydb"},
{Name: "replica-2", DSN: "postgres://replica2:5432/mydb"},
{Name: "replica-3", DSN: "postgres://replica3:5432/mydb"},
},
Retry: dbx.RetryConfig{
MaxAttempts: 5,
BaseDelay: 100 * time.Millisecond,
MaxDelay: 2 * time.Second,
},
HealthCheck: dbx.HealthCheckConfig{
Interval: 3 * time.Second,
Timeout: 1 * time.Second,
},
})
defer cluster.Close()
```
## Transactions
`RunTx` is panic-safe — if the callback panics, the transaction is rolled back and the panic is re-raised:
```go
err := cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", amount, fromID)
if err != nil {
return err
}
_, err = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", amount, toID)
return err
})
```
For custom isolation levels use `RunTxOptions`:
```go
cluster.RunTxOptions(ctx, pgx.TxOptions{
IsoLevel: pgx.Serializable,
}, fn)
```
## Context-based Querier injection
Pass the Querier through context so service layers work both inside and outside transactions:
```go
// Repository
func CreateUser(ctx context.Context, db dbx.Querier, name string) error {
q := dbx.ExtractQuerier(ctx, db)
_, err := q.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", name)
return err
}
// Outside transaction — uses cluster directly
CreateUser(ctx, cluster, "alice")
// Inside transaction — uses tx
cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
ctx = dbx.InjectQuerier(ctx, tx)
return CreateUser(ctx, cluster, "alice") // will use tx from context
})
```
## Error classification
```go
dbx.IsRetryable(err) // connection errors, serialization failures, deadlocks, too_many_connections
dbx.IsConnectionError(err) // PG class 08 + common connection error strings
dbx.IsConstraintViolation(err) // PG class 23 (unique, FK, check violations)
dbx.PgErrorCode(err) // extract raw PG error code
```
Sentinel errors: `ErrNoHealthyNode`, `ErrClusterClosed`, `ErrRetryExhausted`.
## slog integration
```go
cluster, _ := dbx.NewCluster(ctx, dbx.Config{
Master: dbx.NodeConfig{DSN: "postgres://..."},
Logger: dbx.NewSlogLogger(slog.Default()),
})
```
## Scan helpers
Generic functions that eliminate row scanning boilerplate:
```go
type User struct {
ID int `db:"id"`
Name string `db:"name"`
}
users, err := dbx.Collect[User](ctx, cluster, "SELECT id, name FROM users WHERE active = $1", true)
user, err := dbx.CollectOne[User](ctx, cluster, "SELECT id, name FROM users WHERE id = $1", 42)
// returns pgx.ErrNoRows if not found
```
## Slow query logging
```go
cluster, _ := dbx.NewCluster(ctx, dbx.Config{
Master: dbx.NodeConfig{DSN: "postgres://..."},
Logger: dbx.NewSlogLogger(slog.Default()),
SlowQueryThreshold: 100 * time.Millisecond,
})
// queries exceeding threshold are logged at Warn level
```
## Pool stats
```go
stats := cluster.Stats()
fmt.Println(stats.TotalConns, stats.IdleConns, stats.AcquireCount)
// per-node stats: stats.Nodes["master"], stats.Nodes["replica-1"]
```
## OpenTelemetry / pgx tracer
Pass any `pgx.QueryTracer` (e.g., `otelpgx.NewTracer()`) to instrument all queries:
```go
dbx.ApplyOptions(&cfg, dbx.WithTracer(otelpgx.NewTracer()))
```
Or set per-node via `NodeConfig.Tracer`.
## dbxtest helpers
The `dbxtest` package provides test helpers:
```go
func TestMyRepo(t *testing.T) {
cluster := dbxtest.NewTestCluster(t, dbx.WithLogger(&dbxtest.TestLogger{T: t}))
// cluster is auto-closed when test finishes
// skips test if DB is not reachable
}
```
### Transaction isolation for tests
```go
func TestCreateUser(t *testing.T) {
c := dbxtest.NewTestCluster(t)
dbxtest.RunInTx(t, c, func(ctx context.Context, q dbx.Querier) {
// all changes are rolled back after fn returns
_, _ = q.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "test")
})
}
```
Set `DBX_TEST_DSN` env var to override the default DSN (`postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable`).
## Requirements
Go 1.24+, [pgx/v5](https://github.com/jackc/pgx).

35
balancer.go Normal file
View File

@@ -0,0 +1,35 @@
package dbx
import "sync/atomic"
// Balancer selects the next node from a list.
// It must return nil if no suitable node is available.
type Balancer interface {
Next(nodes []*Node) *Node
}
// RoundRobinBalancer distributes load evenly across healthy nodes.
type RoundRobinBalancer struct {
counter atomic.Uint64
}
// NewRoundRobinBalancer creates a new round-robin balancer.
func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{}
}
// Next returns the next healthy node, or nil if none are healthy.
func (b *RoundRobinBalancer) Next(nodes []*Node) *Node {
n := len(nodes)
if n == 0 {
return nil
}
idx := b.counter.Add(1)
for i := 0; i < n; i++ {
node := nodes[(int(idx)+i)%n]
if node.IsHealthy() {
return node
}
}
return nil
}

73
balancer_test.go Normal file
View File

@@ -0,0 +1,73 @@
package dbx
import (
"testing"
)
func TestRoundRobinBalancer_Empty(t *testing.T) {
b := NewRoundRobinBalancer()
if n := b.Next(nil); n != nil {
t.Errorf("expected nil for empty slice, got %v", n)
}
}
func TestRoundRobinBalancer_AllHealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b", "c")
seen := map[string]int{}
for range 30 {
n := b.Next(nodes)
if n == nil {
t.Fatal("unexpected nil")
}
seen[n.name]++
}
if len(seen) != 3 {
t.Errorf("expected 3 distinct nodes, got %d", len(seen))
}
for name, count := range seen {
if count != 10 {
t.Errorf("node %s hit %d times, expected 10", name, count)
}
}
}
func TestRoundRobinBalancer_SkipsUnhealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b", "c")
nodes[1].healthy.Store(false) // b is down
for range 20 {
n := b.Next(nodes)
if n == nil {
t.Fatal("unexpected nil")
}
if n.name == "b" {
t.Error("should not return unhealthy node b")
}
}
}
func TestRoundRobinBalancer_AllUnhealthy(t *testing.T) {
b := NewRoundRobinBalancer()
nodes := makeTestNodes("a", "b")
for _, n := range nodes {
n.healthy.Store(false)
}
if n := b.Next(nodes); n != nil {
t.Errorf("expected nil when all unhealthy, got %v", n.name)
}
}
func makeTestNodes(names ...string) []*Node {
nodes := make([]*Node, len(names))
for i, name := range names {
n := &Node{name: name}
n.healthy.Store(true)
nodes[i] = n
}
return nodes
}

256
cluster.go Normal file
View File

@@ -0,0 +1,256 @@
package dbx
import (
"context"
"sync/atomic"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
// Cluster manages a master node and optional replicas.
// Write operations go to master; read operations are distributed across replicas
// with automatic fallback to master.
type Cluster struct {
master *Node
replicas []*Node
all []*Node // master + replicas for health checker
balancer Balancer
retrier *retrier
health *healthChecker
logger Logger
metrics *MetricsHook
slowQueryThreshold time.Duration
closed atomic.Bool
}
// NewCluster creates a Cluster, connecting to all configured nodes.
func NewCluster(ctx context.Context, cfg Config) (*Cluster, error) {
cfg.defaults()
master, err := connectNode(ctx, cfg.Master)
if err != nil {
return nil, err
}
replicas := make([]*Node, 0, len(cfg.Replicas))
for _, rc := range cfg.Replicas {
r, err := connectNode(ctx, rc)
if err != nil {
// close already-connected nodes
master.Close()
for _, opened := range replicas {
opened.Close()
}
return nil, err
}
replicas = append(replicas, r)
}
all := make([]*Node, 0, 1+len(replicas))
all = append(all, master)
all = append(all, replicas...)
c := &Cluster{
master: master,
replicas: replicas,
all: all,
balancer: NewRoundRobinBalancer(),
retrier: newRetrier(cfg.Retry, cfg.Logger, cfg.Metrics),
logger: cfg.Logger,
metrics: cfg.Metrics,
slowQueryThreshold: cfg.SlowQueryThreshold,
}
c.health = newHealthChecker(all, cfg.HealthCheck, cfg.Logger, cfg.Metrics)
c.health.start()
return c, nil
}
// Master returns the master node as a DB.
func (c *Cluster) Master() DB { return c.master }
// Replica returns a healthy replica as a DB, or falls back to master.
func (c *Cluster) Replica() DB {
if n := c.balancer.Next(c.replicas); n != nil {
return n
}
return c.master
}
// Close shuts down the health checker and closes all connection pools.
func (c *Cluster) Close() {
if !c.closed.CompareAndSwap(false, true) {
return
}
c.health.shutdown()
for _, n := range c.all {
n.Close()
}
}
// Ping checks connectivity to the master node.
func (c *Cluster) Ping(ctx context.Context) error {
if c.closed.Load() {
return ErrClusterClosed
}
return c.master.Ping(ctx)
}
// --- Write operations (always master) ---
func (c *Cluster) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
if c.closed.Load() {
return pgconn.CommandTag{}, ErrClusterClosed
}
var tag pgconn.CommandTag
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
tag, e = n.Exec(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return tag, err
}
func (c *Cluster) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
var rows pgx.Rows
err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
rows, e = n.Query(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return rows, err
}
func (c *Cluster) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
if c.closed.Load() {
return errRow{err: ErrClusterClosed}
}
var row pgx.Row
_ = c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
row = n.QueryRow(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
return nil
})
return row
}
func (c *Cluster) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return c.master.SendBatch(ctx, b)
}
func (c *Cluster) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
if c.closed.Load() {
return 0, ErrClusterClosed
}
return c.master.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (c *Cluster) Begin(ctx context.Context) (pgx.Tx, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
return c.master.Begin(ctx)
}
func (c *Cluster) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
return c.master.BeginTx(ctx, txOptions)
}
// --- Read operations (replicas with master fallback) ---
// ReadQuery executes a read query on a replica, falling back to master if no replicas are healthy.
func (c *Cluster) ReadQuery(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
if c.closed.Load() {
return nil, ErrClusterClosed
}
nodes := c.readNodes()
var rows pgx.Rows
err := c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
var e error
rows, e = n.Query(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, e, time.Since(start))
return e
})
return rows, err
}
// ReadQueryRow executes a read query returning a single row, using replicas with master fallback.
func (c *Cluster) ReadQueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
if c.closed.Load() {
return errRow{err: ErrClusterClosed}
}
nodes := c.readNodes()
var row pgx.Row
_ = c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error {
start := time.Now()
c.queryStart(ctx, n.name, sql)
row = n.QueryRow(ctx, sql, args...)
c.queryEnd(ctx, n.name, sql, nil, time.Since(start))
return nil
})
return row
}
// readNodes returns replicas followed by master for fallback ordering.
func (c *Cluster) readNodes() []*Node {
if len(c.replicas) == 0 {
return []*Node{c.master}
}
nodes := make([]*Node, 0, len(c.replicas)+1)
nodes = append(nodes, c.replicas...)
nodes = append(nodes, c.master)
return nodes
}
// --- metrics helpers ---
func (c *Cluster) queryStart(ctx context.Context, node, sql string) {
if c.metrics != nil && c.metrics.OnQueryStart != nil {
c.metrics.OnQueryStart(ctx, node, sql)
}
}
func (c *Cluster) queryEnd(ctx context.Context, node, sql string, err error, d time.Duration) {
if c.metrics != nil && c.metrics.OnQueryEnd != nil {
c.metrics.OnQueryEnd(ctx, node, sql, err, d)
}
if c.slowQueryThreshold > 0 && d >= c.slowQueryThreshold {
c.logger.Warn(ctx, "dbx: slow query",
"node", node,
"duration", d,
"sql", sql,
)
}
}
// errRow implements pgx.Row for error cases.
type errRow struct {
err error
}
func (r errRow) Scan(...any) error {
return r.err
}
// Compile-time check that *Cluster implements DB.
var _ DB = (*Cluster)(nil)

94
config.go Normal file
View File

@@ -0,0 +1,94 @@
package dbx
import (
"time"
"github.com/jackc/pgx/v5"
)
// Config is the top-level configuration for a Cluster.
type Config struct {
Master NodeConfig
Replicas []NodeConfig
Retry RetryConfig
Logger Logger
Metrics *MetricsHook
HealthCheck HealthCheckConfig
SlowQueryThreshold time.Duration
}
// NodeConfig describes a single database node.
type NodeConfig struct {
Name string // human-readable name for logs/metrics, e.g. "master", "replica-1"
DSN string
Pool PoolConfig
Tracer pgx.QueryTracer
}
// PoolConfig controls pgxpool.Pool parameters.
type PoolConfig struct {
MaxConns int32
MinConns int32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
HealthCheckPeriod time.Duration
}
// RetryConfig controls retry behaviour.
type RetryConfig struct {
MaxAttempts int // default: 3
BaseDelay time.Duration // default: 50ms
MaxDelay time.Duration // default: 1s
RetryableErrors func(error) bool // optional custom classifier
}
// HealthCheckConfig controls the background health checker.
type HealthCheckConfig struct {
Interval time.Duration // default: 5s
Timeout time.Duration // default: 2s
}
// defaults fills zero-valued fields with sensible defaults.
func (c *Config) defaults() {
if c.Logger == nil {
c.Logger = nopLogger{}
}
if c.Retry.MaxAttempts <= 0 {
c.Retry.MaxAttempts = 3
}
if c.Retry.BaseDelay <= 0 {
c.Retry.BaseDelay = 50 * time.Millisecond
}
if c.Retry.MaxDelay <= 0 {
c.Retry.MaxDelay = time.Second
}
if c.HealthCheck.Interval <= 0 {
c.HealthCheck.Interval = 5 * time.Second
}
if c.HealthCheck.Timeout <= 0 {
c.HealthCheck.Timeout = 2 * time.Second
}
if c.Master.Name == "" {
c.Master.Name = "master"
}
for i := range c.Replicas {
if c.Replicas[i].Name == "" {
c.Replicas[i].Name = "replica-" + itoa(i+1)
}
}
}
// itoa is a minimal int-to-string without importing strconv.
func itoa(n int) string {
if n == 0 {
return "0"
}
buf := [20]byte{}
i := len(buf)
for n > 0 {
i--
buf[i] = byte('0' + n%10)
n /= 10
}
return string(buf[i:])
}

72
config_test.go Normal file
View File

@@ -0,0 +1,72 @@
package dbx
import (
"testing"
"time"
)
func TestConfigDefaults(t *testing.T) {
cfg := Config{}
cfg.defaults()
if cfg.Logger == nil {
t.Error("Logger should not be nil after defaults")
}
if cfg.Retry.MaxAttempts != 3 {
t.Errorf("MaxAttempts = %d, want 3", cfg.Retry.MaxAttempts)
}
if cfg.Retry.BaseDelay != 50*time.Millisecond {
t.Errorf("BaseDelay = %v, want 50ms", cfg.Retry.BaseDelay)
}
if cfg.Retry.MaxDelay != time.Second {
t.Errorf("MaxDelay = %v, want 1s", cfg.Retry.MaxDelay)
}
if cfg.HealthCheck.Interval != 5*time.Second {
t.Errorf("HealthCheck.Interval = %v, want 5s", cfg.HealthCheck.Interval)
}
if cfg.HealthCheck.Timeout != 2*time.Second {
t.Errorf("HealthCheck.Timeout = %v, want 2s", cfg.HealthCheck.Timeout)
}
if cfg.Master.Name != "master" {
t.Errorf("Master.Name = %q, want %q", cfg.Master.Name, "master")
}
}
func TestConfigDefaultsReplicaNames(t *testing.T) {
cfg := Config{
Replicas: []NodeConfig{
{DSN: "a"},
{Name: "custom", DSN: "b"},
{DSN: "c"},
},
}
cfg.defaults()
if cfg.Replicas[0].Name != "replica-1" {
t.Errorf("Replica[0].Name = %q, want %q", cfg.Replicas[0].Name, "replica-1")
}
if cfg.Replicas[1].Name != "custom" {
t.Errorf("Replica[1].Name = %q, want %q", cfg.Replicas[1].Name, "custom")
}
if cfg.Replicas[2].Name != "replica-3" {
t.Errorf("Replica[2].Name = %q, want %q", cfg.Replicas[2].Name, "replica-3")
}
}
func TestConfigDefaultsNoOverwrite(t *testing.T) {
cfg := Config{
Retry: RetryConfig{
MaxAttempts: 5,
BaseDelay: 100 * time.Millisecond,
MaxDelay: 2 * time.Second,
},
}
cfg.defaults()
if cfg.Retry.MaxAttempts != 5 {
t.Errorf("MaxAttempts = %d, want 5", cfg.Retry.MaxAttempts)
}
if cfg.Retry.BaseDelay != 100*time.Millisecond {
t.Errorf("BaseDelay = %v, want 100ms", cfg.Retry.BaseDelay)
}
}

54
dbx.go Normal file
View File

@@ -0,0 +1,54 @@
package dbx
import (
"context"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
// Querier is the common interface satisfied by pgxpool.Pool, pgx.Tx, and pgx.Conn.
type Querier interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
}
// DB extends Querier with transaction support and lifecycle management.
type DB interface {
Querier
Begin(ctx context.Context) (pgx.Tx, error)
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
Ping(ctx context.Context) error
Close()
}
// Logger is the interface for pluggable structured logging.
type Logger interface {
Debug(ctx context.Context, msg string, fields ...any)
Info(ctx context.Context, msg string, fields ...any)
Warn(ctx context.Context, msg string, fields ...any)
Error(ctx context.Context, msg string, fields ...any)
}
// MetricsHook provides optional callbacks for observability.
// All fields are nil-safe — only set the hooks you need.
type MetricsHook struct {
OnQueryStart func(ctx context.Context, node string, sql string)
OnQueryEnd func(ctx context.Context, node string, sql string, err error, duration time.Duration)
OnRetry func(ctx context.Context, node string, attempt int, err error)
OnNodeDown func(ctx context.Context, node string, err error)
OnNodeUp func(ctx context.Context, node string)
OnReplicaFallback func(ctx context.Context, fromNode string, toNode string)
}
// nopLogger is a Logger that discards all output.
type nopLogger struct{}
func (nopLogger) Debug(context.Context, string, ...any) {}
func (nopLogger) Info(context.Context, string, ...any) {}
func (nopLogger) Warn(context.Context, string, ...any) {}
func (nopLogger) Error(context.Context, string, ...any) {}

87
dbxtest/dbxtest.go Normal file
View File

@@ -0,0 +1,87 @@
// Package dbxtest provides test helpers for users of the dbx library.
package dbxtest
import (
"context"
"os"
"testing"
"git.codelab.vc/pkg/dbx"
)
const (
// EnvTestDSN is the environment variable for the test database DSN.
EnvTestDSN = "DBX_TEST_DSN"
// DefaultTestDSN is the default DSN used when EnvTestDSN is not set.
DefaultTestDSN = "postgres://postgres:postgres@localhost:5432/dbx_test?sslmode=disable"
)
// TestDSN returns the DSN from the environment or the default.
func TestDSN() string {
if dsn := os.Getenv(EnvTestDSN); dsn != "" {
return dsn
}
return DefaultTestDSN
}
// NewTestCluster creates a Cluster connected to a test database.
// It skips the test if the database is not reachable.
// The cluster is automatically closed when the test finishes.
func NewTestCluster(t testing.TB, opts ...dbx.Option) *dbx.Cluster {
t.Helper()
dsn := TestDSN()
cfg := dbx.Config{
Master: dbx.NodeConfig{
Name: "test-master",
DSN: dsn,
Pool: dbx.PoolConfig{MaxConns: 5, MinConns: 1},
},
}
dbx.ApplyOptions(&cfg, opts...)
ctx := context.Background()
cluster, err := dbx.NewCluster(ctx, cfg)
if err != nil {
t.Skipf("dbxtest: cannot connect to test database: %v", err)
}
if err := cluster.Ping(ctx); err != nil {
cluster.Close()
t.Skipf("dbxtest: cannot reach test database: %v", err)
}
t.Cleanup(func() {
cluster.Close()
})
return cluster
}
// TestLogger is a Logger that writes to testing.T.
type TestLogger struct {
T testing.TB
}
func (l *TestLogger) Debug(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[DEBUG] %s %v", msg, fields)
}
func (l *TestLogger) Info(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[INFO] %s %v", msg, fields)
}
func (l *TestLogger) Warn(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[WARN] %s %v", msg, fields)
}
func (l *TestLogger) Error(_ context.Context, msg string, fields ...any) {
l.T.Helper()
l.T.Logf("[ERROR] %s %v", msg, fields)
}
// Compile-time check.
var _ dbx.Logger = (*TestLogger)(nil)

27
dbxtest/tx.go Normal file
View File

@@ -0,0 +1,27 @@
package dbxtest
import (
"context"
"testing"
"git.codelab.vc/pkg/dbx"
)
// RunInTx executes fn inside a transaction that is always rolled back.
// This is useful for tests that modify data but should not leave side effects.
// The callback receives a dbx.Querier (not pgx.Tx) so it is compatible with
// InjectQuerier/ExtractQuerier patterns.
func RunInTx(t testing.TB, c *dbx.Cluster, fn func(ctx context.Context, q dbx.Querier)) {
t.Helper()
ctx := context.Background()
tx, err := c.Begin(ctx)
if err != nil {
t.Fatalf("dbxtest.RunInTx: begin: %v", err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
fn(ctx, tx)
}

50
dbxtest/tx_test.go Normal file
View File

@@ -0,0 +1,50 @@
package dbxtest_test
import (
"context"
"testing"
"git.codelab.vc/pkg/dbx"
"git.codelab.vc/pkg/dbx/dbxtest"
)
func TestRunInTx(t *testing.T) {
c := dbxtest.NewTestCluster(t)
ctx := context.Background()
// Create a table that persists across the test.
_, err := c.Exec(ctx, `CREATE TABLE IF NOT EXISTS test_run_in_tx (id int)`)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_, _ = c.Exec(context.Background(), `DROP TABLE IF EXISTS test_run_in_tx`)
})
dbxtest.RunInTx(t, c, func(ctx context.Context, q dbx.Querier) {
_, err := q.Exec(ctx, `INSERT INTO test_run_in_tx (id) VALUES (1)`)
if err != nil {
t.Fatal(err)
}
// Row should be visible within the transaction.
var count int
err = q.QueryRow(ctx, `SELECT count(*) FROM test_run_in_tx`).Scan(&count)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("expected 1 row in tx, got %d", count)
}
})
// After RunInTx returns, the transaction was rolled back; row should not exist.
var count int
err = c.QueryRow(ctx, `SELECT count(*) FROM test_run_in_tx`).Scan(&count)
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Errorf("expected 0 rows after rollback, got %d", count)
}
}

49
doc.go Normal file
View File

@@ -0,0 +1,49 @@
// Package dbx provides a production-ready PostgreSQL helper library built on top of pgx/v5.
//
// It manages connection pools, master/replica routing, automatic retries with
// exponential backoff, load balancing across replicas, background health checking,
// and transaction helpers.
//
// # Quick Start
//
// cluster, err := dbx.NewCluster(ctx, dbx.Config{
// Master: dbx.NodeConfig{
// Name: "master",
// DSN: "postgres://user:pass@master:5432/mydb",
// Pool: dbx.PoolConfig{MaxConns: 20, MinConns: 5},
// },
// Replicas: []dbx.NodeConfig{
// {Name: "replica-1", DSN: "postgres://...@replica1:5432/mydb"},
// },
// })
// if err != nil {
// log.Fatal(err)
// }
// defer cluster.Close()
//
// // Write → master
// cluster.Exec(ctx, "INSERT INTO users (name) VALUES ($1)", "alice")
//
// // Read → replica with fallback to master
// rows, _ := cluster.ReadQuery(ctx, "SELECT * FROM users WHERE active = $1", true)
//
// // Transaction → master, panic-safe
// cluster.RunTx(ctx, func(ctx context.Context, tx pgx.Tx) error {
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance - $1 WHERE id = $2", 100, fromID)
// _, _ = tx.Exec(ctx, "UPDATE accounts SET balance = balance + $1 WHERE id = $2", 100, toID)
// return nil
// })
//
// # Routing
//
// The library uses explicit method-based routing (no SQL parsing):
// - Exec, Query, QueryRow, Begin, BeginTx, CopyFrom, SendBatch → master
// - ReadQuery, ReadQueryRow → replicas with master fallback
// - Master(), Replica() → direct access to specific nodes
//
// # Retry
//
// Retryable errors include connection errors (PG class 08), serialization failures (40001),
// deadlocks (40P01), and too_many_connections (53300). A custom classifier can be provided
// via RetryConfig.RetryableErrors.
package dbx

94
errors.go Normal file
View File

@@ -0,0 +1,94 @@
package dbx
import (
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
)
// Sentinel errors.
var (
ErrNoHealthyNode = errors.New("dbx: no healthy node available")
ErrClusterClosed = errors.New("dbx: cluster is closed")
ErrRetryExhausted = errors.New("dbx: retry attempts exhausted")
)
// IsRetryable reports whether the error is worth retrying.
// Connection errors, serialization failures, deadlocks, and too_many_connections are retryable.
func IsRetryable(err error) bool {
if err == nil {
return false
}
if IsConnectionError(err) {
return true
}
code := PgErrorCode(err)
switch code {
case "40001", // serialization_failure
"40P01", // deadlock_detected
"53300": // too_many_connections
return true
}
return false
}
// IsConnectionError reports whether the error indicates a connection problem (PG class 08).
func IsConnectionError(err error) bool {
if err == nil {
return false
}
code := PgErrorCode(err)
if strings.HasPrefix(code, "08") {
return true
}
// pgx wraps connection errors that may not have a PG code
msg := err.Error()
for _, s := range []string{
"connection refused",
"connection reset",
"broken pipe",
"EOF",
"no connection",
"conn closed",
"timeout",
} {
if strings.Contains(msg, s) {
return true
}
}
return false
}
// IsConstraintViolation reports whether the error is a constraint violation (PG class 23).
func IsConstraintViolation(err error) bool {
return strings.HasPrefix(PgErrorCode(err), "23")
}
// PgErrorCode extracts the PostgreSQL error code from err, or returns "" if not a PG error.
func PgErrorCode(err error) string {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code
}
return ""
}
// retryError wraps the last error with ErrRetryExhausted context.
type retryError struct {
attempts int
last error
}
func (e *retryError) Error() string {
return fmt.Sprintf("%s: %d attempts, last error: %v", ErrRetryExhausted, e.attempts, e.last)
}
func (e *retryError) Unwrap() []error {
return []error{ErrRetryExhausted, e.last}
}
func newRetryError(attempts int, last error) error {
return &retryError{attempts: attempts, last: last}
}

92
errors_test.go Normal file
View File

@@ -0,0 +1,92 @@
package dbx
import (
"errors"
"testing"
"github.com/jackc/pgx/v5/pgconn"
)
func TestIsRetryable(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"generic", errors.New("oops"), false},
{"serialization_failure", &pgconn.PgError{Code: "40001"}, true},
{"deadlock", &pgconn.PgError{Code: "40P01"}, true},
{"too_many_connections", &pgconn.PgError{Code: "53300"}, true},
{"connection_exception", &pgconn.PgError{Code: "08006"}, true},
{"constraint_violation", &pgconn.PgError{Code: "23505"}, false},
{"syntax_error", &pgconn.PgError{Code: "42601"}, false},
{"connection_refused", errors.New("connection refused"), true},
{"EOF", errors.New("unexpected EOF"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsRetryable(tt.err); got != tt.want {
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsConnectionError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"pg_class_08", &pgconn.PgError{Code: "08003"}, true},
{"pg_class_23", &pgconn.PgError{Code: "23505"}, false},
{"conn_refused", errors.New("connection refused"), true},
{"broken_pipe", errors.New("write: broken pipe"), true},
{"generic", errors.New("something else"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsConnectionError(tt.err); got != tt.want {
t.Errorf("IsConnectionError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsConstraintViolation(t *testing.T) {
if IsConstraintViolation(nil) {
t.Error("expected false for nil")
}
if !IsConstraintViolation(&pgconn.PgError{Code: "23505"}) {
t.Error("expected true for unique_violation")
}
if IsConstraintViolation(&pgconn.PgError{Code: "42601"}) {
t.Error("expected false for syntax_error")
}
}
func TestPgErrorCode(t *testing.T) {
if code := PgErrorCode(errors.New("not pg")); code != "" {
t.Errorf("expected empty, got %q", code)
}
if code := PgErrorCode(&pgconn.PgError{Code: "42P01"}); code != "42P01" {
t.Errorf("expected 42P01, got %q", code)
}
}
func TestRetryError(t *testing.T) {
inner := errors.New("conn lost")
err := newRetryError(3, inner)
if !errors.Is(err, ErrRetryExhausted) {
t.Error("should unwrap to ErrRetryExhausted")
}
if !errors.Is(err, inner) {
t.Error("should unwrap to inner error")
}
if err.Error() == "" {
t.Error("error string should not be empty")
}
}

13
go.mod Normal file
View File

@@ -0,0 +1,13 @@
module git.codelab.vc/pkg/dbx
go 1.25.7
require github.com/jackc/pgx/v5 v5.9.1
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
)

26
go.sum Normal file
View File

@@ -0,0 +1,26 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

90
health.go Normal file
View File

@@ -0,0 +1,90 @@
package dbx
import (
"context"
"time"
)
// healthChecker periodically pings nodes and updates their health state.
type healthChecker struct {
nodes []*Node
cfg HealthCheckConfig
logger Logger
metrics *MetricsHook
stop chan struct{}
done chan struct{}
}
func newHealthChecker(nodes []*Node, cfg HealthCheckConfig, logger Logger, metrics *MetricsHook) *healthChecker {
return &healthChecker{
nodes: nodes,
cfg: cfg,
logger: logger,
metrics: metrics,
stop: make(chan struct{}),
done: make(chan struct{}),
}
}
func (h *healthChecker) start() {
go h.loop()
}
func (h *healthChecker) loop() {
defer close(h.done)
ticker := time.NewTicker(h.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-h.stop:
return
case <-ticker.C:
h.checkAll()
}
}
}
func (h *healthChecker) checkAll() {
for _, node := range h.nodes {
h.checkNode(node)
}
}
func (h *healthChecker) checkNode(n *Node) {
ctx, cancel := context.WithTimeout(context.Background(), h.cfg.Timeout)
defer cancel()
err := n.pool.Ping(ctx)
wasHealthy := n.healthy.Load()
if err != nil {
n.healthy.Store(false)
if wasHealthy {
h.logger.Error(ctx, "dbx: node is down",
"node", n.name,
"error", err,
)
if h.metrics != nil && h.metrics.OnNodeDown != nil {
h.metrics.OnNodeDown(ctx, n.name, err)
}
}
return
}
n.healthy.Store(true)
if !wasHealthy {
h.logger.Info(ctx, "dbx: node is up",
"node", n.name,
)
if h.metrics != nil && h.metrics.OnNodeUp != nil {
h.metrics.OnNodeUp(ctx, n.name)
}
}
}
func (h *healthChecker) shutdown() {
close(h.stop)
<-h.done
}

56
health_test.go Normal file
View File

@@ -0,0 +1,56 @@
package dbx
import (
"context"
"sync/atomic"
"testing"
"time"
)
func TestHealthChecker_StartStop(t *testing.T) {
nodes := makeTestNodes("a", "b")
hc := newHealthChecker(nodes, HealthCheckConfig{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
}, nopLogger{}, nil)
hc.start()
// Just verify it can be stopped without deadlock
hc.shutdown()
}
func TestHealthChecker_MetricsCallbacks(t *testing.T) {
var downCalled, upCalled atomic.Int32
metrics := &MetricsHook{
OnNodeDown: func(_ context.Context, node string, err error) {
downCalled.Add(1)
},
OnNodeUp: func(_ context.Context, node string) {
upCalled.Add(1)
},
}
node := &Node{name: "test"}
node.healthy.Store(true)
hc := newHealthChecker([]*Node{node}, HealthCheckConfig{
Interval: time.Hour, // won't tick in test
Timeout: 50 * time.Millisecond,
}, nopLogger{}, metrics)
// Simulate a down check: node has no pool, so ping will fail
// We can't easily test with a real pool here, but checkNode
// will panic on nil pool. In integration tests, we use real pools.
_ = hc // just verify construction works
// Verify state transitions
node.healthy.Store(false)
if node.IsHealthy() {
t.Error("expected unhealthy")
}
node.healthy.Store(true)
if !node.IsHealthy() {
t.Error("expected healthy")
}
}

113
node.go Normal file
View File

@@ -0,0 +1,113 @@
package dbx
import (
"context"
"sync/atomic"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
// Node wraps a pgxpool.Pool with health state and a human-readable name.
type Node struct {
name string
pool *pgxpool.Pool
healthy atomic.Bool
}
// newNode creates a Node from an existing pool.
func newNode(name string, pool *pgxpool.Pool) *Node {
n := &Node{
name: name,
pool: pool,
}
n.healthy.Store(true)
return n
}
// connectNode parses the NodeConfig, creates a pgxpool.Pool, and returns a Node.
func connectNode(ctx context.Context, cfg NodeConfig) (*Node, error) {
poolCfg, err := pgxpool.ParseConfig(cfg.DSN)
if err != nil {
return nil, err
}
applyPoolConfig(poolCfg, cfg.Pool)
if cfg.Tracer != nil {
poolCfg.ConnConfig.Tracer = cfg.Tracer
}
pool, err := pgxpool.NewWithConfig(ctx, poolCfg)
if err != nil {
return nil, err
}
return newNode(cfg.Name, pool), nil
}
func applyPoolConfig(dst *pgxpool.Config, src PoolConfig) {
if src.MaxConns > 0 {
dst.MaxConns = src.MaxConns
}
if src.MinConns > 0 {
dst.MinConns = src.MinConns
}
if src.MaxConnLifetime > 0 {
dst.MaxConnLifetime = src.MaxConnLifetime
}
if src.MaxConnIdleTime > 0 {
dst.MaxConnIdleTime = src.MaxConnIdleTime
}
if src.HealthCheckPeriod > 0 {
dst.HealthCheckPeriod = src.HealthCheckPeriod
}
}
// Name returns the node's human-readable name.
func (n *Node) Name() string { return n.name }
// IsHealthy reports whether the node is considered healthy.
func (n *Node) IsHealthy() bool { return n.healthy.Load() }
// Pool returns the underlying pgxpool.Pool.
func (n *Node) Pool() *pgxpool.Pool { return n.pool }
// --- DB interface implementation ---
func (n *Node) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
return n.pool.Exec(ctx, sql, args...)
}
func (n *Node) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
return n.pool.Query(ctx, sql, args...)
}
func (n *Node) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
return n.pool.QueryRow(ctx, sql, args...)
}
func (n *Node) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return n.pool.SendBatch(ctx, b)
}
func (n *Node) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return n.pool.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (n *Node) Begin(ctx context.Context) (pgx.Tx, error) {
return n.pool.Begin(ctx)
}
func (n *Node) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
return n.pool.BeginTx(ctx, txOptions)
}
func (n *Node) Ping(ctx context.Context) error {
return n.pool.Ping(ctx)
}
func (n *Node) Close() {
n.pool.Close()
}
// Compile-time check that *Node implements DB.
var _ DB = (*Node)(nil)

64
options.go Normal file
View File

@@ -0,0 +1,64 @@
package dbx
import (
"time"
"github.com/jackc/pgx/v5"
)
// Option is a functional option for NewCluster.
type Option func(*Config)
// WithLogger sets the logger for the cluster.
func WithLogger(l Logger) Option {
return func(c *Config) {
c.Logger = l
}
}
// WithMetrics sets the metrics hook for the cluster.
func WithMetrics(m *MetricsHook) Option {
return func(c *Config) {
c.Metrics = m
}
}
// WithRetry overrides the retry configuration.
func WithRetry(r RetryConfig) Option {
return func(c *Config) {
c.Retry = r
}
}
// WithHealthCheck overrides the health check configuration.
func WithHealthCheck(h HealthCheckConfig) Option {
return func(c *Config) {
c.HealthCheck = h
}
}
// WithSlowQueryThreshold sets the threshold for slow query warnings.
// Queries taking longer than d will be logged at Warn level.
func WithSlowQueryThreshold(d time.Duration) Option {
return func(c *Config) {
c.SlowQueryThreshold = d
}
}
// WithTracer sets the pgx.QueryTracer on master and all replica configs.
// This enables OpenTelemetry integration via libraries like otelpgx.
func WithTracer(t pgx.QueryTracer) Option {
return func(c *Config) {
c.Master.Tracer = t
for i := range c.Replicas {
c.Replicas[i].Tracer = t
}
}
}
// ApplyOptions applies functional options to a Config.
func ApplyOptions(cfg *Config, opts ...Option) {
for _, o := range opts {
o(cfg)
}
}

35
options_test.go Normal file
View File

@@ -0,0 +1,35 @@
package dbx
import (
"testing"
"time"
)
func TestApplyOptions(t *testing.T) {
cfg := Config{}
logger := nopLogger{}
metrics := &MetricsHook{}
retry := RetryConfig{MaxAttempts: 7}
hc := HealthCheckConfig{Interval: 10 * time.Second}
ApplyOptions(&cfg,
WithLogger(logger),
WithMetrics(metrics),
WithRetry(retry),
WithHealthCheck(hc),
)
if cfg.Logger == nil {
t.Error("Logger should be set")
}
if cfg.Metrics == nil {
t.Error("Metrics should be set")
}
if cfg.Retry.MaxAttempts != 7 {
t.Errorf("Retry.MaxAttempts = %d, want 7", cfg.Retry.MaxAttempts)
}
if cfg.HealthCheck.Interval != 10*time.Second {
t.Errorf("HealthCheck.Interval = %v, want 10s", cfg.HealthCheck.Interval)
}
}

98
retry.go Normal file
View File

@@ -0,0 +1,98 @@
package dbx
import (
"context"
"math"
"math/rand/v2"
"time"
)
// retrier executes operations with retry and node fallback.
type retrier struct {
cfg RetryConfig
logger Logger
metrics *MetricsHook
}
func newRetrier(cfg RetryConfig, logger Logger, metrics *MetricsHook) *retrier {
return &retrier{cfg: cfg, logger: logger, metrics: metrics}
}
// isRetryable checks the custom classifier first, then falls back to the default.
func (r *retrier) isRetryable(err error) bool {
if r.cfg.RetryableErrors != nil {
return r.cfg.RetryableErrors(err)
}
return IsRetryable(err)
}
// do executes fn on the given nodes in order, retrying on retryable errors.
// For writes, pass a single-element slice with the master.
// For reads, pass [replicas..., master] for fallback.
func (r *retrier) do(ctx context.Context, nodes []*Node, fn func(ctx context.Context, n *Node) error) error {
var lastErr error
for attempt := 0; attempt < r.cfg.MaxAttempts; attempt++ {
if ctx.Err() != nil {
if lastErr != nil {
return lastErr
}
return ctx.Err()
}
for _, node := range nodes {
if !node.IsHealthy() {
continue
}
err := fn(ctx, node)
if err == nil {
return nil
}
lastErr = err
if !r.isRetryable(err) {
return err
}
if r.metrics != nil && r.metrics.OnRetry != nil {
r.metrics.OnRetry(ctx, node.name, attempt+1, err)
}
r.logger.Warn(ctx, "dbx: retryable error",
"node", node.name,
"attempt", attempt+1,
"error", err,
)
}
if attempt < r.cfg.MaxAttempts-1 {
delay := r.backoff(attempt)
t := time.NewTimer(delay)
select {
case <-ctx.Done():
t.Stop()
if lastErr != nil {
return lastErr
}
return ctx.Err()
case <-t.C:
}
}
}
if lastErr == nil {
return ErrNoHealthyNode
}
return newRetryError(r.cfg.MaxAttempts, lastErr)
}
// backoff returns the delay for the given attempt with jitter.
func (r *retrier) backoff(attempt int) time.Duration {
delay := float64(r.cfg.BaseDelay) * math.Pow(2, float64(attempt))
if delay > float64(r.cfg.MaxDelay) {
delay = float64(r.cfg.MaxDelay)
}
// add jitter: 75%-125% of computed delay
jitter := 0.75 + rand.Float64()*0.5
return time.Duration(delay * jitter)
}

168
retry_test.go Normal file
View File

@@ -0,0 +1,168 @@
package dbx
import (
"context"
"errors"
"testing"
"time"
)
func TestRetrier_Success(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1", calls)
}
}
func TestRetrier_RetriesOnRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
retryableErr := errors.New("connection refused")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 3 {
return retryableErr
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 3 {
t.Errorf("calls = %d, want 3", calls)
}
}
func TestRetrier_NonRetryableError(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
syntaxErr := errors.New("syntax problem")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
return syntaxErr
})
if !errors.Is(err, syntaxErr) {
t.Errorf("expected syntax error, got %v", err)
}
if calls != 1 {
t.Errorf("calls = %d, want 1 (should not retry non-retryable)", calls)
}
}
func TestRetrier_Exhausted(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
return errors.New("connection refused")
})
if !errors.Is(err, ErrRetryExhausted) {
t.Errorf("expected ErrRetryExhausted, got %v", err)
}
}
func TestRetrier_FallbackToNextNode(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("replica-1", "master")
visited := []string{}
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
visited = append(visited, n.name)
if n.name == "replica-1" {
return errors.New("connection refused")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(visited) < 2 || visited[1] != "master" {
t.Errorf("expected fallback to master, visited: %v", visited)
}
}
func TestRetrier_ContextCanceled(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := r.do(ctx, nodes, func(_ context.Context, n *Node) error {
return nil
})
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestRetrier_NoHealthyNodes(t *testing.T) {
r := newRetrier(RetryConfig{MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
nodes[0].healthy.Store(false)
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
t.Fatal("should not be called")
return nil
})
if !errors.Is(err, ErrNoHealthyNode) {
t.Errorf("expected ErrNoHealthyNode, got %v", err)
}
}
func TestRetrier_CustomClassifier(t *testing.T) {
custom := func(err error) bool {
return err.Error() == "custom-retry"
}
r := newRetrier(RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 10 * time.Millisecond, RetryableErrors: custom}, nopLogger{}, nil)
nodes := makeTestNodes("n1")
calls := 0
err := r.do(context.Background(), nodes, func(_ context.Context, n *Node) error {
calls++
if calls < 2 {
return errors.New("custom-retry")
}
return nil
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 2 {
t.Errorf("calls = %d, want 2", calls)
}
}
func TestBackoff(t *testing.T) {
r := newRetrier(RetryConfig{BaseDelay: 100 * time.Millisecond, MaxDelay: time.Second}, nopLogger{}, nil)
d0 := r.backoff(0)
if d0 < 50*time.Millisecond || d0 > 150*time.Millisecond {
t.Errorf("backoff(0) = %v, expected ~100ms", d0)
}
d3 := r.backoff(3)
if d3 < 600*time.Millisecond || d3 > 1100*time.Millisecond {
t.Errorf("backoff(3) = %v, expected ~800ms", d3)
}
// Should cap at MaxDelay
d10 := r.backoff(10)
if d10 > 1250*time.Millisecond {
t.Errorf("backoff(10) = %v, should be capped near 1s", d10)
}
}

28
scan.go Normal file
View File

@@ -0,0 +1,28 @@
package dbx
import (
"context"
"github.com/jackc/pgx/v5"
)
// Collect executes a read query and collects all rows into a slice of T
// using pgx.RowToStructByName. T must be a struct with db tags matching column names.
func Collect[T any](ctx context.Context, c *Cluster, sql string, args ...any) ([]T, error) {
rows, err := c.ReadQuery(ctx, sql, args...)
if err != nil {
return nil, err
}
return pgx.CollectRows(rows, pgx.RowToStructByName[T])
}
// CollectOne executes a read query and collects exactly one row into T.
// Returns pgx.ErrNoRows if no rows are returned.
func CollectOne[T any](ctx context.Context, c *Cluster, sql string, args ...any) (T, error) {
rows, err := c.ReadQuery(ctx, sql, args...)
if err != nil {
var zero T
return zero, err
}
return pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[T])
}

81
scan_test.go Normal file
View File

@@ -0,0 +1,81 @@
package dbx_test
import (
"context"
"errors"
"testing"
"git.codelab.vc/pkg/dbx"
"git.codelab.vc/pkg/dbx/dbxtest"
"github.com/jackc/pgx/v5"
)
type scanRow struct {
ID int `db:"id"`
Name string `db:"name"`
}
func TestCollect(t *testing.T) {
c := dbxtest.NewTestCluster(t)
ctx := context.Background()
_, err := c.Exec(ctx, `CREATE TEMPORARY TABLE test_collect (id int, name text)`)
if err != nil {
t.Fatal(err)
}
_, err = c.Exec(ctx, `INSERT INTO test_collect (id, name) VALUES (1, 'alice'), (2, 'bob')`)
if err != nil {
t.Fatal(err)
}
rows, err := dbx.Collect[scanRow](ctx, c, `SELECT id, name FROM test_collect ORDER BY id`)
if err != nil {
t.Fatal(err)
}
if len(rows) != 2 {
t.Fatalf("expected 2 rows, got %d", len(rows))
}
if rows[0].ID != 1 || rows[0].Name != "alice" {
t.Errorf("row 0: got %+v", rows[0])
}
if rows[1].ID != 2 || rows[1].Name != "bob" {
t.Errorf("row 1: got %+v", rows[1])
}
}
func TestCollectOne(t *testing.T) {
c := dbxtest.NewTestCluster(t)
ctx := context.Background()
_, err := c.Exec(ctx, `CREATE TEMPORARY TABLE test_collect_one (id int, name text)`)
if err != nil {
t.Fatal(err)
}
_, err = c.Exec(ctx, `INSERT INTO test_collect_one (id, name) VALUES (1, 'alice')`)
if err != nil {
t.Fatal(err)
}
row, err := dbx.CollectOne[scanRow](ctx, c, `SELECT id, name FROM test_collect_one WHERE id = 1`)
if err != nil {
t.Fatal(err)
}
if row.ID != 1 || row.Name != "alice" {
t.Errorf("got %+v", row)
}
}
func TestCollectOneNoRows(t *testing.T) {
c := dbxtest.NewTestCluster(t)
ctx := context.Background()
_, err := c.Exec(ctx, `CREATE TEMPORARY TABLE test_collect_norows (id int, name text)`)
if err != nil {
t.Fatal(err)
}
_, err = dbx.CollectOne[scanRow](ctx, c, `SELECT id, name FROM test_collect_norows`)
if !errors.Is(err, pgx.ErrNoRows) {
t.Errorf("expected pgx.ErrNoRows, got %v", err)
}
}

38
slog.go Normal file
View File

@@ -0,0 +1,38 @@
package dbx
import (
"context"
"log/slog"
)
// SlogLogger adapts *slog.Logger to the dbx.Logger interface.
type SlogLogger struct {
Logger *slog.Logger
}
// NewSlogLogger creates a SlogLogger. If l is nil, slog.Default() is used.
func NewSlogLogger(l *slog.Logger) *SlogLogger {
if l == nil {
l = slog.Default()
}
return &SlogLogger{Logger: l}
}
func (s *SlogLogger) Debug(ctx context.Context, msg string, fields ...any) {
s.Logger.DebugContext(ctx, msg, fields...)
}
func (s *SlogLogger) Info(ctx context.Context, msg string, fields ...any) {
s.Logger.InfoContext(ctx, msg, fields...)
}
func (s *SlogLogger) Warn(ctx context.Context, msg string, fields ...any) {
s.Logger.WarnContext(ctx, msg, fields...)
}
func (s *SlogLogger) Error(ctx context.Context, msg string, fields ...any) {
s.Logger.ErrorContext(ctx, msg, fields...)
}
// Compile-time check.
var _ Logger = (*SlogLogger)(nil)

52
slog_test.go Normal file
View File

@@ -0,0 +1,52 @@
package dbx
import (
"bytes"
"context"
"log/slog"
"strings"
"testing"
)
func TestSlogLogger(t *testing.T) {
var buf bytes.Buffer
h := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})
l := NewSlogLogger(slog.New(h))
ctx := context.Background()
l.Debug(ctx, "debug msg", "key", "val1")
l.Info(ctx, "info msg", "key", "val2")
l.Warn(ctx, "warn msg", "key", "val3")
l.Error(ctx, "error msg", "key", "val4")
out := buf.String()
for _, want := range []string{
"level=DEBUG",
"debug msg",
"key=val1",
"level=INFO",
"info msg",
"key=val2",
"level=WARN",
"warn msg",
"key=val3",
"level=ERROR",
"error msg",
"key=val4",
} {
if !strings.Contains(out, want) {
t.Errorf("output missing %q\ngot: %s", want, out)
}
}
}
func TestNewSlogLoggerNil(t *testing.T) {
l := NewSlogLogger(nil)
if l.Logger == nil {
t.Fatal("expected non-nil logger when passing nil")
}
// should not panic
l.Info(context.Background(), "test")
}

42
stats.go Normal file
View File

@@ -0,0 +1,42 @@
package dbx
import (
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// PoolStats is an aggregate of pool statistics across all nodes.
type PoolStats struct {
AcquireCount int64
AcquireDuration time.Duration
AcquiredConns int32
CanceledAcquireCount int64
ConstructingConns int32
EmptyAcquireCount int64
IdleConns int32
MaxConns int32
TotalConns int32
Nodes map[string]*pgxpool.Stat
}
// Stats returns aggregate pool statistics for all nodes in the cluster.
func (c *Cluster) Stats() PoolStats {
ps := PoolStats{
Nodes: make(map[string]*pgxpool.Stat, len(c.all)),
}
for _, n := range c.all {
s := n.pool.Stat()
ps.Nodes[n.name] = s
ps.AcquireCount += s.AcquireCount()
ps.AcquireDuration += s.AcquireDuration()
ps.AcquiredConns += s.AcquiredConns()
ps.CanceledAcquireCount += s.CanceledAcquireCount()
ps.ConstructingConns += s.ConstructingConns()
ps.EmptyAcquireCount += s.EmptyAcquireCount()
ps.IdleConns += s.IdleConns()
ps.MaxConns += s.MaxConns()
ps.TotalConns += s.TotalConns()
}
return ps
}

22
stats_test.go Normal file
View File

@@ -0,0 +1,22 @@
package dbx_test
import (
"testing"
"git.codelab.vc/pkg/dbx/dbxtest"
)
func TestStats(t *testing.T) {
c := dbxtest.NewTestCluster(t)
ps := c.Stats()
if ps.Nodes == nil {
t.Fatal("Nodes map is nil")
}
if _, ok := ps.Nodes["test-master"]; !ok {
t.Error("expected test-master in Nodes")
}
if ps.MaxConns <= 0 {
t.Errorf("expected MaxConns > 0, got %d", ps.MaxConns)
}
}

61
tx.go Normal file
View File

@@ -0,0 +1,61 @@
package dbx
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
)
// TxFunc is the callback signature for RunTx.
type TxFunc func(ctx context.Context, tx pgx.Tx) error
// RunTx executes fn inside a transaction on the master node with default options.
// The transaction is committed if fn returns nil, rolled back otherwise.
// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised.
func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error {
return c.RunTxOptions(ctx, pgx.TxOptions{}, fn)
}
// RunTxOptions executes fn inside a transaction with the given options.
func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error {
tx, err := c.master.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("dbx: begin tx: %w", err)
}
return runTx(ctx, tx, fn)
}
func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) {
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback(ctx)
panic(p)
}
if retErr != nil {
_ = tx.Rollback(ctx)
}
}()
if err := fn(ctx, tx); err != nil {
return err
}
return tx.Commit(ctx)
}
// querierKey is the context key for Querier injection.
type querierKey struct{}
// InjectQuerier returns a new context carrying q.
func InjectQuerier(ctx context.Context, q Querier) context.Context {
return context.WithValue(ctx, querierKey{}, q)
}
// ExtractQuerier returns the Querier from ctx, or fallback if none is present.
func ExtractQuerier(ctx context.Context, fallback Querier) Querier {
if q, ok := ctx.Value(querierKey{}).(Querier); ok {
return q
}
return fallback
}

25
tx_test.go Normal file
View File

@@ -0,0 +1,25 @@
package dbx
import (
"context"
"testing"
)
func TestInjectExtractQuerier(t *testing.T) {
ctx := context.Background()
fallback := &Node{name: "fallback"}
// No querier in context → returns fallback
got := ExtractQuerier(ctx, fallback)
if got != fallback {
t.Error("expected fallback when no querier in context")
}
// Inject querier → extract it
injected := &Node{name: "injected"}
ctx = InjectQuerier(ctx, injected)
got = ExtractQuerier(ctx, fallback)
if got != injected {
t.Error("expected injected querier")
}
}