diff --git a/adapter/sqlite/note_dao.go b/adapter/sqlite/note_dao.go index 16c71f1..a8c4769 100644 --- a/adapter/sqlite/note_dao.go +++ b/adapter/sqlite/note_dao.go @@ -2,6 +2,7 @@ package sqlite import ( "database/sql" + "strings" "time" "github.com/mickael-menu/zk/core/note" @@ -163,7 +164,7 @@ func (d *NoteDAO) Find(callback func(note.Match) error, filters ...note.Filter) WHERE notes_fts MATCH ? ORDER BY bm25(notes_fts, 1000.0, 500.0, 1.0) --- ORDER BY rank - `, filter) + `, escapeForFTS5(string(filter))) } }() @@ -202,3 +203,81 @@ func (d *NoteDAO) Find(callback func(note.Match) error, filters ...note.Filter) return nil } + +func escapeForFTS5(query string) string { + quote := false + out := "" + term := "" + + endTerm := func() { + if term == "" { + return + } + switch term { + case "AND", "OR", "NOT": + out += term + default: + isPrefixToken := strings.HasSuffix(term, "*") + if isPrefixToken { + term = strings.TrimSuffix(term, "*") + } + out += `"` + term + `"` + if isPrefixToken { + out += "*" + } + } + term = "" + } + + for _, c := range query { + switch { + case c == '"': + if quote { + endTerm() + } + quote = !quote + + case c == '^' || c == '*': + if term != "" { + term += string(c) + } else { + out += string(c) + } + + case c == '-': + if term == "" { + out += " NOT " + } else { + term += string(c) + } + + case c == ':': + if term != "" && !quote { + out += term + string(c) + term = "" + } else { + term += string(c) + } + + case c == '+': + if term != "" || quote { + term += string(c) + } + + case c == ' ', c == '\t', c == '\n', c == '(', c == ')': + if !quote { + endTerm() + out += string(c) + } else { + term += string(c) + } + + default: + term = term + string(c) + } + } + + endTerm() + + return out +} diff --git a/adapter/sqlite/note_dao_test.go b/adapter/sqlite/note_dao_test.go index 4234248..77eea63 100644 --- a/adapter/sqlite/note_dao_test.go +++ b/adapter/sqlite/note_dao_test.go @@ -214,13 +214,44 @@ func TestNoteDAOFindAll(t *testing.T) { }) } -func testNoteDAOFind(t *testing.T, expected []note.Match) { +func TestNoteDAOFindMatch(t *testing.T) { + expected := []note.Match{ + { + Snippet: "A daily note", + Metadata: note.Metadata{ + Path: "log/2021-01-03.md", + Title: "January 3, 2021", + Body: "A daily note", + WordCount: 3, + Created: time.Date(2020, 11, 22, 16, 27, 45, 0, time.Local), + Modified: time.Date(2020, 11, 22, 16, 27, 45, 0, time.Local), + Checksum: "qwfpgj", + }, + }, + { + Snippet: "A second daily note", + Metadata: note.Metadata{ + Path: "log/2021-01-04.md", + Title: "January 4, 2021", + Body: "A second daily note", + WordCount: 4, + Created: time.Date(2020, 11, 29, 8, 20, 18, 0, time.Local), + Modified: time.Date(2020, 11, 29, 8, 20, 18, 0, time.Local), + Checksum: "arstde", + }, + }, + } + + testNoteDAOFind(t, expected, note.MatchFilter("daily")) +} + +func testNoteDAOFind(t *testing.T, expected []note.Match, filters ...note.Filter) { testNoteDAO(t, func(tx Transaction, dao *NoteDAO) { actual := make([]note.Match, 0) err := dao.Find(func(m note.Match) error { actual = append(actual, m) return nil - }) + }, filters...) assert.Nil(t, err) assert.Equal(t, actual, expected) }) @@ -247,3 +278,61 @@ func queryNoteRow(tx Transaction, where string) (noteRow, error) { `, where)).Scan(&row.Path, &row.Title, &row.Body, &row.WordCount, &row.Checksum, &row.Created, &row.Modified) return row, err } + +func TestEscapeForFTS5(t *testing.T) { + test := func(text, expected string) { + assert.Equal(t, escapeForFTS5(text), expected) + } + + test(`foo`, `"foo"`) + test(`foo bar`, `"foo" "bar"`) + test(`"foo"`, `"foo"`) + test(`"foo bar"`, `"foo bar"`) + test(`"foo bar" qux`, `"foo bar" "qux"`) + + test(`foo AND bar`, `"foo" AND "bar"`) + test(`foo AN bar`, `"foo" "AN" "bar"`) + test(`foo ANT bar`, `"foo" "ANT" "bar"`) + test(`"foo AND bar"`, `"foo AND bar"`) + test(`foo OR bar`, `"foo" OR "bar"`) + test(`foo NOT bar`, `"foo" NOT "bar"`) + test(`(foo AND bar) OR qux`, `("foo" AND "bar") OR "qux"`) + + test(`foo -bar`, `"foo" NOT "bar"`) + test(`"foo -bar"`, `"foo -bar"`) + test(`foo-bar`, `"foo-bar"`) + + test(`foo/bar`, `"foo/bar"`) + test(`foo;bar`, `"foo;bar"`) + test(`foo,bar`, `"foo,bar"`) + test(`foo&bar`, `"foo&bar"`) + test(`foo's bar`, `"foo's" "bar"`) + + test(`foo ba*`, `"foo" "ba"*`) + test(`foo ba* qux`, `"foo" "ba"* "qux"`) + test(`"foo ba"*`, `"foo ba"*`) + test(`(foo ba*)`, `("foo" "ba"*)`) + test(`foo*bar`, `"foo*bar"`) + test(`"foo*bar"`, `"foo*bar"`) + + test(`col:foo bar`, `col:"foo" "bar"`) + test(`foo col:bar`, `"foo" col:"bar"`) + test(`foo "col:bar"`, `"foo" "col:bar"`) + test(`":foo"`, `":foo"`) + test(`-col:foo bar`, ` NOT col:"foo" "bar"`) + test(`col:(foo bar)`, `col:("foo" "bar")`) + + test(`^foo`, `^"foo"`) + test(`^foo bar`, `^"foo" "bar"`) + test(`foo ^bar`, `"foo" ^"bar"`) + test(`^"foo bar"`, `^"foo bar"`) + test(`"foo ^bar"`, `"foo ^bar"`) + test(`col:^foo`, `col:^"foo"`) + + test(`foo + bar`, `"foo" "bar"`) + test(`"foo + bar"`, `"foo + bar"`) + test(`"+foo"`, `"+foo"`) + + // NEAR is not supported + test(`NEAR(foo, bar, 4)`, `"NEAR"("foo," "bar," "4")`) +} diff --git a/cmd/list.go b/cmd/list.go index 86f362a..cd9690c 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -11,9 +11,9 @@ import ( // List displays notes matching a set of criteria. type List struct { - Path string `arg optional placeholder:"PATH"` - Match string `help:"Terms to search for in the notes" placeholder:"TERMS"` - Format string `help:"Pretty prints the list using the given format" placeholder:"TEMPLATE"` + Path []string `arg optional placeholder:"PATHS"` + Match string `help:"Terms to search for in the notes" placeholder:"TERMS"` + Format string `help:"Pretty prints the list using the given format" placeholder:"TEMPLATE"` } func (cmd *List) Run(container *Container) error {