You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
zk/internal/adapter/sqlite/transaction.go

69 lines
1.7 KiB
Go

package sqlite
import "database/sql"
// Inspired by https://pseudomuto.com/2018/01/clean-sql-transactions-in-golang/
// Transaction is an interface that models the standard transaction in
// database/sql.
//
// To ensure TxFn funcs cannot commit or rollback a transaction (which is
// handled by `WithTransaction`), those methods are not included here.
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecStmts(stmts []string) error
Prepare(query string) (*sql.Stmt, error)
PrepareLazy(query string) *LazyStmt
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// txWrapper wraps a native sql.Tx to fully implement the Transaction interface.
type txWrapper struct {
*sql.Tx
}
func (tx *txWrapper) PrepareLazy(query string) *LazyStmt {
return NewLazyStmt(tx.Tx, query)
}
func (tx *txWrapper) ExecStmts(stmts []string) error {
var err error
for _, stmt := range stmts {
if err != nil {
break
}
_, err = tx.Exec(stmt)
}
return err
}
// A Txfn is a function that will be called with an initialized Transaction
// object that can be used for executing statements and queries against a
// database.
type TxFn func(tx Transaction) error
// WithTransaction creates a new transaction and handles rollback/commit based
// on the error object returned by the TxFn closure.
func (db *DB) WithTransaction(fn TxFn) error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
// A panic occurred, rollback and repanic.
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
err = fn(&txWrapper{tx})
return err
}