62 lines
1.6 KiB
Go
62 lines
1.6 KiB
Go
package dbx
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
// TxFunc is the callback signature for RunTx.
|
|
type TxFunc func(ctx context.Context, tx pgx.Tx) error
|
|
|
|
// RunTx executes fn inside a transaction on the master node with default options.
|
|
// The transaction is committed if fn returns nil, rolled back otherwise.
|
|
// Panics inside fn are caught, the transaction is rolled back, and the panic is re-raised.
|
|
func (c *Cluster) RunTx(ctx context.Context, fn TxFunc) error {
|
|
return c.RunTxOptions(ctx, pgx.TxOptions{}, fn)
|
|
}
|
|
|
|
// RunTxOptions executes fn inside a transaction with the given options.
|
|
func (c *Cluster) RunTxOptions(ctx context.Context, opts pgx.TxOptions, fn TxFunc) error {
|
|
tx, err := c.master.BeginTx(ctx, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("dbx: begin tx: %w", err)
|
|
}
|
|
|
|
return runTx(ctx, tx, fn)
|
|
}
|
|
|
|
func runTx(ctx context.Context, tx pgx.Tx, fn TxFunc) (retErr error) {
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback(ctx)
|
|
panic(p)
|
|
}
|
|
if retErr != nil {
|
|
_ = tx.Rollback(ctx)
|
|
}
|
|
}()
|
|
|
|
if err := fn(ctx, tx); err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
// querierKey is the context key for Querier injection.
|
|
type querierKey struct{}
|
|
|
|
// InjectQuerier returns a new context carrying q.
|
|
func InjectQuerier(ctx context.Context, q Querier) context.Context {
|
|
return context.WithValue(ctx, querierKey{}, q)
|
|
}
|
|
|
|
// ExtractQuerier returns the Querier from ctx, or fallback if none is present.
|
|
func ExtractQuerier(ctx context.Context, fallback Querier) Querier {
|
|
if q, ok := ctx.Value(querierKey{}).(Querier); ok {
|
|
return q
|
|
}
|
|
return fallback
|
|
}
|