diff --git a/adapter/sqlite/note_dao.go b/adapter/sqlite/note_dao.go index dcc4c6e..d7c7865 100644 --- a/adapter/sqlite/note_dao.go +++ b/adapter/sqlite/note_dao.go @@ -147,85 +147,7 @@ func (d *NoteDAO) exists(path string) (bool, error) { } func (d *NoteDAO) Find(opts note.FinderOpts, callback func(note.Match) error) (int, error) { - rows, err := func() (*sql.Rows, error) { - snippetCol := `""` - whereExprs := make([]string, 0) - orderTerms := make([]string, 0) - args := make([]interface{}, 0) - - for _, filter := range opts.Filters { - switch filter := filter.(type) { - - case note.MatchFilter: - snippetCol = `snippet(notes_fts, 2, '', '', '…', 20) as snippet` - orderTerms = append(orderTerms, `bm25(notes_fts, 1000.0, 500.0, 1.0)`) - whereExprs = append(whereExprs, "notes_fts MATCH ?") - args = append(args, fts5.ConvertQuery(string(filter))) - - case note.PathFilter: - if len(filter) == 0 { - break - } - globs := make([]string, 0) - for _, path := range filter { - globs = append(globs, "n.path GLOB ?") - args = append(args, path+"*") - } - whereExprs = append(whereExprs, strings.Join(globs, " OR ")) - - case note.ExcludePathFilter: - if len(filter) == 0 { - break - } - globs := make([]string, 0) - for _, path := range filter { - globs = append(globs, "n.path NOT GLOB ?") - args = append(args, path+"*") - } - whereExprs = append(whereExprs, strings.Join(globs, " AND ")) - - case note.DateFilter: - value := "?" - field := "n." + dateField(filter) - op, ignoreTime := dateDirection(filter) - if ignoreTime { - field = "date(" + field + ")" - value = "date(?)" - } - - whereExprs = append(whereExprs, fmt.Sprintf("%s %s %s", field, op, value)) - args = append(args, filter.Date) - - default: - panic(fmt.Sprintf("%v: unknown filter type", filter)) - } - } - - for _, sorter := range opts.Sorters { - orderTerms = append(orderTerms, orderTerm(sorter)) - } - orderTerms = append(orderTerms, `n.title ASC`) - - query := "SELECT n.id, n.path, n.title, n.body, n.word_count, n.created, n.modified, n.checksum, " + snippetCol - - query += ` -FROM notes n -JOIN notes_fts -ON n.id = notes_fts.rowid` - - if len(whereExprs) > 0 { - query += "\nWHERE " + strings.Join(whereExprs, "\nAND ") - } - - query += "\nORDER BY " + strings.Join(orderTerms, ", ") - - if opts.Limit > 0 { - query += fmt.Sprintf("\nLIMIT %d", opts.Limit) - } - - return d.tx.Query(query, args...) - }() - + rows, err := d.findRows(opts) if err != nil { return 0, err } @@ -265,6 +187,94 @@ ON n.id = notes_fts.rowid` return count, nil } +type findQuery struct { + SnippetCol string + WhereExprs []string + OrderTerms []string + Args []interface{} +} + +func (d *NoteDAO) findRows(opts note.FinderOpts) (*sql.Rows, error) { + snippetCol := `""` + whereExprs := make([]string, 0) + orderTerms := make([]string, 0) + args := make([]interface{}, 0) + + for _, filter := range opts.Filters { + switch filter := filter.(type) { + + case note.MatchFilter: + snippetCol = `snippet(notes_fts, 2, '', '', '…', 20) as snippet` + orderTerms = append(orderTerms, `bm25(notes_fts, 1000.0, 500.0, 1.0)`) + whereExprs = append(whereExprs, "notes_fts MATCH ?") + args = append(args, fts5.ConvertQuery(string(filter))) + + case note.PathFilter: + if len(filter) == 0 { + break + } + globs := make([]string, 0) + for _, path := range filter { + globs = append(globs, "n.path GLOB ?") + args = append(args, path+"*") + } + whereExprs = append(whereExprs, strings.Join(globs, " OR ")) + + case note.ExcludePathFilter: + if len(filter) == 0 { + break + } + globs := make([]string, 0) + for _, path := range filter { + globs = append(globs, "n.path NOT GLOB ?") + args = append(args, path+"*") + } + whereExprs = append(whereExprs, strings.Join(globs, " AND ")) + + case note.DateFilter: + value := "?" + field := "n." + dateField(filter) + op, ignoreTime := dateDirection(filter) + if ignoreTime { + field = "date(" + field + ")" + value = "date(?)" + } + + whereExprs = append(whereExprs, fmt.Sprintf("%s %s %s", field, op, value)) + args = append(args, filter.Date) + + default: + panic(fmt.Sprintf("%v: unknown filter type", filter)) + } + } + + for _, sorter := range opts.Sorters { + orderTerms = append(orderTerms, orderTerm(sorter)) + } + orderTerms = append(orderTerms, `n.title ASC`) + + query := "SELECT n.id, n.path, n.title, n.body, n.word_count, n.created, n.modified, n.checksum, " + snippetCol + + query += ` +FROM notes n +JOIN notes_fts +ON n.id = notes_fts.rowid` + + if len(whereExprs) > 0 { + query += "\nWHERE " + strings.Join(whereExprs, "\nAND ") + } + + query += "\nORDER BY " + strings.Join(orderTerms, ", ") + + if opts.Limit > 0 { + query += fmt.Sprintf("\nLIMIT %d", opts.Limit) + } + + // fmt.Println(query) + // fmt.Println(args) + return d.tx.Query(query, args...) +} + func dateField(filter note.DateFilter) string { switch filter.Field { case note.DateCreated: diff --git a/cmd/container.go b/cmd/container.go index 634c2ef..119773e 100644 --- a/cmd/container.go +++ b/cmd/container.go @@ -1,11 +1,15 @@ package cmd import ( + "io" + "github.com/mickael-menu/zk/adapter/handlebars" "github.com/mickael-menu/zk/adapter/sqlite" "github.com/mickael-menu/zk/adapter/tty" + "github.com/mickael-menu/zk/core/zk" "github.com/mickael-menu/zk/util" "github.com/mickael-menu/zk/util/date" + "github.com/mickael-menu/zk/util/pager" ) type Container struct { @@ -43,3 +47,25 @@ func (c *Container) Database(path string) (*sqlite.DB, error) { err = db.Migrate() return db, err } + +// Paginate creates an auto-closing io.Writer which will be automatically +// paginated if noPager is false, using the user's pager. +// +// You can write to the pager only in the run callback. +func (c *Container) Paginate(noPager bool, config zk.Config, run func(out io.Writer) error) error { + pager, err := c.pager(noPager || config.NoPager, config) + if err != nil { + return err + } + err = run(pager) + pager.Close() + return err +} + +func (c *Container) pager(noPager bool, config zk.Config) (*pager.Pager, error) { + if noPager { + return pager.PassthroughPager, nil + } else { + return pager.New(config.Pager, c.Logger) + } +} diff --git a/cmd/list.go b/cmd/list.go index a908d72..1637fe8 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "io" "time" "github.com/mickael-menu/zk/adapter/sqlite" @@ -9,7 +10,6 @@ import ( "github.com/mickael-menu/zk/core/zk" "github.com/mickael-menu/zk/util/errors" "github.com/mickael-menu/zk/util/opt" - "github.com/mickael-menu/zk/util/pager" "github.com/mickael-menu/zk/util/strings" "github.com/tj/go-naturaldate" ) @@ -58,17 +58,11 @@ func (cmd *List) Run(container *Container) error { Templates: container.TemplateLoader(zk.Config.Lang), } - p := pager.PassthroughPager - if !cmd.NoPager { - p, err = pager.New(logger) - if err != nil { - return err - } - } - - count, err := note.List(*opts, deps, p.WriteString) - - p.Close() + count := 0 + err = container.Paginate(cmd.NoPager, zk.Config, func(out io.Writer) error { + count, err = note.List(*opts, deps, out) + return err + }) if err == nil { fmt.Printf("\nFound %d %s\n", count, strings.Pluralize("result", count)) diff --git a/core/note/list.go b/core/note/list.go index 3ddd321..0f9abce 100644 --- a/core/note/list.go +++ b/core/note/list.go @@ -1,6 +1,8 @@ package note import ( + "fmt" + "io" "os" "path/filepath" "regexp" @@ -24,7 +26,7 @@ type ListDeps struct { // List finds notes matching given criteria and formats them according to user // preference. -func List(opts ListOpts, deps ListDeps, callback func(formattedNote string) error) (int, error) { +func List(opts ListOpts, deps ListDeps, out io.Writer) (int, error) { templ := matchTemplate(opts.Format) template, err := deps.Templates.Load(templ) if err != nil { @@ -40,7 +42,9 @@ func List(opts ListOpts, deps ListDeps, callback func(formattedNote string) erro if err != nil { return err } - return callback(res) + + _, err = fmt.Fprintln(out, res) + return err }) } diff --git a/core/zk/config.go b/core/zk/config.go index f90f168..45b7069 100644 --- a/core/zk/config.go +++ b/core/zk/config.go @@ -11,8 +11,10 @@ import ( // Config holds the global user configuration. type Config struct { DirConfig - Dirs map[string]DirConfig - Editor opt.String + Dirs map[string]DirConfig + Editor opt.String + Pager opt.String + NoPager bool } // DirConfig holds the user configuration for a given directory. @@ -115,6 +117,8 @@ func ParseConfig(content []byte, templatesDir string) (*Config, error) { DirConfig: root, Dirs: make(map[string]DirConfig), Editor: opt.NewNotEmptyString(hcl.Editor), + Pager: opt.NewNotEmptyString(hcl.Pager), + NoPager: hcl.NoPager, } for _, dirHCL := range hcl.Dirs { @@ -183,6 +187,8 @@ type hclConfig struct { Extra map[string]string `hcl:"extra,optional"` Dirs []hclDirConfig `hcl:"dir,block"` Editor string `hcl:"editor,optional"` + Pager string `hcl:"pager,optional"` + NoPager bool `hcl:"no-pager,optional"` } type hclDirConfig struct { diff --git a/core/zk/config_test.go b/core/zk/config_test.go index 171ba9c..dedfa30 100644 --- a/core/zk/config_test.go +++ b/core/zk/config_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/mickael-menu/zk/util/test/assert" "github.com/mickael-menu/zk/util/opt" + "github.com/mickael-menu/zk/util/test/assert" ) func TestParseDefaultConfig(t *testing.T) { @@ -43,6 +43,8 @@ func TestParseComplete(t *testing.T) { conf, err := ParseConfig([]byte(` // Comment editor = "vim" + pager = "less" + no-pager = true filename = "{{id}}.note" extension = "txt" template = "default.note" @@ -130,7 +132,9 @@ func TestParseComplete(t *testing.T) { }, }, }, - Editor: opt.NewString("vim"), + Editor: opt.NewString("vim"), + Pager: opt.NewString("less"), + NoPager: true, }) } diff --git a/util/opt/opt.go b/util/opt/opt.go index d6f54ba..d1973e2 100644 --- a/util/opt/opt.go +++ b/util/opt/opt.go @@ -56,7 +56,8 @@ func (s String) Unwrap() string { } func (s String) Equal(other String) bool { - return s.value == other.value || *s.value == *other.value + return s.value == other.value || + (s.value != nil && other.value != nil && *s.value == *other.value) } func (s String) String() string { diff --git a/util/pager/pager.go b/util/pager/pager.go index 6adb509..4fef972 100644 --- a/util/pager/pager.go +++ b/util/pager/pager.go @@ -31,10 +31,10 @@ var PassthroughPager = &Pager{ } // New creates a pager.Pager to be used to write a paginated text to the TTY. -func New(logger util.Logger) (*Pager, error) { +func New(pagerCmd opt.String, logger util.Logger) (*Pager, error) { wrap := errors.Wrapper("failed to paginate the output, try again with --no-pager or fix your PAGER environment variable") - pagerCmd := locatePager() + pagerCmd = selectPagerCmd(pagerCmd) if pagerCmd.IsNull() { return PassthroughPager, nil } @@ -98,17 +98,24 @@ func (p *Pager) WriteString(text string) error { return err } -func locatePager() opt.String { +// selectPagerCmd returns the paging command meant to be run. +// +// By order of precedence: ZK_PAGER, config.pager, PAGER then the default +// pagers. +func selectPagerCmd(userPager opt.String) opt.String { return osutil.GetOptEnv("ZK_PAGER"). + Or(userPager). Or(osutil.GetOptEnv("PAGER")). - Or(locateDefaultPager()) + Or(selectDefaultPager()) } var defaultPagers = []string{ "less -FIRX", "more -R", } -func locateDefaultPager() opt.String { +// selectDefaultPager returns the first pager in the list of defaultPagers +// available on the execution paths. +func selectDefaultPager() opt.String { for _, pager := range defaultPagers { parts, err := shellquote.Split(pager) if err != nil {