Skip to content

Commit

Permalink
fix: #542 use vim.treesitter instead of nvim-treesitter. (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsully authored Dec 16, 2024
1 parent d935c84 commit b741490
Showing 1 changed file with 71 additions and 5 deletions.
76 changes: 71 additions & 5 deletions lua/codecompanion/utils/treesitter.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
local ts_parsers = require("nvim-treesitter.parsers")
local ts_utils = require("nvim-treesitter.ts_utils")
local api = vim.api

local M = {}
Expand All @@ -12,7 +10,13 @@ function M.goto_heading(direction, count)
local cursor = api.nvim_win_get_cursor(0)
local current_row = cursor[1] - 1

local parser = ts_parsers.get_parser(bufnr, "markdown")
local parser = vim.treesitter.get_parser(bufnr, "markdown")

if parser == nil then
vim.notify("Couldn't find the 'markdown' treesitter parser!")
return
end

local root_tree = parser:parse()[1]:root()

local query = vim.treesitter.query.parse("markdown", [[(atx_heading) @heading]])
Expand All @@ -28,7 +32,7 @@ function M.goto_heading(direction, count)
if node_start >= from_row then
table.insert(found_headings, node) -- Collect valid headings in a table
if #found_headings == count then -- Check if we have reached the desired count
ts_utils.goto_node(found_headings[count], false, true)
M.goto_node(found_headings[count], false, true)
return
end
end
Expand All @@ -47,12 +51,74 @@ function M.goto_heading(direction, count)
end
end
if #found_headings >= count then
ts_utils.goto_node(found_headings[#found_headings - count + 1], false, true)
M.goto_node(found_headings[#found_headings - count + 1], false, true)
end
end

-- If we haven't found the desired heading, we can stay at current position,
-- or implement some behavior like wrapping around or signaling an error.
end

-- From nvim-treesitter @ 'master', not the 'main' rewrite.
function M.goto_node(node, goto_end, avoid_set_jump)
if not node then
return
end

if not avoid_set_jump then
vim.cmd("normal! m'")
end

local range = { M.get_vim_range({ node:range() }) }

---@type table<number>
local position
if not goto_end then
position = { range[1], range[2] }
else
position = { range[3], range[4] }
end

-- Enter visual mode if we are in operator pending mode
-- If we don't do this, it will miss the last character.
local mode = vim.api.nvim_get_mode()

if mode.mode == "no" then
vim.cmd("normal! v")
end

-- Position is 1, 0 indexed.
api.nvim_win_set_cursor(0, { position[1], position[2] - 1 })
end

-- Get a compatible vim range (1 index based) from a TS node range.
--
-- TS nodes start with 0 and the end col is ending exclusive.
-- They also treat a EOF/EOL char as a char ending in the first
-- col of the next row.
---comment
---@param range integer[]
---@param buf integer|nil
---@return integer, integer, integer, integer
function M.get_vim_range(range, buf)
---@type integer, integer, integer, integer
local srow, scol, erow, ecol = unpack(range)
srow = srow + 1
scol = scol + 1
erow = erow + 1

if ecol == 0 then
-- Use the value of the last col of the previous row instead.
erow = erow - 1
if not buf or buf == 0 then
ecol = vim.fn.col({ erow, "$" }) - 1
else
ecol = #api.nvim_buf_get_lines(buf, erow - 1, erow, false)[1]
end
ecol = math.max(ecol, 1)
end

return srow, scol, erow, ecol
end

return M

0 comments on commit b741490

Please sign in to comment.