You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
5.0 KiB
Lua
193 lines
5.0 KiB
Lua
local parsers = require 'nvim-treesitter.parsers'
|
|
local api = vim.api
|
|
|
|
local M = {}
|
|
local treesitter_namespace = api.nvim_get_namespaces()["treesitter/highlighter"]
|
|
local virt_text_id = api.nvim_create_namespace('TSPlaygroundHlGroups')
|
|
local lang_virt_text_id = api.nvim_create_namespace('TSPlaygroundLangGroups')
|
|
|
|
local function get_extmarks(bufnr, start, end_)
|
|
return api.nvim_buf_get_extmarks(bufnr, treesitter_namespace, start, end_, { details = true })
|
|
end
|
|
|
|
local function get_hl_group_for_node(bufnr, node)
|
|
local start_row, start_col, end_row, end_col = node:range()
|
|
local extmarks = get_extmarks(bufnr, {start_row, start_col}, {end_row, end_col})
|
|
local groups = {}
|
|
|
|
if #extmarks > 0 then
|
|
for _, ext in ipairs(extmarks) do
|
|
table.insert(groups, ext[4].hl_group)
|
|
end
|
|
end
|
|
|
|
return groups
|
|
end
|
|
|
|
local function flatten_node(root, results, level, language_tree, options)
|
|
level = level or 0
|
|
results = results or {}
|
|
|
|
for node, field in root:iter_children() do
|
|
if node:named() or options.include_anonymous_nodes then
|
|
local node_entry = {
|
|
level = level,
|
|
node = node,
|
|
field = field,
|
|
language_tree = language_tree,
|
|
hl_groups = options.include_hl_groups
|
|
and options.bufnr
|
|
and get_hl_group_for_node(options.bufnr, node)
|
|
or {}
|
|
}
|
|
|
|
table.insert(results, node_entry)
|
|
|
|
flatten_node(node, results, level + 1, language_tree, options)
|
|
end
|
|
end
|
|
|
|
return results
|
|
end
|
|
|
|
local function node_contains(node, range)
|
|
local start_row, start_col, end_row, end_col = node:range()
|
|
local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
|
|
local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
|
|
|
|
return start_fits and end_fits
|
|
end
|
|
|
|
local function flatten_lang_tree(lang_tree, results, options)
|
|
results = results or {}
|
|
|
|
for _, tree in ipairs(lang_tree:trees()) do
|
|
local root = tree:root()
|
|
local head_entry = nil
|
|
local head_entry_index = nil
|
|
|
|
for i, node_entry in ipairs(results) do
|
|
local is_contained = node_contains(node_entry.node, {root:range()})
|
|
|
|
if is_contained then
|
|
if not head_entry then
|
|
head_entry = node_entry
|
|
head_entry_index = i
|
|
else
|
|
if node_entry.level >= head_entry.level then
|
|
head_entry = node_entry
|
|
head_entry_index = i
|
|
else
|
|
-- If entry contains the root tree but is less specific, then we
|
|
-- can exit the loop
|
|
break
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
local insert_index = head_entry_index and head_entry_index or #results
|
|
local level = head_entry and head_entry.level + 1 or nil
|
|
|
|
local flattened_root = flatten_node(root, nil, level, lang_tree, options)
|
|
local i = insert_index + 1
|
|
|
|
-- Insert new items into the table at the correct positions
|
|
for _, entry in ipairs(flattened_root) do
|
|
table.insert(results, i, entry)
|
|
i = i + 1
|
|
end
|
|
end
|
|
|
|
if not options.suppress_injected_languages then
|
|
for _, child in pairs(lang_tree:children()) do
|
|
flatten_lang_tree(child, results, options)
|
|
end
|
|
end
|
|
|
|
return results
|
|
end
|
|
|
|
function M.process(bufnr, lang_tree, options)
|
|
bufnr = bufnr or api.nvim_get_current_buf()
|
|
options = options or {}
|
|
lang_tree = lang_tree or parsers.get_parser(bufnr)
|
|
options.bufnr = options.bufnr or bufnr
|
|
|
|
if not lang_tree then return {} end
|
|
|
|
return flatten_lang_tree(lang_tree, nil, options)
|
|
end
|
|
|
|
function M.print_entry(node_entry)
|
|
local line
|
|
local indent = string.rep(" ", node_entry.level)
|
|
local node = node_entry.node
|
|
local field = node_entry.field
|
|
local node_name = node:type()
|
|
|
|
if not node:named() then
|
|
node_name = string.format([["%s"]], node_name)
|
|
node_name = string.gsub(node_name, "\n", "\\n")
|
|
end
|
|
|
|
if field then
|
|
line = string.format("%s%s: %s [%d, %d] - [%d, %d]",
|
|
indent,
|
|
field,
|
|
node_name,
|
|
node:range())
|
|
else
|
|
line = string.format("%s%s [%d, %d] - [%d, %d]",
|
|
indent,
|
|
node_name,
|
|
node:range())
|
|
end
|
|
|
|
return line
|
|
end
|
|
|
|
function M.print_entries(node_entries)
|
|
local results = {}
|
|
|
|
for _, entry in ipairs(node_entries) do
|
|
table.insert(results, M.print_entry(entry))
|
|
end
|
|
|
|
return results
|
|
end
|
|
|
|
function M.print_hl_groups(bufnr, node_entries)
|
|
for i, node_entry in ipairs(node_entries) do
|
|
local groups = {}
|
|
|
|
for j, hl_group in ipairs(node_entry.hl_groups) do
|
|
local str = hl_group .. " / "
|
|
|
|
if j == #hl_group then
|
|
str = string.sub(str, 0, -3)
|
|
end
|
|
|
|
table.insert(groups, {str, hl_group})
|
|
end
|
|
|
|
api.nvim_buf_set_virtual_text(bufnr, virt_text_id, i, groups, {})
|
|
end
|
|
end
|
|
|
|
function M.print_language(bufnr, node_entries)
|
|
for i, node_entry in ipairs(node_entries) do
|
|
api.nvim_buf_set_virtual_text(bufnr, lang_virt_text_id, i - 1, {{node_entry.language_tree:lang()}}, {})
|
|
end
|
|
end
|
|
|
|
function M.remove_hl_groups(bufnr)
|
|
api.nvim_buf_clear_namespace(bufnr, virt_text_id, 0, -1)
|
|
end
|
|
|
|
function M.remove_language(bufnr)
|
|
api.nvim_buf_clear_namespace(bufnr, lang_virt_text_id, 0, -1)
|
|
end
|
|
|
|
return M
|