Files
dbx/tx.go

62 lines
1.6 KiB
Go

package dbx
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
)
// TxFunc is the callback signature for RunTx.
type TxFunc func(ctx context.Context, tx pgx.Tx) error
// RunTx executes fn inside a transaction on the master node with default options.
// The transaction is committed if fn returns nil, rolled back otherwise.
// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised.
func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error {
return c.RunTxOptions(ctx, pgx.TxOptions{}, fn)
}
// RunTxOptions executes fn inside a transaction with the given options.
func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error {
tx, err := c.master.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("dbx: begin tx: %w", err)
}
return runTx(ctx, tx, fn)
}
func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) {
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback(ctx)
panic(p)
}
if retErr != nil {
_ = tx.Rollback(ctx)
}
}()
if err := fn(ctx, tx); err != nil {
return err
}
return tx.Commit(ctx)
}
// querierKey is the context key for Querier injection.
type querierKey struct{}
// InjectQuerier returns a new context carrying q.
func InjectQuerier(ctx context.Context, q Querier) context.Context {
return context.WithValue(ctx, querierKey{}, q)
}
// ExtractQuerier returns the Querier from ctx, or fallback if none is present.
func ExtractQuerier(ctx context.Context, fallback Querier) Querier {
if q, ok := ctx.Value(querierKey{}).(Querier); ok {
return q
}
return fallback
}