cursor range col/row calculate

pull/21/head
ray-x 3 years ago
parent e7a7e95c85
commit a093d2420b

@ -5,7 +5,7 @@ M.test_fun = function(args)
local fpath = vim.fn.expand('%:p:h') local fpath = vim.fn.expand('%:p:h')
local row, col = unpack(vim.api.nvim_win_get_cursor(0)) local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1 row, col = row, col + 1
local ns = require("go.ts.go").get_func_method_node_at_pos(row, col) local ns = require("go.ts.go").get_func_method_node_at_pos(row, col)
if ns == nil or ns == {} then if ns == nil or ns == {} then
return return

@ -18,7 +18,6 @@ tags.modify = function(...)
require("go.install").install(gomodify) require("go.install").install(gomodify)
local fname = vim.fn.expand("%") -- %:p:h ? %:p local fname = vim.fn.expand("%") -- %:p:h ? %:p
local row, col = unpack(vim.api.nvim_win_get_cursor(0)) local row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1
local ns = require("go.ts.go").get_struct_node_at_pos(row, col) local ns = require("go.ts.go").get_struct_node_at_pos(row, col)
if ns == nil or ns == {} then if ns == nil or ns == {} then
return return

@ -1,5 +1,7 @@
local nodes = require("go.ts.nodes") local nodes = require("go.ts.nodes")
local log = require("go.utils").log
M = { M = {
-- query_struct = "(type_spec name:(type_identifier) @definition.struct type: (struct_type))", -- query_struct = "(type_spec name:(type_identifier) @definition.struct type: (struct_type))",
query_package = "(package_clause (package_identifier)@package.name)@package.clause", query_package = "(package_clause (package_identifier)@package.name)@package.clause",
@ -8,7 +10,7 @@ M = {
query_struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]], query_struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]],
query_em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]], query_em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]],
query_struct_block_from_id = [[(((type_spec name:(type_identifier) type: (struct_type)))@block.struct_from_id)]], query_struct_block_from_id = [[(((type_spec name:(type_identifier) type: (struct_type)))@block.struct_from_id)]],
--query_em_struct = "(field_declaration name:(field_identifier) @definition.struct type: (struct_type))", -- query_em_struct = "(field_declaration name:(field_identifier) @definition.struct type: (struct_type))",
query_interface_id = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]], query_interface_id = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]],
query_interface_method = [[((method_spec name: (field_identifier)@method.name)@interface.method.declaration)]], query_interface_method = [[((method_spec name: (field_identifier)@method.name)@interface.method.declaration)]],
query_func = "((function_declaration name: (identifier)@function.name) @function.declaration)", query_func = "((function_declaration name: (identifier)@function.name) @function.declaration)",
@ -84,12 +86,7 @@ M = {
)@method.declaration)]] )@method.declaration)]]
} }
function get_name_defaults() function get_name_defaults()
return { return {["func"] = "function", ["if"] = "if", ["else"] = "else", ["for"] = "for"}
["func"] = "function",
["if"] = "if",
["else"] = "else",
["for"] = "for"
}
end end
M.get_struct_node_at_pos = function(row, col) M.get_struct_node_at_pos = function(row, col)
@ -99,6 +96,7 @@ M.get_struct_node_at_pos = function(row, col)
if ns == nil then if ns == nil then
print("struct not found") print("struct not found")
else else
log('struct node', ns)
return ns[#ns] return ns[#ns]
end end
end end

@ -1,5 +1,4 @@
-- part of the code from polarmutex/contextprint.nvim -- part of the code from polarmutex/contextprint.nvim
local ts_utils = require("nvim-treesitter.ts_utils") local ts_utils = require("nvim-treesitter.ts_utils")
local ts_query = require("nvim-treesitter.query") local ts_query = require("nvim-treesitter.query")
local parsers = require("nvim-treesitter.parsers") local parsers = require("nvim-treesitter.parsers")
@ -40,12 +39,9 @@ end
-- @param nodes Array<node_wrapper> -- @param nodes Array<node_wrapper>
-- perf note. I could memoize some values here... -- perf note. I could memoize some values here...
M.sort_nodes = function(nodes) M.sort_nodes = function(nodes)
table.sort( table.sort(nodes, function(a, b)
nodes, return M.count_parents(a) < M.count_parents(b)
function(a, b) end)
return M.count_parents(a) < M.count_parents(b)
end
)
return nodes return nodes
end end
@ -59,12 +55,9 @@ end
-- }] -- }]
M.get_nodes = function(query, lang, defaults, bufnr) M.get_nodes = function(query, lang, defaults, bufnr)
bufnr = bufnr or 0 bufnr = bufnr or 0
local success, parsed_query = local success, parsed_query = pcall(function()
pcall( return vim.treesitter.parse_query(lang, query)
function() end)
return vim.treesitter.parse_query(lang, query)
end
)
if not success then if not success then
return nil return nil
end end
@ -80,45 +73,36 @@ M.get_nodes = function(query, lang, defaults, bufnr)
local declaration_node local declaration_node
local type = "nil" local type = "nil"
local name = "nil" local name = "nil"
locals.recurse_local_nodes( locals.recurse_local_nodes(match, function(_, node, path)
match, local idx = string.find(path, ".", 1, true)
function(_, node, path) local op = string.sub(path, idx + 1, #path)
local idx = string.find(path, ".", 1, true)
local op = string.sub(path, idx + 1, #path) local a1, b1, c1, d1 = ts_utils.get_node_range(node)
local a1, b1, c1, d1 = ts_utils.get_node_range(node) type = string.sub(path, 1, idx - 1)
if name == nil then
type = string.sub(path, 1, idx - 1) name = defaults[type] or "empty"
if name == nil then end
name = defaults[type] or "empty"
end if op == "name" then
name = ts_utils.get_node_text(node)[1]
if op == "name" then elseif op == "declaration" then
name = ts_utils.get_node_text(node)[1] declaration_node = node
elseif op == "declaration" then sRow, sCol, eRow, eCol = node:range()
declaration_node = node sRow = sRow + 1
sRow, sCol, eRow, eCol = node:range() eRow = eRow + 1
sRow = sRow + 1 sCol = sCol + 1
eRow = eRow + 1 eCol = eCol + 1
sCol = sCol + 1
eCol = eCol + 1
end
end end
) end)
if declaration_node ~= nil then if declaration_node ~= nil then
table.insert( table.insert(results, {
results, declaring_node = declaration_node,
{ dim = {s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol}},
declaring_node = declaration_node, name = name,
dim = { type = type
s = {r = sRow, c = sCol}, })
e = {r = eRow, c = eCol}
},
name = name,
type = type
}
)
end end
end end
@ -137,12 +121,9 @@ M.get_all_nodes = function(query, lang, defaults, bufnr, pos_row, pos_col)
bufnr = bufnr or 0 bufnr = bufnr or 0
-- todo a huge number -- todo a huge number
pos_row = pos_row or 30000 pos_row = pos_row or 30000
local success, parsed_query = local success, parsed_query = pcall(function()
pcall( return vim.treesitter.parse_query(lang, query)
function() end)
return vim.treesitter.parse_query(lang, query)
end
)
if not success then if not success then
return nil return nil
end end
@ -163,55 +144,49 @@ M.get_all_nodes = function(query, lang, defaults, bufnr, pos_row, pos_col)
local op = "" local op = ""
local method_receiver = "" local method_receiver = ""
locals.recurse_local_nodes( locals.recurse_local_nodes(match, function(_, node, path)
match, -- local idx = string.find(path, ".", 1, true)
function(_, node, path) local idx = string.find(path, ".[^.]*$") -- find last .
--local idx = string.find(path, ".", 1, true) op = string.sub(path, idx + 1, #path)
local idx = string.find(path, ".[^.]*$") -- find last . local a1, b1, c1, d1 = ts_utils.get_node_range(node)
op = string.sub(path, idx + 1, #path) local dbg_txt = ts_utils.get_node_text(node, bufnr)[1]
local a1, b1, c1, d1 = ts_utils.get_node_range(node) type = string.sub(path, 1, idx - 1)
local dbg_txt = ts_utils.get_node_text(node, bufnr)[1]
type = string.sub(path, 1, idx - 1) ulog("node ", vim.inspect(node),
"\n path: " .. path .. " op: " .. op .. " type: " .. type .. "\n txt: " .. dbg_txt
ulog( "node" .. vim.inspect(node) .. "\n path: " .. path .. " op: " .. op .. " type: " .. type .. "\n txt: " .. dbg_txt .. "\n range: " .. tostring(a1) .. ":" .. tostring(b1) .. " TO " .. tostring(c1) .. ":" .. tostring(d1)) .. "\n range: " .. tostring(a1) .. ":" .. tostring(b1) .. " TO " .. tostring(c1)
-- .. ":" .. tostring(d1))
-- may not handle complex node --
if op == "name" then -- may not handle complex node
-- ulog("node name " .. name) if op == "name" then
name = ts_utils.get_node_text(node, bufnr)[1] -- ulog("node name " .. name)
elseif op == "declaration" or op == "clause" then name = ts_utils.get_node_text(node, bufnr)[1]
declaration_node = node elseif op == "declaration" or op == "clause" then
sRow, sCol, eRow, eCol = node:range() declaration_node = node
sRow = sRow + 1 sRow, sCol, eRow, eCol = node:range()
eRow = eRow + 1 sRow = sRow + 1
sCol = sCol + 1 eRow = eRow + 1
eCol = eCol + 1 sCol = sCol + 1
end eCol = eCol + 1
end end
) end)
if declaration_node ~= nil then if declaration_node ~= nil then
-- ulog(name .. " " .. op) -- ulog(name .. " " .. op)
-- ulog(sRow, pos_row) -- ulog(sRow, pos_row)
if sRow > pos_row then if sRow > pos_row then
ulog("beyond " .. tostring(pos_row)) ulog("beyond " .. tostring(pos_row))
break -- break
end end
table.insert( table.insert(results, {
results, declaring_node = declaration_node,
{ dim = {s = {r = sRow, c = sCol}, e = {r = eRow, c = eCol}},
declaring_node = declaration_node, name = name,
dim = { operator = op,
s = {r = sRow, c = sCol}, type = type
e = {r = eRow, c = eCol} })
},
name = name,
operator = op,
type = type
}
)
end end
end end
-- ulog("total nodes got: " .. tostring(#results)) ulog("total nodes got: " .. tostring(#results))
return results return results
end end
@ -220,7 +195,6 @@ M.nodes_at_cursor = function(query, default, bufnr, row, col)
local ft = vim.api.nvim_buf_get_option(bufnr, "ft") local ft = vim.api.nvim_buf_get_option(bufnr, "ft")
if row == nil or col == nil then if row == nil or col == nil then
row, col = unpack(vim.api.nvim_win_get_cursor(0)) row, col = unpack(vim.api.nvim_win_get_cursor(0))
row, col = row + 1, col + 1
end end
local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col) local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col)
if nodes == nil then if nodes == nil then

@ -1,6 +1,7 @@
local M = {} local M = {}
-- local ulog = require("go.utils").log
M.intersects = function(row, col, sRow, sCol, eRow, eCol) M.intersects = function(row, col, sRow, sCol, eRow, eCol)
-- print(row, col, sRow, sCol, eRow, eCol) -- ulog(row, col, sRow, sCol, eRow, eCol)
if sRow > row or eRow < row then if sRow > row or eRow < row then
return false return false
end end

Loading…
Cancel
Save