diff --git a/adapter/sqlite/note_dao.go b/adapter/sqlite/note_dao.go
index 16c71f1..a8c4769 100644
--- a/adapter/sqlite/note_dao.go
+++ b/adapter/sqlite/note_dao.go
@@ -2,6 +2,7 @@ package sqlite
import (
"database/sql"
+ "strings"
"time"
"github.com/mickael-menu/zk/core/note"
@@ -163,7 +164,7 @@ func (d *NoteDAO) Find(callback func(note.Match) error, filters ...note.Filter)
WHERE notes_fts MATCH ?
ORDER BY bm25(notes_fts, 1000.0, 500.0, 1.0)
--- ORDER BY rank
- `, filter)
+ `, escapeForFTS5(string(filter)))
}
}()
@@ -202,3 +203,81 @@ func (d *NoteDAO) Find(callback func(note.Match) error, filters ...note.Filter)
return nil
}
+
+func escapeForFTS5(query string) string {
+ quote := false
+ out := ""
+ term := ""
+
+ endTerm := func() {
+ if term == "" {
+ return
+ }
+ switch term {
+ case "AND", "OR", "NOT":
+ out += term
+ default:
+ isPrefixToken := strings.HasSuffix(term, "*")
+ if isPrefixToken {
+ term = strings.TrimSuffix(term, "*")
+ }
+ out += `"` + term + `"`
+ if isPrefixToken {
+ out += "*"
+ }
+ }
+ term = ""
+ }
+
+ for _, c := range query {
+ switch {
+ case c == '"':
+ if quote {
+ endTerm()
+ }
+ quote = !quote
+
+ case c == '^' || c == '*':
+ if term != "" {
+ term += string(c)
+ } else {
+ out += string(c)
+ }
+
+ case c == '-':
+ if term == "" {
+ out += " NOT "
+ } else {
+ term += string(c)
+ }
+
+ case c == ':':
+ if term != "" && !quote {
+ out += term + string(c)
+ term = ""
+ } else {
+ term += string(c)
+ }
+
+ case c == '+':
+ if term != "" || quote {
+ term += string(c)
+ }
+
+ case c == ' ', c == '\t', c == '\n', c == '(', c == ')':
+ if !quote {
+ endTerm()
+ out += string(c)
+ } else {
+ term += string(c)
+ }
+
+ default:
+ term = term + string(c)
+ }
+ }
+
+ endTerm()
+
+ return out
+}
diff --git a/adapter/sqlite/note_dao_test.go b/adapter/sqlite/note_dao_test.go
index 4234248..77eea63 100644
--- a/adapter/sqlite/note_dao_test.go
+++ b/adapter/sqlite/note_dao_test.go
@@ -214,13 +214,44 @@ func TestNoteDAOFindAll(t *testing.T) {
})
}
-func testNoteDAOFind(t *testing.T, expected []note.Match) {
+func TestNoteDAOFindMatch(t *testing.T) {
+ expected := []note.Match{
+ {
+ Snippet: "A daily note",
+ Metadata: note.Metadata{
+ Path: "log/2021-01-03.md",
+ Title: "January 3, 2021",
+ Body: "A daily note",
+ WordCount: 3,
+ Created: time.Date(2020, 11, 22, 16, 27, 45, 0, time.Local),
+ Modified: time.Date(2020, 11, 22, 16, 27, 45, 0, time.Local),
+ Checksum: "qwfpgj",
+ },
+ },
+ {
+ Snippet: "A second daily note",
+ Metadata: note.Metadata{
+ Path: "log/2021-01-04.md",
+ Title: "January 4, 2021",
+ Body: "A second daily note",
+ WordCount: 4,
+ Created: time.Date(2020, 11, 29, 8, 20, 18, 0, time.Local),
+ Modified: time.Date(2020, 11, 29, 8, 20, 18, 0, time.Local),
+ Checksum: "arstde",
+ },
+ },
+ }
+
+ testNoteDAOFind(t, expected, note.MatchFilter("daily"))
+}
+
+func testNoteDAOFind(t *testing.T, expected []note.Match, filters ...note.Filter) {
testNoteDAO(t, func(tx Transaction, dao *NoteDAO) {
actual := make([]note.Match, 0)
err := dao.Find(func(m note.Match) error {
actual = append(actual, m)
return nil
- })
+ }, filters...)
assert.Nil(t, err)
assert.Equal(t, actual, expected)
})
@@ -247,3 +278,61 @@ func queryNoteRow(tx Transaction, where string) (noteRow, error) {
`, where)).Scan(&row.Path, &row.Title, &row.Body, &row.WordCount, &row.Checksum, &row.Created, &row.Modified)
return row, err
}
+
+func TestEscapeForFTS5(t *testing.T) {
+ test := func(text, expected string) {
+ assert.Equal(t, escapeForFTS5(text), expected)
+ }
+
+ test(`foo`, `"foo"`)
+ test(`foo bar`, `"foo" "bar"`)
+ test(`"foo"`, `"foo"`)
+ test(`"foo bar"`, `"foo bar"`)
+ test(`"foo bar" qux`, `"foo bar" "qux"`)
+
+ test(`foo AND bar`, `"foo" AND "bar"`)
+ test(`foo AN bar`, `"foo" "AN" "bar"`)
+ test(`foo ANT bar`, `"foo" "ANT" "bar"`)
+ test(`"foo AND bar"`, `"foo AND bar"`)
+ test(`foo OR bar`, `"foo" OR "bar"`)
+ test(`foo NOT bar`, `"foo" NOT "bar"`)
+ test(`(foo AND bar) OR qux`, `("foo" AND "bar") OR "qux"`)
+
+ test(`foo -bar`, `"foo" NOT "bar"`)
+ test(`"foo -bar"`, `"foo -bar"`)
+ test(`foo-bar`, `"foo-bar"`)
+
+ test(`foo/bar`, `"foo/bar"`)
+ test(`foo;bar`, `"foo;bar"`)
+ test(`foo,bar`, `"foo,bar"`)
+ test(`foo&bar`, `"foo&bar"`)
+ test(`foo's bar`, `"foo's" "bar"`)
+
+ test(`foo ba*`, `"foo" "ba"*`)
+ test(`foo ba* qux`, `"foo" "ba"* "qux"`)
+ test(`"foo ba"*`, `"foo ba"*`)
+ test(`(foo ba*)`, `("foo" "ba"*)`)
+ test(`foo*bar`, `"foo*bar"`)
+ test(`"foo*bar"`, `"foo*bar"`)
+
+ test(`col:foo bar`, `col:"foo" "bar"`)
+ test(`foo col:bar`, `"foo" col:"bar"`)
+ test(`foo "col:bar"`, `"foo" "col:bar"`)
+ test(`":foo"`, `":foo"`)
+ test(`-col:foo bar`, ` NOT col:"foo" "bar"`)
+ test(`col:(foo bar)`, `col:("foo" "bar")`)
+
+ test(`^foo`, `^"foo"`)
+ test(`^foo bar`, `^"foo" "bar"`)
+ test(`foo ^bar`, `"foo" ^"bar"`)
+ test(`^"foo bar"`, `^"foo bar"`)
+ test(`"foo ^bar"`, `"foo ^bar"`)
+ test(`col:^foo`, `col:^"foo"`)
+
+ test(`foo + bar`, `"foo" "bar"`)
+ test(`"foo + bar"`, `"foo + bar"`)
+ test(`"+foo"`, `"+foo"`)
+
+ // NEAR is not supported
+ test(`NEAR(foo, bar, 4)`, `"NEAR"("foo," "bar," "4")`)
+}
diff --git a/cmd/list.go b/cmd/list.go
index 86f362a..cd9690c 100644
--- a/cmd/list.go
+++ b/cmd/list.go
@@ -11,9 +11,9 @@ import (
// List displays notes matching a set of criteria.
type List struct {
- Path string `arg optional placeholder:"PATH"`
- Match string `help:"Terms to search for in the notes" placeholder:"TERMS"`
- Format string `help:"Pretty prints the list using the given format" placeholder:"TEMPLATE"`
+ Path []string `arg optional placeholder:"PATHS"`
+ Match string `help:"Terms to search for in the notes" placeholder:"TERMS"`
+ Format string `help:"Pretty prints the list using the given format" placeholder:"TEMPLATE"`
}
func (cmd *List) Run(container *Container) error {