Refactor transaction usage

pull/6/head
Mickaël Menu 4 years ago
parent 4ca1595f1f
commit 8467f1aa3a
No known key found for this signature in database
GPG Key ID: 53D73664CD359895

@ -9,7 +9,7 @@ import (
// DB holds the connections to a SQLite database. // DB holds the connections to a SQLite database.
type DB struct { type DB struct {
*sql.DB db *sql.DB
} }
// Open creates a new DB instance for the SQLite database at the given path. // Open creates a new DB instance for the SQLite database at the given path.
@ -23,28 +23,21 @@ func Open(path string) (*DB, error) {
// Close terminates the connections to the SQLite database. // Close terminates the connections to the SQLite database.
func (db *DB) Close() error { func (db *DB) Close() error {
err := db.Close() err := db.db.Close()
return errors.Wrap(err, "failed to close the database") return errors.Wrap(err, "failed to close the database")
} }
// Migrate upgrades the SQL schema of the database. // Migrate upgrades the SQL schema of the database.
func (db *DB) Migrate() error { func (db *DB) Migrate() error {
wrap := errors.Wrapper("database migration failed") err := db.WithTransaction(func(tx Transaction) error {
tx, err := db.Begin()
if err != nil {
return wrap(err)
}
defer tx.Rollback()
var version int var version int
err = tx.QueryRow("PRAGMA user_version").Scan(&version) err := tx.QueryRow("PRAGMA user_version").Scan(&version)
if err != nil { if err != nil {
return wrap(err) return err
} }
if version == 0 { if version == 0 {
err = execMultiple(tx, []string{ err = tx.ExecStmts([]string{
` `
CREATE TABLE IF NOT EXISTS notes ( CREATE TABLE IF NOT EXISTS notes (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
@ -87,26 +80,14 @@ func (db *DB) Migrate() error {
`, `,
`PRAGMA user_version = 1`, `PRAGMA user_version = 1`,
}) })
}
if err != nil {
return wrap(err)
}
err = tx.Commit()
if err != nil { if err != nil {
return wrap(err) return err
}
} }
return nil return nil
} })
func execMultiple(tx *sql.Tx, stmts []string) error { return errors.Wrap(err, "database migration failed")
var err error
for _, stmt := range stmts {
if err != nil {
break
}
_, err = tx.Exec(stmt)
}
return err
} }

@ -13,7 +13,7 @@ import (
// NoteIndexer retrieves and stores notes indexation in the SQLite database. // NoteIndexer retrieves and stores notes indexation in the SQLite database.
// It implements the Core port note.Indexer. // It implements the Core port note.Indexer.
type NoteIndexer struct { type NoteIndexer struct {
tx *sql.Tx tx Transaction
root string root string
logger util.Logger logger util.Logger
@ -24,7 +24,7 @@ type NoteIndexer struct {
removeStmt *sql.Stmt removeStmt *sql.Stmt
} }
func NewNoteIndexer(tx *sql.Tx, root string, logger util.Logger) (*NoteIndexer, error) { func NewNoteIndexer(tx Transaction, root string, logger util.Logger) (*NoteIndexer, error) {
indexedStmt, err := tx.Prepare(` indexedStmt, err := tx.Prepare(`
SELECT filename, dir, modified from notes SELECT filename, dir, modified from notes
ORDER BY dir, filename ASC ORDER BY dir, filename ASC

@ -0,0 +1,63 @@
package sqlite
import "database/sql"
// Inspired by https://pseudomuto.com/2018/01/clean-sql-transactions-in-golang/
// Transaction is an interface that models the standard transaction in
// database/sql.
//
// To ensure TxFn funcs cannot commit or rollback a transaction (which is
// handled by `WithTransaction`), those methods are not included here.
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecStmts(stmts []string) error
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// txWrapper wraps a native sql.Tx to fully implement the Transaction interface.
type txWrapper struct {
*sql.Tx
}
func (tx *txWrapper) ExecStmts(stmts []string) error {
var err error
for _, stmt := range stmts {
if err != nil {
break
}
_, err = tx.Exec(stmt)
}
return err
}
// A Txfn is a function that will be called with an initialized Transaction
// object that can be used for executing statements and queries against a
// database.
type TxFn func(tx Transaction) error
// WithTransaction creates a new transaction and handles rollback/commit based
// on the error object returned by the TxFn closure.
func (db *DB) WithTransaction(fn TxFn) error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
// A panic occurred, rollback and repanic.
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
}
}()
err = fn(&txWrapper{tx})
return err
}

@ -26,19 +26,13 @@ func (cmd *Index) Run(container *Container) error {
if err != nil { if err != nil {
return err return err
} }
tx, err := db.Begin()
defer tx.Rollback() return db.WithTransaction(func(tx sqlite.Transaction) error {
if err != nil {
return err
}
indexer, err := sqlite.NewNoteIndexer(tx, zk.Path, container.Logger) indexer, err := sqlite.NewNoteIndexer(tx, zk.Path, container.Logger)
if err != nil { if err != nil {
return err return err
} }
err = note.Index(*dir, indexer, container.Logger)
if err != nil {
return err
}
return tx.Commit() return note.Index(*dir, indexer, container.Logger)
})
} }

Loading…
Cancel
Save