diff --git a/adapter/sqlite/db.go b/adapter/sqlite/db.go index 1188b32..f4e9e4f 100644 --- a/adapter/sqlite/db.go +++ b/adapter/sqlite/db.go @@ -9,7 +9,7 @@ import ( // DB holds the connections to a SQLite database. type DB struct { - *sql.DB + db *sql.DB } // Open creates a new DB instance for the SQLite database at the given path. @@ -23,90 +23,71 @@ func Open(path string) (*DB, error) { // Close terminates the connections to the SQLite database. func (db *DB) Close() error { - err := db.Close() + err := db.db.Close() return errors.Wrap(err, "failed to close the database") } // Migrate upgrades the SQL schema of the database. func (db *DB) Migrate() error { - wrap := errors.Wrapper("database migration failed") - - tx, err := db.Begin() - if err != nil { - return wrap(err) - } - defer tx.Rollback() - - var version int - err = tx.QueryRow("PRAGMA user_version").Scan(&version) - if err != nil { - return wrap(err) - } + err := db.WithTransaction(func(tx Transaction) error { + var version int + err := tx.QueryRow("PRAGMA user_version").Scan(&version) + if err != nil { + return err + } - if version == 0 { - err = execMultiple(tx, []string{ - ` - CREATE TABLE IF NOT EXISTS notes ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - filename TEXT NOT NULL, - dir TEXT NOT NULL, - title TEXT DEFAULT('') NOT NULL, - body TEXT DEFAULT('') NOT NULL, - word_count INTEGER DEFAULT(0) NOT NULL, - checksum TEXT NOT NULL, - created DATETIME DEFAULT(CURRENT_TIMESTAMP) NOT NULL, - modified DATETIME DEFAULT(CURRENT_TIMESTAMP) NOT NULL, - UNIQUE(filename, dir) - ) - `, - `CREATE INDEX IF NOT EXISTS notes_checksum_idx ON notes(checksum)`, - ` - CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5( - title, body, - content = notes, - content_rowid = id, - tokenize = 'porter unicode61 remove_diacritics 1' - ) - `, - // Triggers to keep the FTS index up to date. - ` - CREATE TRIGGER IF NOT EXISTS notes_ai AFTER INSERT ON notes BEGIN - INSERT INTO notes_fts(rowid, title, body) VALUES (new.id, new.title, new.body); - END - `, - ` - CREATE TRIGGER IF NOT EXISTS notes_ad AFTER DELETE ON notes BEGIN - INSERT INTO notes_fts(notes_fts, rowid, title, body) VALUES('delete', old.id, old.title, old.body); - END - `, - ` - CREATE TRIGGER IF NOT EXISTS notes_au AFTER UPDATE ON notes BEGIN - INSERT INTO notes_fts(notes_fts, rowid, title, body) VALUES('delete', old.id, old.title, old.body); - INSERT INTO notes_fts(rowid, title, body) VALUES (new.id, new.title, new.body); - END - `, - `PRAGMA user_version = 1`, - }) - } - if err != nil { - return wrap(err) - } + if version == 0 { + err = tx.ExecStmts([]string{ + ` + CREATE TABLE IF NOT EXISTS notes ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + filename TEXT NOT NULL, + dir TEXT NOT NULL, + title TEXT DEFAULT('') NOT NULL, + body TEXT DEFAULT('') NOT NULL, + word_count INTEGER DEFAULT(0) NOT NULL, + checksum TEXT NOT NULL, + created DATETIME DEFAULT(CURRENT_TIMESTAMP) NOT NULL, + modified DATETIME DEFAULT(CURRENT_TIMESTAMP) NOT NULL, + UNIQUE(filename, dir) + ) + `, + `CREATE INDEX IF NOT EXISTS notes_checksum_idx ON notes(checksum)`, + ` + CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5( + title, body, + content = notes, + content_rowid = id, + tokenize = 'porter unicode61 remove_diacritics 1' + ) + `, + // Triggers to keep the FTS index up to date. + ` + CREATE TRIGGER IF NOT EXISTS notes_ai AFTER INSERT ON notes BEGIN + INSERT INTO notes_fts(rowid, title, body) VALUES (new.id, new.title, new.body); + END + `, + ` + CREATE TRIGGER IF NOT EXISTS notes_ad AFTER DELETE ON notes BEGIN + INSERT INTO notes_fts(notes_fts, rowid, title, body) VALUES('delete', old.id, old.title, old.body); + END + `, + ` + CREATE TRIGGER IF NOT EXISTS notes_au AFTER UPDATE ON notes BEGIN + INSERT INTO notes_fts(notes_fts, rowid, title, body) VALUES('delete', old.id, old.title, old.body); + INSERT INTO notes_fts(rowid, title, body) VALUES (new.id, new.title, new.body); + END + `, + `PRAGMA user_version = 1`, + }) - err = tx.Commit() - if err != nil { - return wrap(err) - } + if err != nil { + return err + } + } - return nil -} + return nil + }) -func execMultiple(tx *sql.Tx, stmts []string) error { - var err error - for _, stmt := range stmts { - if err != nil { - break - } - _, err = tx.Exec(stmt) - } - return err + return errors.Wrap(err, "database migration failed") } diff --git a/adapter/sqlite/indexer.go b/adapter/sqlite/indexer.go index 11637fd..495f1d3 100644 --- a/adapter/sqlite/indexer.go +++ b/adapter/sqlite/indexer.go @@ -13,7 +13,7 @@ import ( // NoteIndexer retrieves and stores notes indexation in the SQLite database. // It implements the Core port note.Indexer. type NoteIndexer struct { - tx *sql.Tx + tx Transaction root string logger util.Logger @@ -24,7 +24,7 @@ type NoteIndexer struct { removeStmt *sql.Stmt } -func NewNoteIndexer(tx *sql.Tx, root string, logger util.Logger) (*NoteIndexer, error) { +func NewNoteIndexer(tx Transaction, root string, logger util.Logger) (*NoteIndexer, error) { indexedStmt, err := tx.Prepare(` SELECT filename, dir, modified from notes ORDER BY dir, filename ASC diff --git a/adapter/sqlite/transaction.go b/adapter/sqlite/transaction.go new file mode 100644 index 0000000..cb64bdd --- /dev/null +++ b/adapter/sqlite/transaction.go @@ -0,0 +1,63 @@ +package sqlite + +import "database/sql" + +// Inspired by https://pseudomuto.com/2018/01/clean-sql-transactions-in-golang/ + +// Transaction is an interface that models the standard transaction in +// database/sql. +// +// To ensure TxFn funcs cannot commit or rollback a transaction (which is +// handled by `WithTransaction`), those methods are not included here. +type Transaction interface { + Exec(query string, args ...interface{}) (sql.Result, error) + ExecStmts(stmts []string) error + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// txWrapper wraps a native sql.Tx to fully implement the Transaction interface. +type txWrapper struct { + *sql.Tx +} + +func (tx *txWrapper) ExecStmts(stmts []string) error { + var err error + for _, stmt := range stmts { + if err != nil { + break + } + _, err = tx.Exec(stmt) + } + return err +} + +// A Txfn is a function that will be called with an initialized Transaction +// object that can be used for executing statements and queries against a +// database. +type TxFn func(tx Transaction) error + +// WithTransaction creates a new transaction and handles rollback/commit based +// on the error object returned by the TxFn closure. +func (db *DB) WithTransaction(fn TxFn) error { + tx, err := db.db.Begin() + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + // A panic occurred, rollback and repanic. + tx.Rollback() + panic(p) + } else if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + } + }() + + err = fn(&txWrapper{tx}) + return err +} diff --git a/cmd/index.go b/cmd/index.go index 5f7ed34..099410b 100644 --- a/cmd/index.go +++ b/cmd/index.go @@ -26,19 +26,13 @@ func (cmd *Index) Run(container *Container) error { if err != nil { return err } - tx, err := db.Begin() - defer tx.Rollback() - if err != nil { - return err - } - indexer, err := sqlite.NewNoteIndexer(tx, zk.Path, container.Logger) - if err != nil { - return err - } - err = note.Index(*dir, indexer, container.Logger) - if err != nil { - return err - } - return tx.Commit() + return db.WithTransaction(func(tx sqlite.Transaction) error { + indexer, err := sqlite.NewNoteIndexer(tx, zk.Path, container.Logger) + if err != nil { + return err + } + + return note.Index(*dir, indexer, container.Logger) + }) }