915d5a02b9
vim.treesitter.set_query is deprecated. Using vim.treesitter.query.set instead.
269 lines
6.6 KiB
Lua
269 lines
6.6 KiB
Lua
-- first version: from https://github.com/arsham/shark
|
|
local ls = require('luasnip')
|
|
local fmt = require('luasnip.extras.fmt').fmt
|
|
local ok, ts_utils = pcall(require, 'nvim-treesitter.ts_utils')
|
|
if not ok then
|
|
local ok, packer = pcall(require, 'packer')
|
|
if ok then
|
|
require('packer').loader('nvim-treesitter')
|
|
else
|
|
local ok, lazy = pcall(require, 'lazy')
|
|
if lazy then
|
|
lazy.load({ plugins = { 'nvim-treesitter' } })
|
|
else
|
|
error('Please install nvim-treesitter')
|
|
end
|
|
end
|
|
ts_utils = require('nvim-treesitter.ts_utils')
|
|
end
|
|
local ts_locals = require('nvim-treesitter.locals')
|
|
local rep = require('luasnip.extras').rep
|
|
local ai = require('luasnip.nodes.absolute_indexer')
|
|
|
|
local M = {}
|
|
|
|
M.go_err_snippet = function(args, _, _, spec)
|
|
local err_name = args[1][1]
|
|
local index = spec and spec.index or nil
|
|
local msg = spec and spec[1] or ''
|
|
if spec and spec[2] then
|
|
err_name = err_name .. spec[2]
|
|
end
|
|
if err_name == 'nil' then
|
|
return ls.sn(index, ls.sn(nil, ls.i(1, 'nil')))
|
|
end
|
|
return ls.sn(index, {
|
|
ls.c(1, {
|
|
ls.sn(nil, fmt('errors.Wrap({}, "{}")', { ls.t(err_name), ls.i(1, msg) })),
|
|
ls.sn(nil, fmt('errors.Wrapf({}, "{}", {})', { ls.t(err_name), ls.i(1, msg), ls.i(2) })),
|
|
ls.sn(
|
|
nil,
|
|
fmt('internal.GrpcError({},\n\t\tcodes.{}, "{}", "{}", {})', {
|
|
ls.t(err_name),
|
|
ls.i(1, 'Internal'),
|
|
ls.i(2, 'Description'),
|
|
ls.i(3, 'Field'),
|
|
ls.i(4, 'fields'),
|
|
})
|
|
),
|
|
ls.t(err_name),
|
|
}),
|
|
})
|
|
end
|
|
|
|
---Transform makes a node from the given text.
|
|
local function transform(text, info)
|
|
local string_sn = function(template, default)
|
|
info.index = info.index + 1
|
|
return ls.sn(info.index, fmt(template, ls.i(1, default)))
|
|
end
|
|
local new_sn = function(default)
|
|
return string_sn('{}', default)
|
|
end
|
|
|
|
-- cutting the name if exists.
|
|
if text:find([[^[^\[]*string$]]) then
|
|
text = 'string'
|
|
elseif text:find('^[^%[]*map%[[^%]]+') then
|
|
text = 'map'
|
|
elseif text:find('%[%]') then
|
|
text = 'slice'
|
|
elseif text:find([[ ?chan +[%a%d]+]]) then
|
|
return ls.t('nil')
|
|
end
|
|
|
|
-- separating the type from the name if exists.
|
|
local type = text:match([[^[%a%d]+ ([%a%d]+)$]])
|
|
if type then
|
|
text = type
|
|
end
|
|
|
|
if text == 'int' or text == 'int64' or text == 'int32' then
|
|
return new_sn('0')
|
|
elseif text == 'float32' or text == 'float64' then
|
|
return new_sn('0')
|
|
elseif text == 'error' then
|
|
if not info then
|
|
return ls.t('err')
|
|
end
|
|
|
|
info.index = info.index + 1
|
|
return M.go_err_snippet({ { info.err_name } }, nil, nil, { index = info.index })
|
|
elseif text == 'bool' then
|
|
info.index = info.index + 1
|
|
return ls.c(info.index, { ls.i(1, 'false'), ls.i(2, 'true') })
|
|
elseif text == 'string' then
|
|
return string_sn('"{}"', '')
|
|
elseif text == 'map' or text == 'slice' then
|
|
return ls.t('nil')
|
|
elseif string.find(text, '*', 1, true) then
|
|
return new_sn('nil')
|
|
end
|
|
|
|
text = text:match('[^ ]+$')
|
|
if text == 'context.Context' then
|
|
text = 'context.Background()'
|
|
else
|
|
-- when the type is concrete
|
|
text = text .. '{}'
|
|
end
|
|
|
|
return ls.t(text)
|
|
end
|
|
|
|
local get_node_text = require('go.utils').get_node_text
|
|
|
|
local handlers = {
|
|
parameter_list = function(node, info)
|
|
local result = {}
|
|
|
|
local count = node:named_child_count()
|
|
for idx = 0, count - 1 do
|
|
table.insert(result, transform(get_node_text(node:named_child(idx), 0), info))
|
|
if idx ~= count - 1 then
|
|
table.insert(result, ls.t({ ', ' }))
|
|
end
|
|
end
|
|
|
|
return result
|
|
end,
|
|
|
|
type_identifier = function(node, info)
|
|
local text = get_node_text(node, 0)
|
|
return { transform(text, info) }
|
|
end,
|
|
}
|
|
|
|
local query_is_set = false
|
|
|
|
local function set_query()
|
|
if query_is_set then
|
|
return
|
|
end
|
|
query_is_set = true
|
|
vim.treesitter.query.set(
|
|
'go',
|
|
'LuaSnip_Result',
|
|
[[
|
|
[
|
|
(method_declaration result: (_) @id)
|
|
(function_declaration result: (_) @id)
|
|
(func_literal result: (_) @id)
|
|
]
|
|
]]
|
|
)
|
|
end
|
|
|
|
local function return_value_nodes(info)
|
|
set_query()
|
|
local cursor_node = ts_utils.get_node_at_cursor()
|
|
local scope_tree = ts_locals.get_scope_tree(cursor_node, 0)
|
|
|
|
local function_node
|
|
for _, scope in ipairs(scope_tree) do
|
|
if
|
|
scope:type() == 'function_declaration'
|
|
or scope:type() == 'method_declaration'
|
|
or scope:type() == 'func_literal'
|
|
then
|
|
function_node = scope
|
|
break
|
|
end
|
|
end
|
|
|
|
if not function_node then
|
|
return
|
|
end
|
|
|
|
local query = (vim.fn.has('nvim-0.9') == 1) and vim.treesitter.query.get('go', 'LuaSnip_Result')
|
|
for _, node in query:iter_captures(function_node, 0) do
|
|
if handlers[node:type()] then
|
|
return handlers[node:type()](node, info)
|
|
end
|
|
end
|
|
return ls.t({ '' })
|
|
end
|
|
|
|
local is_in_function = require('go.ts.go').in_func()
|
|
|
|
---Transforms the given arguments into nodes wrapped in a snippet node.
|
|
M.make_return_nodes = function(args)
|
|
local info = { index = 0, err_name = args[1][1] }
|
|
|
|
return ls.sn(nil, return_value_nodes(info))
|
|
end
|
|
M.make_default_return_nodes = function()
|
|
local info = { index = 0, err_name = 'nil' }
|
|
|
|
return ls.sn(nil, return_value_nodes(info))
|
|
end
|
|
|
|
M.fill_return = function()
|
|
local info = { index = 0, err_name = 'nil' }
|
|
return ls.sn(nil, return_value_nodes(info))
|
|
end
|
|
|
|
---Runs the command in shell.
|
|
-- @param command string
|
|
-- @return table
|
|
M.shell = require('go.utils').run_command
|
|
|
|
M.last_lua_module_section = function(args)
|
|
local text = args[1][1] or ''
|
|
local split = vim.split(text, '.', { plain = true })
|
|
|
|
local options = {}
|
|
for len = 0, #split - 1 do
|
|
local node = ls.t(table.concat(vim.list_slice(split, #split - len, #split), '_'))
|
|
table.insert(options, node)
|
|
end
|
|
|
|
return ls.sn(nil, {
|
|
ls.c(1, options),
|
|
})
|
|
end
|
|
|
|
function M.is_in_test_file()
|
|
local filename = vim.fn.expand('%:p')
|
|
return vim.endswith(filename, '_test.go')
|
|
end
|
|
|
|
function M.is_in_test_function()
|
|
return M.is_in_test_file() and is_in_function()
|
|
end
|
|
|
|
M.create_t_run = function(args)
|
|
return ls.sn(1, {
|
|
ls.c(1, {
|
|
ls.t({ '' }),
|
|
ls.sn(
|
|
nil,
|
|
fmt('\tt.Run("{}", {}{})\n{}', {
|
|
ls.i(1, 'Case'),
|
|
ls.t(args[1]),
|
|
rep(1),
|
|
ls.d(2, M.create_t_run, ai[1]),
|
|
})
|
|
),
|
|
}),
|
|
})
|
|
end
|
|
|
|
M.mirror_t_run_funcs = function(args)
|
|
local strs = {}
|
|
for _, v in ipairs(args[1]) do
|
|
local name = v:match('^%s*t%.Run%s*%(%s*".*", (.*)%)')
|
|
if name then
|
|
local node = string.format('func %s(t *testing.T) {{\n\tt.Parallel()\n}}\n\n', name)
|
|
table.insert(strs, node)
|
|
end
|
|
end
|
|
local str = table.concat(strs, '')
|
|
if #str == 0 then
|
|
return ls.sn(1, ls.t(''))
|
|
end
|
|
return ls.sn(1, fmt(str, {}))
|
|
end
|
|
|
|
return M
|