package dbx import ( "context" "sync/atomic" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) // Cluster manages a master node and optional replicas. // Write operations go to master; read operations are distributed across replicas // with automatic fallback to master. type Cluster struct { master *Node replicas []*Node all []*Node // master + replicas for health checker balancer Balancer retrier *retrier health *healthChecker logger Logger metrics *MetricsHook slowQueryThreshold time.Duration closed atomic.Bool } // NewCluster creates a Cluster, connecting to all configured nodes. func NewCluster(ctx context.Context, cfg Config) (*Cluster, error) { cfg.defaults() master, err := connectNode(ctx, cfg.Master) if err != nil { return nil, err } replicas := make([]*Node, 0, len(cfg.Replicas)) for _, rc := range cfg.Replicas { r, err := connectNode(ctx, rc) if err != nil { // close already-connected nodes master.Close() for _, opened := range replicas { opened.Close() } return nil, err } replicas = append(replicas, r) } all := make([]*Node, 0, 1+len(replicas)) all = append(all, master) all = append(all, replicas...) c := &Cluster{ master: master, replicas: replicas, all: all, balancer: NewRoundRobinBalancer(), retrier: newRetrier(cfg.Retry, cfg.Logger, cfg.Metrics), logger: cfg.Logger, metrics: cfg.Metrics, slowQueryThreshold: cfg.SlowQueryThreshold, } c.health = newHealthChecker(all, cfg.HealthCheck, cfg.Logger, cfg.Metrics) c.health.start() return c, nil } // Master returns the master node as a DB. func (c *Cluster) Master() DB { return c.master } // Replica returns a healthy replica as a DB, or falls back to master. func (c *Cluster) Replica() DB { if n := c.balancer.Next(c.replicas); n != nil { return n } return c.master } // Close shuts down the health checker and closes all connection pools. func (c *Cluster) Close() { if !c.closed.CompareAndSwap(false, true) { return } c.health.shutdown() for _, n := range c.all { n.Close() } } // Ping checks connectivity to the master node. func (c *Cluster) Ping(ctx context.Context) error { if c.closed.Load() { return ErrClusterClosed } return c.master.Ping(ctx) } // --- Write operations (always master) --- func (c *Cluster) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { if c.closed.Load() { return pgconn.CommandTag{}, ErrClusterClosed } var tag pgconn.CommandTag err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { start := time.Now() c.queryStart(ctx, n.name, sql) var e error tag, e = n.Exec(ctx, sql, args...) c.queryEnd(ctx, n.name, sql, e, time.Since(start)) return e }) return tag, err } func (c *Cluster) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { if c.closed.Load() { return nil, ErrClusterClosed } var rows pgx.Rows err := c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { start := time.Now() c.queryStart(ctx, n.name, sql) var e error rows, e = n.Query(ctx, sql, args...) c.queryEnd(ctx, n.name, sql, e, time.Since(start)) return e }) return rows, err } func (c *Cluster) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { if c.closed.Load() { return errRow{err: ErrClusterClosed} } var row pgx.Row _ = c.retrier.do(ctx, []*Node{c.master}, func(ctx context.Context, n *Node) error { start := time.Now() c.queryStart(ctx, n.name, sql) row = n.QueryRow(ctx, sql, args...) c.queryEnd(ctx, n.name, sql, nil, time.Since(start)) return nil }) return row } func (c *Cluster) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.master.SendBatch(ctx, b) } func (c *Cluster) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { if c.closed.Load() { return 0, ErrClusterClosed } return c.master.CopyFrom(ctx, tableName, columnNames, rowSrc) } func (c *Cluster) Begin(ctx context.Context) (pgx.Tx, error) { if c.closed.Load() { return nil, ErrClusterClosed } return c.master.Begin(ctx) } func (c *Cluster) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { if c.closed.Load() { return nil, ErrClusterClosed } return c.master.BeginTx(ctx, txOptions) } // --- Read operations (replicas with master fallback) --- // ReadQuery executes a read query on a replica, falling back to master if no replicas are healthy. func (c *Cluster) ReadQuery(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { if c.closed.Load() { return nil, ErrClusterClosed } nodes := c.readNodes() var rows pgx.Rows err := c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error { start := time.Now() c.queryStart(ctx, n.name, sql) var e error rows, e = n.Query(ctx, sql, args...) c.queryEnd(ctx, n.name, sql, e, time.Since(start)) return e }) return rows, err } // ReadQueryRow executes a read query returning a single row, using replicas with master fallback. func (c *Cluster) ReadQueryRow(ctx context.Context, sql string, args ...any) pgx.Row { if c.closed.Load() { return errRow{err: ErrClusterClosed} } nodes := c.readNodes() var row pgx.Row _ = c.retrier.do(ctx, nodes, func(ctx context.Context, n *Node) error { start := time.Now() c.queryStart(ctx, n.name, sql) row = n.QueryRow(ctx, sql, args...) c.queryEnd(ctx, n.name, sql, nil, time.Since(start)) return nil }) return row } // readNodes returns replicas followed by master for fallback ordering. func (c *Cluster) readNodes() []*Node { if len(c.replicas) == 0 { return []*Node{c.master} } nodes := make([]*Node, 0, len(c.replicas)+1) nodes = append(nodes, c.replicas...) nodes = append(nodes, c.master) return nodes } // --- metrics helpers --- func (c *Cluster) queryStart(ctx context.Context, node, sql string) { if c.metrics != nil && c.metrics.OnQueryStart != nil { c.metrics.OnQueryStart(ctx, node, sql) } } func (c *Cluster) queryEnd(ctx context.Context, node, sql string, err error, d time.Duration) { if c.metrics != nil && c.metrics.OnQueryEnd != nil { c.metrics.OnQueryEnd(ctx, node, sql, err, d) } if c.slowQueryThreshold > 0 && d >= c.slowQueryThreshold { c.logger.Warn(ctx, "dbx: slow query", "node", node, "duration", d, "sql", sql, ) } } // errRow implements pgx.Row for error cases. type errRow struct { err error } func (r errRow) Scan(...any) error { return r.err } // Compile-time check that *Cluster implements DB. var _ DB = (*Cluster)(nil)