package dbx import ( "context" "fmt" "github.com/jackc/pgx/v5" ) // TxFunc is the callback signature for RunTx. type TxFunc func(ctx context.Context, tx pgx.Tx) error // RunTx executes fn inside a transaction on the master node with default options. // The transaction is committed if fn returns nil, rolled back otherwise. // Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised. func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error { return c.RunTxOptions(ctx, pgx.TxOptions{}, fn) } // RunTxOptions executes fn inside a transaction with the given options. func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error { tx, err := c.master.BeginTx(ctx, opts) if err != nil { return fmt.Errorf("dbx: begin tx: %w", err) } return runTx(ctx, tx, fn) } func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) { defer func() { if p := recover(); p != nil { _ = tx.Rollback(ctx) panic(p) } if retErr != nil { _ = tx.Rollback(ctx) } }() if err := fn(ctx, tx); err != nil { return err } return tx.Commit(ctx) } // querierKey is the context key for Querier injection. type querierKey struct{} // InjectQuerier returns a new context carrying q. func InjectQuerier(ctx context.Context, q Querier) context.Context { return context.WithValue(ctx, querierKey{}, q) } // ExtractQuerier returns the Querier from ctx, or fallback if none is present. func ExtractQuerier(ctx context.Context, fallback Querier) Querier { if q, ok := ctx.Value(querierKey{}).(Querier); ok { return q } return fallback }