334 lines
11 KiB
Lua
334 lines
11 KiB
Lua
local nodes = require('go.ts.nodes')
|
|
|
|
local tsutil = require('nvim-treesitter.ts_utils')
|
|
local log = require('go.utils').log
|
|
local warn = require('go.utils').warn
|
|
local info = require('go.utils').info
|
|
local debug = require('go.utils').debug
|
|
-- debug = log
|
|
|
|
local M = {
|
|
query_struct = '(type_spec name:(type_identifier) @definition.struct type: (struct_type))',
|
|
query_package = '(package_clause (package_identifier)@package.name)@package.clause',
|
|
query_struct_id = '(type_spec name:(type_identifier) @definition.struct (struct_type))',
|
|
query_em_struct_id = '(field_declaration name:(field_identifier) @definition.struct (struct_type))',
|
|
query_struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]],
|
|
-- query_type_declaration = [[((type_declaration (type_spec name:(type_identifier)@type_decl.name type:(type_identifier)@type_decl.type))@type_decl.declaration)]], -- rename to gotype so not confuse with type
|
|
query_type_declaration = [[((type_declaration (type_spec name:(type_identifier)@type_decl.name)))]],
|
|
query_em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]],
|
|
query_struct_block_from_id = [[(((type_spec name:(type_identifier) type: (struct_type)))@block.struct_from_id)]],
|
|
-- query_em_struct = "(field_declaration name:(field_identifier) @definition.struct type: (struct_type))",
|
|
query_interface_id = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]],
|
|
query_interface_method = [[((method_spec name: (field_identifier)@method.name)@interface.method.declaration)]],
|
|
query_func = '((function_declaration name: (identifier)@function.name) @function.declaration)',
|
|
query_method = '(method_declaration receiver: (parameter_list (parameter_declaration name:(identifier)@method.receiver.name type:(type_identifier)@method.receiver.type)) name:(field_identifier)@method.name)@method.declaration',
|
|
query_method_name = [[((method_declaration
|
|
receiver: (parameter_list)@method.receiver
|
|
name: (field_identifier)@method.name
|
|
body:(block))@method.declaration)]],
|
|
query_method_void = [[((method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (pointer_type)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
body:(block)
|
|
)@method.declaration)]],
|
|
query_method_multi_ret = [[(method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (pointer_type)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
result: (parameter_list)@method.result
|
|
body:(block)
|
|
)@method.declaration]],
|
|
query_method_single_ret = [[((method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (pointer_type)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
result: (type_identifier)@method.result
|
|
body:(block)
|
|
)@method.declaration)]],
|
|
query_tr_method_void = [[((method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (type_identifier)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
body:(block)
|
|
)@method.declaration)]],
|
|
query_tr_method_multi_ret = [[((method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (type_identifier)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
result: (parameter_list)@method.result
|
|
body:(block)
|
|
)@method.declaration)]],
|
|
query_tr_method_single_ret = [[((method_declaration
|
|
receiver: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)@method.receiver.name
|
|
type: (type_identifier)@method.receiver.type)
|
|
)
|
|
name: (field_identifier)@method.name
|
|
parameters: (parameter_list)@method.parameter
|
|
result: (type_identifier)@method.result
|
|
body:(block)
|
|
)@method.declaration)]],
|
|
query_test_func = [[((function_declaration name: (identifier) @test_name
|
|
parameters: (parameter_list
|
|
(parameter_declaration
|
|
name: (identifier)
|
|
type: (pointer_type
|
|
(qualified_type
|
|
package: (package_identifier) @_param_package
|
|
name: (type_identifier) @_param_name))))
|
|
) @testfunc
|
|
(#contains? @test_name "Test"))]],
|
|
query_tbl_testcase_node = [[ ( literal_value (
|
|
literal_element (
|
|
literal_value .(
|
|
keyed_element
|
|
(literal_element (identifier))
|
|
(literal_element (interpreted_string_literal) @test.name)
|
|
)
|
|
) @test.block
|
|
))
|
|
]],
|
|
query_sub_testcase_node = [[ (call_expression
|
|
(selector_expression
|
|
(field_identifier) @method.name)
|
|
(argument_list
|
|
(interpreted_string_literal) @tc.name
|
|
(func_literal) )
|
|
(#eq? @method.name "Run")
|
|
) @tc.run ]],
|
|
query_string_literal = [[((interpreted_string_literal) @string.value)]],
|
|
}
|
|
|
|
local function get_name_defaults()
|
|
return { ['func'] = 'function', ['if'] = 'if', ['else'] = 'else', ['for'] = 'for' }
|
|
end
|
|
|
|
M.get_struct_node_at_pos = function(bufnr)
|
|
local query = M.query_struct_block .. ' ' .. M.query_em_struct_block
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn)
|
|
if ns == nil then
|
|
debug('struct not found')
|
|
else
|
|
log('struct node', ns)
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_type_node_at_pos = function(bufnr)
|
|
local query = M.query_type_declaration
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn)
|
|
if ns == nil then
|
|
debug('type not found')
|
|
else
|
|
log('type node', ns)
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_interface_node_at_pos = function(bufnr)
|
|
local query = M.query_interface_id
|
|
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn)
|
|
if ns == nil then
|
|
debug('interface not found')
|
|
else
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_interface_method_node_at_pos = function(bufnr)
|
|
local query = M.query_interface_method
|
|
bufnr = bufnr or vim.api.nvim_get_current_buf()
|
|
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufnr)
|
|
if ns == nil then
|
|
warn('interface method not found')
|
|
else
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_func_method_node_at_pos = function(bufnr)
|
|
local query = M.query_func .. ' ' .. M.query_method_name
|
|
-- local query = require("go.ts.go").query_method_name
|
|
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn)
|
|
if ns == nil then
|
|
return nil
|
|
end
|
|
if ns == nil then
|
|
warn('function not found')
|
|
else
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_tbl_testcase_node_name = function(bufnr)
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local ok, parser = pcall(vim.treesitter.get_parser, bufn)
|
|
if not ok or not parser then
|
|
return
|
|
end
|
|
local tree = parser:parse()
|
|
tree = tree[1]
|
|
|
|
local tbl_case_query = vim.treesitter.query.parse('go', M.query_tbl_testcase_node)
|
|
|
|
local curr_row, _ = unpack(vim.api.nvim_win_get_cursor(0))
|
|
for _, match, _ in tbl_case_query:iter_matches(tree:root(), bufn, 0, -1) do
|
|
local tc_name = nil
|
|
for id, node in pairs(match) do
|
|
local name = tbl_case_query.captures[id]
|
|
-- IDK why test.name is captured before test.block
|
|
if name == 'test.name' then
|
|
tc_name = tsutil.get_node_text(node, bufn)[1]
|
|
end
|
|
|
|
if name == 'test.block' then
|
|
local start_row, _, end_row, _ = node:range()
|
|
if (curr_row >= start_row and curr_row <= end_row) then
|
|
return tc_name
|
|
end
|
|
end
|
|
end
|
|
end
|
|
return nil
|
|
end
|
|
|
|
M.get_sub_testcase_name = function(bufnr)
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local sub_case_query = vim.treesitter.query.parse('go', M.query_sub_testcase_node)
|
|
|
|
local ok, parser = pcall(vim.treesitter.get_parser, bufn)
|
|
if not ok or not parser then
|
|
return
|
|
end
|
|
local tree = parser:parse()
|
|
tree = tree[1]
|
|
|
|
local is_inside_test = false
|
|
local curr_row, _ = unpack(vim.api.nvim_win_get_cursor(0))
|
|
for id, node in sub_case_query:iter_captures(tree:root(), bufn, 0, -1) do
|
|
local name = sub_case_query.captures[id]
|
|
-- tc_run is the first capture of a match, so we can use it to check if we are inside a test
|
|
if name == 'tc.run' then
|
|
local start_row, _, end_row, _ = node:range()
|
|
if (curr_row >= start_row and curr_row <= end_row) then
|
|
is_inside_test = true
|
|
else
|
|
is_inside_test = false
|
|
end
|
|
goto continue
|
|
end
|
|
if name == 'tc.name' and is_inside_test then
|
|
local tc_name = tsutil.get_node_text(node, bufn)
|
|
return tc_name[1]
|
|
end
|
|
::continue::
|
|
end
|
|
return nil
|
|
end
|
|
|
|
M.get_string_node = function(bufnr)
|
|
local query = M.query_string_literal
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, 'value')
|
|
if ns == nil then
|
|
debug('struct not found')
|
|
else
|
|
log('struct node', ns[#ns])
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
M.get_import_node_at_pos = function(bufnr)
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
|
|
local cur_node = tsutil.get_node_at_cursor()
|
|
|
|
if cur_node and (cur_node:type() == 'import_spec' or cur_node:parent():type() == 'import_spec') then
|
|
return cur_node
|
|
end
|
|
end
|
|
|
|
M.get_module_at_pos = function(bufnr)
|
|
local node = M.get_import_node_at_pos(bufnr)
|
|
if node then
|
|
local module = require('go.utils').get_node_text(node, vim.api.nvim_get_current_buf())
|
|
-- log
|
|
module = string.gsub(module, '"', '')
|
|
return module
|
|
end
|
|
end
|
|
|
|
M.get_package_node_at_pos = function(bufnr)
|
|
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
|
|
row, col = row, col + 1
|
|
if row > 10 then
|
|
return
|
|
end
|
|
local query = M.query_package
|
|
-- local query = require("go.ts.go").query_method_name
|
|
|
|
local bufn = bufnr or vim.api.nvim_get_current_buf()
|
|
|
|
local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn)
|
|
if ns == nil then
|
|
return nil
|
|
end
|
|
if ns == nil then
|
|
warn('package not found')
|
|
else
|
|
return ns[#ns]
|
|
end
|
|
end
|
|
|
|
function M.in_func()
|
|
local ok, ts_utils = pcall(require, 'nvim-treesitter.ts_utils')
|
|
if not ok then
|
|
return false
|
|
end
|
|
local current_node = ts_utils.get_node_at_cursor()
|
|
if not current_node then
|
|
return false
|
|
end
|
|
local expr = current_node
|
|
|
|
while expr do
|
|
if expr:type() == 'function_declaration' or expr:type() == 'method_declaration' then
|
|
return true
|
|
end
|
|
expr = expr:parent()
|
|
end
|
|
return false
|
|
end
|
|
|
|
return M
|