cursor range col/row calculate

This commit is contained in:
ray-x 2021-07-17 01:58:44 +10:00
parent e7a7e95c85
commit a093d2420b
5 changed files with 75 additions and 103 deletions

View File

@ -5,7 +5,7 @@ M.test_fun = function(args)
local fpath = vim.fn.expand('%:p:h')
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1
row, col = row, col + 1
local ns = require("go.ts.go").get_func_method_node_at_pos(row, col)
if ns == nil or ns == {} then
return

View File

@ -18,7 +18,6 @@ tags.modify = function(...)
require("go.install").install(gomodify)
local fname = vim.fn.expand("%") -- %:p:h ? %:p
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1
local ns = require("go.ts.go").get_struct_node_at_pos(row, col)
if ns == nil or ns == {} then
return

View File

@ -1,5 +1,7 @@
local nodes = require("go.ts.nodes")
local log = require("go.utils").log
M = {
-- query_struct = "(type_spec name:(type_identifier) @definition.struct type: (struct_type))",
query_package = "(package_clause (package_identifier)@package.name)@package.clause",
@ -8,7 +10,7 @@ M = {
query_struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]],
query_em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]],
query_struct_block_from_id = [[(((type_spec name:(type_identifier) type: (struct_type)))@block.struct_from_id)]],
--query_em_struct = "(field_declaration name:(field_identifier) @definition.struct type: (struct_type))",
-- query_em_struct = "(field_declaration name:(field_identifier) @definition.struct type: (struct_type))",
query_interface_id = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]],
query_interface_method = [[((method_spec name: (field_identifier)@method.name)@interface.method.declaration)]],
query_func = "((function_declaration name: (identifier)@function.name) @function.declaration)",
@ -84,12 +86,7 @@ M = {
)@method.declaration)]]
}
function get_name_defaults()
return {
["func"] = "function",
["if"] = "if",
["else"] = "else",
["for"] = "for"
}
return {["func"] = "function", ["if"] = "if", ["else"] = "else", ["for"] = "for"}
end
M.get_struct_node_at_pos = function(row, col)
@ -99,6 +96,7 @@ M.get_struct_node_at_pos = function(row, col)
if ns == nil then
print("struct not found")
else
log('struct node', ns)
return ns[#ns]
end
end

View File

@ -1,5 +1,4 @@
-- part of the code from polarmutex/contextprint.nvim
local ts_utils = require("nvim-treesitter.ts_utils")
local ts_query = require("nvim-treesitter.query")
local parsers = require("nvim-treesitter.parsers")
@ -40,12 +39,9 @@ end
-- @param nodes Array<node_wrapper>
-- perf note. I could memoize some values here...
M.sort_nodes = function(nodes)
table.sort(
nodes,
function(a, b)
return M.count_parents(a) < M.count_parents(b)
end
)
table.sort(nodes, function(a, b)
return M.count_parents(a) < M.count_parents(b)
end)
return nodes
end
@ -59,12 +55,9 @@ end
-- }]
M.get_nodes = function(query, lang, defaults, bufnr)
bufnr = bufnr or 0
local success, parsed_query =
pcall(
function()
return vim.treesitter.parse_query(lang, query)
end
)
local success, parsed_query = pcall(function()
return vim.treesitter.parse_query(lang, query)
end)
if not success then
return nil
end
@ -80,45 +73,36 @@ M.get_nodes = function(query, lang, defaults, bufnr)
local declaration_node
local type = "nil"
local name = "nil"
locals.recurse_local_nodes(
match,
function(_, node, path)
local idx = string.find(path, ".", 1, true)
local op = string.sub(path, idx + 1, #path)
locals.recurse_local_nodes(match, function(_, node, path)
local idx = string.find(path, ".", 1, true)
local op = string.sub(path, idx + 1, #path)
local a1, b1, c1, d1 = ts_utils.get_node_range(node)
local a1, b1, c1, d1 = ts_utils.get_node_range(node)
type = string.sub(path, 1, idx - 1)
if name == nil then
name = defaults[type] or "empty"
end
if op == "name" then
name = ts_utils.get_node_text(node)[1]
elseif op == "declaration" then
declaration_node = node
sRow, sCol, eRow, eCol = node:range()
sRow = sRow + 1
eRow = eRow + 1
sCol = sCol + 1
eCol = eCol + 1
end
type = string.sub(path, 1, idx - 1)
if name == nil then
name = defaults[type] or "empty"
end
)
if op == "name" then
name = ts_utils.get_node_text(node)[1]
elseif op == "declaration" then
declaration_node = node
sRow, sCol, eRow, eCol = node:range()
sRow = sRow + 1
eRow = eRow + 1
sCol = sCol + 1
eCol = eCol + 1
end
end)
if declaration_node ~= nil then
table.insert(
results,
{
declaring_node = declaration_node,
dim = {
s = {r = sRow, c = sCol},
e = {r = eRow, c = eCol}
},
name = name,
type = type
}
)
table.insert(results, {
declaring_node = declaration_node,
dim = {s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol}},
name = name,
type = type
})
end
end
@ -137,12 +121,9 @@ M.get_all_nodes = function(query, lang, defaults, bufnr, pos_row, pos_col)
bufnr = bufnr or 0
-- todo a huge number
pos_row = pos_row or 30000
local success, parsed_query =
pcall(
function()
return vim.treesitter.parse_query(lang, query)
end
)
local success, parsed_query = pcall(function()
return vim.treesitter.parse_query(lang, query)
end)
if not success then
return nil
end
@ -163,55 +144,49 @@ M.get_all_nodes = function(query, lang, defaults, bufnr, pos_row, pos_col)
local op = ""
local method_receiver = ""
locals.recurse_local_nodes(
match,
function(_, node, path)
--local idx = string.find(path, ".", 1, true)
local idx = string.find(path, ".[^.]*$") -- find last .
op = string.sub(path, idx + 1, #path)
local a1, b1, c1, d1 = ts_utils.get_node_range(node)
local dbg_txt = ts_utils.get_node_text(node, bufnr)[1]
type = string.sub(path, 1, idx - 1)
locals.recurse_local_nodes(match, function(_, node, path)
-- local idx = string.find(path, ".", 1, true)
local idx = string.find(path, ".[^.]*$") -- find last .
op = string.sub(path, idx + 1, #path)
local a1, b1, c1, d1 = ts_utils.get_node_range(node)
local dbg_txt = ts_utils.get_node_text(node, bufnr)[1]
type = string.sub(path, 1, idx - 1)
ulog( "node" .. vim.inspect(node) .. "\n path: " .. path .. " op: " .. op .. " type: " .. type .. "\n txt: " .. dbg_txt .. "\n range: " .. tostring(a1) .. ":" .. tostring(b1) .. " TO " .. tostring(c1) .. ":" .. tostring(d1))
--
-- may not handle complex node
if op == "name" then
-- ulog("node name " .. name)
name = ts_utils.get_node_text(node, bufnr)[1]
elseif op == "declaration" or op == "clause" then
declaration_node = node
sRow, sCol, eRow, eCol = node:range()
sRow = sRow + 1
eRow = eRow + 1
sCol = sCol + 1
eCol = eCol + 1
end
ulog("node ", vim.inspect(node),
"\n path: " .. path .. " op: " .. op .. " type: " .. type .. "\n txt: " .. dbg_txt
.. "\n range: " .. tostring(a1) .. ":" .. tostring(b1) .. " TO " .. tostring(c1)
.. ":" .. tostring(d1))
--
-- may not handle complex node
if op == "name" then
-- ulog("node name " .. name)
name = ts_utils.get_node_text(node, bufnr)[1]
elseif op == "declaration" or op == "clause" then
declaration_node = node
sRow, sCol, eRow, eCol = node:range()
sRow = sRow + 1
eRow = eRow + 1
sCol = sCol + 1
eCol = eCol + 1
end
)
end)
if declaration_node ~= nil then
-- ulog(name .. " " .. op)
-- ulog(sRow, pos_row)
if sRow > pos_row then
ulog("beyond " .. tostring(pos_row))
break
-- break
end
table.insert(
results,
{
declaring_node = declaration_node,
dim = {
s = {r = sRow, c = sCol},
e = {r = eRow, c = eCol}
},
name = name,
operator = op,
type = type
}
)
table.insert(results, {
declaring_node = declaration_node,
dim = {s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol}},
name = name,
operator = op,
type = type
})
end
end
-- ulog("total nodes got: " .. tostring(#results))
ulog("total nodes got: " .. tostring(#results))
return results
end
@ -220,7 +195,6 @@ M.nodes_at_cursor = function(query, default, bufnr, row, col)
local ft = vim.api.nvim_buf_get_option(bufnr, "ft")
if row == nil or col == nil then
row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1
end
local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col)
if nodes == nil then

View File

@ -1,6 +1,7 @@
local M = {}
-- local ulog = require("go.utils").log
M.intersects = function(row, col, sRow, sCol, eRow, eCol)
-- print(row, col, sRow, sCol, eRow, eCol)
-- ulog(row, col, sRow, sCol, eRow, eCol)
if sRow > row or eRow < row then
return false
end