-- part of the code from polarmutex/contextprint.nvim local ts_utils = require("nvim-treesitter.ts_utils") local ts_query = require("nvim-treesitter.query") local parsers = require("nvim-treesitter.parsers") local locals = require("nvim-treesitter.locals") local utils = require("go.ts.utils") local ulog = require("go.utils").log local M = {} -- Array M.intersect_nodes = function(nodes, row, col) local found = {} for idx = 1, #nodes do local node = nodes[idx] local sRow = node.dim.s.r local sCol = node.dim.s.c local eRow = node.dim.e.r local eCol = node.dim.e.c if utils.intersects(row, col, sRow, sCol, eRow, eCol) then table.insert(found, node) end end return found end M.count_parents = function(node) local count = 0 local n = node.declaring_node while n ~= nil do n = n:parent() count = count + 1 end return count end -- @param nodes Array -- perf note. I could memoize some values here... M.sort_nodes = function(nodes) table.sort( nodes, function(a, b) return M.count_parents(a) < M.count_parents(b) end ) return nodes end -- local lang = vim.api.nvim_buf_get_option(bufnr, 'ft') -- node_wrapper -- returns [{ -- declaring_node = tsnode -- dim: {s: {r, c}, e: {r, c}}, -- name: string -- type: string -- }] M.get_nodes = function(query, lang, defaults, bufnr) bufnr = bufnr or 0 local success, parsed_query = pcall( function() return vim.treesitter.parse_query(lang, query) end ) if not success then return nil end local parser = parsers.get_parser(bufnr, lang) local root = parser:parse()[1]:root() local start_row, _, end_row, _ = root:range() -- local n = ts_utils.get_node_at_cursor() -- local a, b, c, d = ts_utils.get_node_range(n) local results = {} for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do local sRow, sCol, eRow, eCol local declaration_node local type = "nil" local name = "nil" locals.recurse_local_nodes( match, function(_, node, path) local idx = string.find(path, ".", 1, true) local op = string.sub(path, idx + 1, #path) local a1, b1, c1, d1 = ts_utils.get_node_range(node) type = string.sub(path, 1, idx - 1) if name == nil then name = defaults[type] or "empty" end if op == "name" then name = ts_utils.get_node_text(node)[1] elseif op == "declaration" then declaration_node = node sRow, sCol, eRow, eCol = node:range() sRow = sRow + 1 eRow = eRow + 1 sCol = sCol + 1 eCol = eCol + 1 end end ) if declaration_node ~= nil then table.insert( results, { declaring_node = declaration_node, dim = { s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol} }, name = name, type = type } ) end end return results end -- local lang = vim.api.nvim_buf_get_option(bufnr, 'ft') -- node_wrapper -- returns { -- declaring_node = tsnode -- dim: {s: {r, c}, e: {r, c}}, -- name: string -- type: string -- } M.get_all_nodes = function(query, lang, defaults, bufnr, pos_row, pos_col) bufnr = bufnr or 0 -- todo a huge number pos_row = pos_row or 30000 local success, parsed_query = pcall( function() return vim.treesitter.parse_query(lang, query) end ) if not success then return nil end local parser = parsers.get_parser(bufnr, lang) local root = parser:parse()[1]:root() local start_row, _, end_row, _ = root:range() -- local n = ts_utils.get_node_at_cursor() -- local a, b, c, d = ts_utils.get_node_range(n) -- ulog("node range " .. tostring(a) .. tostring(b) .. tostring(c).. tostring(d)) -- ulog("cru node:" .. vim.inspect(n) .. "text: " .. vim.inspect(ts_utils.get_node_text(n))) local results = {} for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do local sRow, sCol, eRow, eCol local declaration_node local type = "" local name = "" local op = "" local method_receiver = "" locals.recurse_local_nodes( match, function(_, node, path) --local idx = string.find(path, ".", 1, true) local idx = string.find(path, ".[^.]*$") -- find last . op = string.sub(path, idx + 1, #path) local a1, b1, c1, d1 = ts_utils.get_node_range(node) local dbg_txt = ts_utils.get_node_text(node, bufnr)[1] type = string.sub(path, 1, idx - 1) ulog( "node" .. vim.inspect(node) .. "\n path: " .. path .. " op: " .. op .. " type: " .. type .. "\n txt: " .. dbg_txt .. "\n range: " .. tostring(a1) .. ":" .. tostring(b1) .. " TO " .. tostring(c1) .. ":" .. tostring(d1)) -- -- may not handle complex node if op == "name" then -- ulog("node name " .. name) name = ts_utils.get_node_text(node, bufnr)[1] elseif op == "declaration" or op == "clause" then declaration_node = node sRow, sCol, eRow, eCol = node:range() sRow = sRow + 1 eRow = eRow + 1 sCol = sCol + 1 eCol = eCol + 1 end end ) if declaration_node ~= nil then -- ulog(name .. " " .. op) -- ulog(sRow, pos_row) if sRow > pos_row then ulog("beyond " .. tostring(pos_row)) break end table.insert( results, { declaring_node = declaration_node, dim = { s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol} }, name = name, operator = op, type = type } ) end end -- ulog("total nodes got: " .. tostring(#results)) return results end M.nodes_at_cursor = function(query, default, bufnr, row, col) bufnr = bufnr or vim.api.nvim_get_current_buf() local ft = vim.api.nvim_buf_get_option(bufnr, "ft") if row == nil or col == nil then row, col = unpack(vim.api.nvim_win_get_cursor(0)) row, col = row + 1, col + 1 end local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col) if nodes == nil then print("Unable to find any nodes. Is your query correct?") return nil end nodes = M.sort_nodes(M.intersect_nodes(nodes, row, col)) if nodes == nil or #nodes == 0 then print("Unable to find any nodes at pos. " .. tostring(row) .. ":" .. tostring(col)) return nil end return nodes end return M