From 059973f557df6d1abd34ac4d669402313c9b03fe Mon Sep 17 00:00:00 2001 From: ray-x Date: Sat, 11 Nov 2023 15:43:20 +1100 Subject: [PATCH] Major updates for treesitter folding. Write how scope and fold level are calculated --- lua/navigator/codeAction.lua | 6 +- lua/navigator/diagnostics.lua | 12 +- lua/navigator/foldts.lua | 204 +++++++++++++++------------------- lua/navigator/util.lua | 5 +- 4 files changed, 101 insertions(+), 126 deletions(-) diff --git a/lua/navigator/codeAction.lua b/lua/navigator/codeAction.lua index 4038735..f47c1d8 100644 --- a/lua/navigator/codeAction.lua +++ b/lua/navigator/codeAction.lua @@ -57,7 +57,7 @@ local function _update_sign(line) if code_action[winid].lightbulb_line and code_action[winid].lightbulb_line > 0 then vim.fn.sign_unplace(sign_group, { id = code_action[winid].lightbulb_line, buffer = '%' }) - log('sign removed', line) + trace('sign removed', line) end if line then @@ -70,7 +70,7 @@ local function _update_sign(line) { lnum = line + 1, priority = config.lsp.code_action.sign_priority } ) code_action[winid].lightbulb_line = id - log('sign updated', id) + log('sign updated', id, line, sign_group, sign_name) end end @@ -122,7 +122,7 @@ function code_action:render_action_virtual_text(line, diagnostics) end vim.defer_fn(function() - log('clear vt') + trace('clear vt') if config.lsp.code_action.virtual_text then _update_virtual_text(nil) end diff --git a/lua/navigator/diagnostics.lua b/lua/navigator/diagnostics.lua index 894d336..b2965db 100644 --- a/lua/navigator/diagnostics.lua +++ b/lua/navigator/diagnostics.lua @@ -1,4 +1,5 @@ local gui = require('navigator.gui') +local uv = vim.uv or vim.loop local diagnostic = vim.diagnostic or vim.lsp.diagnostic -- local hide = diagnostic.hide or diagnostic.clear local util = require('navigator.util') @@ -48,7 +49,6 @@ local function error_marker(result, ctx, config) end local async - local uv = vim.uv or vim.loop async = uv.new_async(vim.schedule_wrap(function() if vim.tbl_isempty(result.diagnostics) then return @@ -119,7 +119,7 @@ local function error_marker(result, ctx, config) diags[i].range = { start = { line = diags[i].lnum } } end end - + local ratio = wheight / total_num table.sort(diags, function(a, b) return a.range.start.line < b.range.start.line end) @@ -131,7 +131,7 @@ local function error_marker(result, ctx, config) end if diag.range and diag.range.start and diag.range.start.line then p = diag.range.start.line + 1 -- convert to 1 based - p = util.round(p * wheight / math.max(wheight, total_num)) + p = util.round(p * ratio, ratio) trace('pos: ', diag.range.start.line, p) if pos[#pos] and pos[#pos].line == p then local bar = _NgConfigValues.lsp.diagnostic_scrollbar_sign[2] @@ -151,9 +151,7 @@ local function error_marker(result, ctx, config) trace('pos, line:', p, diag.severity, diag.range) end - if not vim.tbl_isempty(pos) then - api.nvim_buf_clear_namespace(bufnr, _NG_VT_DIAG_NS, 0, -1) - end + api.nvim_buf_clear_namespace(bufnr, _NG_VT_DIAG_NS, 0, -1) for _, s in pairs(pos) do local hl = 'ErrorMsg' if type(s.severity) == 'number' then @@ -209,7 +207,7 @@ local diag_hdlr = function(err, result, ctx, config) if mode ~= 'n' and config.update_in_insert == false then trace('skip sign update in insert mode') end - local cwd = vim.loop.cwd() + local cwd = uv.cwd() local ft = vim.bo.filetype if M.diagnostic_list[ft] == nil then M.diagnostic_list[vim.bo.filetype] = {} diff --git a/lua/navigator/foldts.lua b/lua/navigator/foldts.lua index cc764f7..e2094d3 100644 --- a/lua/navigator/foldts.lua +++ b/lua/navigator/foldts.lua @@ -2,6 +2,7 @@ local log = require('navigator.util').log local trace = require('navigator.util').trace +trace = log local api = vim.api local tsutils = require('nvim-treesitter.ts_utils') local query = require('nvim-treesitter.query') @@ -46,7 +47,7 @@ function NG_custom_fold_text() local tabspace = string.rep(' ', vim.o.tabstop) s = s:gsub('\t', tabspace) end - s = s:gsub('^ ', prefix) + s = s:gsub('^ ', prefix) -- replace prefix with two spaces if s ~= spaces[1] then spaces[1] = s spaces[2] = { '@keyword' } @@ -86,20 +87,42 @@ function M.setup_fold() api.nvim_win_set_option(current_window, 'foldexpr', 'folding#ngfoldexpr()') end -local function get_fold_level(levels, lnum) - local prev_l = levels[lnum] - local prev_ln - if prev_l:find('>') then - prev_ln = tonumber(prev_l:sub(2)) - else - prev_ln = tonumber(prev_l) +local function is_comment(line_number) + local node = get_node_at_line(line_number) + trace(node, node:type()) + if not node then + return false end - return prev_ln + local node_type = node:type() + trace(node_type) + return node_type == 'comment' or node_type == 'comment_block' end --- This is cached on buf tick to avoid computing that multiple times --- Especially not for every line in the file when `zx` is hit -local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) +local function get_comment_scopes(total_lines) + local comment_scopes = {} + local comment_start = nil + + for line = 0, total_lines - 1 do + if is_comment(line + 1) then + if not comment_start then + comment_start = line + end + elseif comment_start then + if line - comment_start > 2 then -- More than 2 lines + table.insert(comment_scopes, { comment_start, line }) + end + comment_start = nil + end + end + + -- Handle case where file ends with a multiline comment + if comment_start and total_lines - comment_start > 2 then + table.insert(comment_scopes, { comment_start, total_lines }) + end + trace(comment_scopes) + return comment_scopes +end +local function indent_levels(scopes, total_lines) local max_fold_level = api.nvim_win_get_option(0, 'foldnestmax') local trim_level = function(level) if level > max_fold_level then @@ -108,6 +131,39 @@ local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) return level end + local events = {} + local prev = { -1, -1 } + for _, scope in ipairs(scopes) do + if not (prev[1] == scope[1] and prev[2] == scope[2]) then + events[scope[1]] = (events[scope[1]] or 0) + 1 + events[scope[2]] = (events[scope[2]] or 0) - 1 + end + prev = scope + end + trace(events) + + local currentIndent = 0 + local indentLevels = {} + local prevIndentLevel = 0 + local levels = {} + for line = 0, total_lines - 1 do + if events[line] then + currentIndent = currentIndent + events[line] + end + indentLevels[line] = currentIndent + + local indentSymbol = indentLevels[line] > prevIndentLevel and '>' or ' ' + trace('Line ' .. line .. ': ' .. indentSymbol .. indentLevels[line]) + levels[line + 1] = indentSymbol .. tostring(trim_level(indentLevels[line])) + prevIndentLevel = indentLevels[line] + end + trace(levels) + return levels +end + +-- This is cached on buf tick to avoid computing that multiple times +-- Especially not for every line in the file when `zx` is hit +local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) local parser = parsers.get_parser(bufnr) if not parser then @@ -124,16 +180,24 @@ local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) end) -- start..stop is an inclusive range + + ---@type table local start_counts = {} + ---@type table local stop_counts = {} local prev_start = -1 local prev_stop = -1 local min_fold_lines = api.nvim_win_get_option(0, 'foldminlines') - - for _, node in ipairs(matches) do - local start, _, stop, stop_col = node.node:range() + local scopes = {} + for _, match in ipairs(matches) do + local start, stop, stop_col ---@type integer, integer, integer + if match.metadata and match.metadata.range then + start, _, stop, stop_col = unpack(match.metadata.range) ---@type integer, integer, integer, integer + else + start, _, stop, stop_col = match.node:range() ---@type integer, integer, integer, integer + end if stop_col == 0 then stop = stop - 1 @@ -141,116 +205,29 @@ local folds_levels = tsutils.memoize_by_buf_tick(function(bufnr) local fold_length = stop - start + 1 local should_fold = fold_length > min_fold_lines - -- Fold only multiline nodes that are not exactly the same as previously met folds -- Checking against just the previously found fold is sufficient if nodes -- are returned in preorder or postorder when traversing tree if should_fold and not (start == prev_start and stop == prev_stop) then start_counts[start] = (start_counts[start] or 0) + 1 stop_counts[stop] = (stop_counts[stop] or 0) + 1 + -- trace('fold scope', start, stop, match.node:type()) prev_start = start prev_stop = stop + table.insert(scopes, { start, stop }) end end - trace(start_counts) - trace(stop_counts) - - local levels = {} - local current_level = 0 - - -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - local pre_node - for lnum = 0, api.nvim_buf_line_count(bufnr) do - local node, _ = get_node_at_line(lnum + 1) - local comment = node:type() == 'comment' - - local next_node, _ = get_node_at_line(lnum + 1) - local next_comment = node and node:type() == 'comment' - local last_trimmed_level = trim_level(current_level) - current_level = current_level + (start_counts[lnum] or 0) - local trimmed_level = trim_level(current_level) - local current_level2 = current_level - (stop_counts[lnum] or 0) - local next_trimmed_level = trim_level(current_level2) - - trace(lnum, node:type(), node, last_trimmed_level, trimmed_level, next_trimmed_level) - if comment then - trace('comment node', trimmed_level) - -- if trimmed_level == 0 then - -- trimmed_level = 1 - -- end - - levels[lnum + 1] = tostring(trimmed_level + 2) - if pre_node and pre_node:type() ~= 'comment' then - levels[lnum + 1] = '>' .. tostring(trimmed_level + 2) - end - if next_node and next_node:type() ~= 'comment' then - levels[lnum + 1] = tostring(trimmed_level + 1) - end - else - -- Determine if it's the start/end of a fold - -- NB: vim's fold-expr interface does not have a mechanism to indicate that - -- two (or more) folds start at this line, so it cannot distinguish between - -- ( \n ( \n )) \n (( \n ) \n ) - -- versus - -- ( \n ( \n ) \n ( \n ) \n ) - -- If it did have such a mechansim, (trimmed_level - last_trimmed_level) - -- would be the correct number of starts to pass on. - if trimmed_level - last_trimmed_level > 0 then - if levels[lnum + 1] ~= '>' .. tostring(trimmed_level) then - levels[lnum + 1] = tostring(trimmed_level) -- hack do not fold current line as it is first in fold range - end - levels[lnum + 2] = '>' .. tostring(trimmed_level + 1) -- dirty hack fold start from next line - trace('fold start') - elseif trimmed_level - next_trimmed_level > 0 then -- last line in fold range - -- Ending marks tend to confuse vim more than it helps, particularly when - -- the fold level changes by at least 2; we can uncomment this if - -- vim's behavior gets fixed. - - trace('fold end') - if levels[lnum + 1] then - trace('already set reset as fold is ending', levels[lnum + 1]) - levels[lnum + 1] = tostring(trimmed_level + 1) - else - local prev_ln = get_fold_level(levels, lnum) - 1 - if prev_ln == 0 then - prev_ln = 1 - end - levels[lnum + 1] = tostring(prev_ln) - end - -- levels[lnum + 1] = tostring(trimmed_level + 1) - -- else - current_level = current_level - 1 - else - trace('same') - if pre_node and pre_node:type() == 'comment' then - local prev_ln = get_fold_level(levels, lnum) - 1 - levels[lnum + 1] = tostring(prev_ln) - else - local n = math.max(trimmed_level, 1) - if lnum > 1 then - if levels[lnum + 1] then - trace('already set', levels[lnum + 1]) - else - local prev_l = levels[lnum] - if prev_l:find('>') then - levels[lnum + 1] = prev_l:sub(2) - else - levels[lnum + 1] = prev_l - end - end - else - levels[lnum + 1] = tostring(n) - end - end - end - trace(levels) + local total_lines = api.nvim_buf_line_count(bufnr) + local comment_scopes = get_comment_scopes(total_lines) + scopes = vim.list_extend(scopes, comment_scopes) + table.sort(scopes, function(a, b) + if a[1] == b[1] then + return a[2] < b[2] end - pre_node = node - end - trace(levels) - return levels + return a[1] < b[1] + end) + return indent_levels(scopes, total_lines) end) - function M.get_fold_indic(lnum) if not parsers.has_parser() or not lnum then return '0' @@ -268,7 +245,6 @@ function M.get_fold_indic(lnum) return '0' end local levels = folds_levels(buf) or {} - -- trace(lnum, levels[lnum]) -- TODO: comment it out in master return levels[lnum] or '0' end diff --git a/lua/navigator/util.lua b/lua/navigator/util.lua index 02bbb14..31fcf5d 100644 --- a/lua/navigator/util.lua +++ b/lua/navigator/util.lua @@ -30,8 +30,9 @@ M.path_cur = function() end end -M.round = function(x) - return math.max(0, math.floor(x - 0.5)) +M.round = function(x, r) + r = r or 0.5 + return math.max(0, math.floor(x - r)) end function M.get_data_from_file(filename, startLine)