Skip to content

Commit

Permalink
feat: write to multiple files (yetone#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone authored Oct 14, 2024
1 parent 347d9be commit b19573c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 58 deletions.
141 changes: 83 additions & 58 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,20 @@ end
---@field explanation string
---@field start_line_in_response_buf integer
---@field end_line_in_response_buf integer
---@field filepath string

---@param response_content string
---@return AvanteCodeSnippet[]
local function extract_code_snippets(response_content)
---@return table<string, AvanteCodeSnippet[]>
local function extract_code_snippets_map(response_content)
local snippets = {}
local current_snippet = {}
local in_code_block = false
local lang, start_line, end_line, start_line_in_response_buf
local explanation = ""

for idx, line in ipairs(vim.split(response_content, "\n")) do
local lines = vim.split(response_content, "\n")

for idx, line in ipairs(lines) do
local _, start_line_str, end_line_str =
line:match("^%s*(%d*)[%.%)%s]*[Aa]?n?d?%s*[Rr]eplace%s+[Ll]ines:?%s*(%d+)%-(%d+)")
if start_line_str ~= nil and end_line_str ~= nil then
Expand Down Expand Up @@ -337,6 +340,7 @@ local function extract_code_snippets(response_content)
explanation = explanation,
start_line_in_response_buf = start_line_in_response_buf,
end_line_in_response_buf = idx,
filepath = lines[start_line_in_response_buf - 2],
}
table.insert(snippets, snippet)
end
Expand All @@ -357,48 +361,62 @@ local function extract_code_snippets(response_content)
end
end

return snippets
local snippets_map = {}
for _, snippet in ipairs(snippets) do
snippets_map[snippet.filepath] = snippets_map[snippet.filepath] or {}
table.insert(snippets_map[snippet.filepath], snippet)
end

return snippets_map
end

---@param snippets AvanteCodeSnippet[]
---@return AvanteCodeSnippet[]
local function ensure_snippets_no_overlap(original_content, snippets)
table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end)
---@param snippets_map table<string, AvanteCodeSnippet[]>
---@return table<string, AvanteCodeSnippet[]>
local function ensure_snippets_no_overlap(snippets_map)
local new_snippets_map = {}

local original_lines = vim.split(original_content, "\n")
for filepath, snippets in pairs(snippets_map) do
table.sort(snippets, function(a, b) return a.range[1] < b.range[1] end)

local result = {}
local last_end_line = 0
for _, snippet in ipairs(snippets) do
if snippet.range[1] > last_end_line then
table.insert(result, snippet)
last_end_line = snippet.range[2]
else
local snippet_lines = vim.split(snippet.content, "\n")
-- Trim the overlapping part
local new_start_line = nil
for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do
if
Utils.remove_indentation(original_lines[i])
== Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1])
then
new_start_line = i + 1
else
break
end
end
if new_start_line ~= nil then
snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n")
snippet.range[1] = new_start_line
table.insert(result, snippet)
local original_content = ""
if Utils.file.exists(filepath) then original_content = Utils.file.read_content(filepath) or "" end

local original_lines = vim.split(original_content, "\n")

local new_snippets = {}
local last_end_line = 0
for _, snippet in ipairs(snippets) do
if snippet.range[1] > last_end_line then
table.insert(new_snippets, snippet)
last_end_line = snippet.range[2]
else
Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" })
local snippet_lines = vim.split(snippet.content, "\n")
-- Trim the overlapping part
local new_start_line = nil
for i = snippet.range[1], math.min(snippet.range[2], last_end_line) do
if
Utils.remove_indentation(original_lines[i])
== Utils.remove_indentation(snippet_lines[i - snippet.range[1] + 1])
then
new_start_line = i + 1
else
break
end
end
if new_start_line ~= nil then
snippet.content = table.concat(vim.list_slice(snippet_lines, new_start_line - snippet.range[1] + 1), "\n")
snippet.range[1] = new_start_line
table.insert(new_snippets, snippet)
last_end_line = snippet.range[2]
else
Utils.error("Failed to ensure snippets no overlap", { once = true, title = "Avante" })
end
end
end
new_snippets_map[filepath] = new_snippets
end

return result
return new_snippets_map
end

local function insert_conflict_contents(bufnr, snippets)
Expand Down Expand Up @@ -494,40 +512,47 @@ end

---@param current_cursor boolean
function Sidebar:apply(current_cursor)
local content = table.concat(Utils.get_buf_lines(0, -1, self.code.bufnr), "\n")
local response, response_start_line = self:get_content_between_separators()
local all_snippets = extract_code_snippets(response)
all_snippets = ensure_snippets_no_overlap(content, all_snippets)
local selected_snippets = {}
local all_snippets_map = extract_code_snippets_map(response)
all_snippets_map = ensure_snippets_no_overlap(all_snippets_map)
local selected_snippets_map = {}
if current_cursor then
if self.result and self.result.winid then
local cursor_line = Utils.get_cursor_pos(self.result.winid)
for _, snippet in ipairs(all_snippets) do
if
cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1
and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1
then
selected_snippets = { snippet }
break
for filepath, snippets in pairs(all_snippets_map) do
for _, snippet in ipairs(snippets) do
if
cursor_line >= snippet.start_line_in_response_buf + response_start_line - 1
and cursor_line <= snippet.end_line_in_response_buf + response_start_line - 1
then
selected_snippets_map[filepath] = { snippet }
break
end
end
end
end
else
selected_snippets = all_snippets
selected_snippets_map = all_snippets_map
end

vim.defer_fn(function()
insert_conflict_contents(self.code.bufnr, selected_snippets)

api.nvim_set_current_win(self.code.winid)
api.nvim_feedkeys(api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true)
Diff.add_visited_buffer(self.code.bufnr)
Diff.process(self.code.bufnr)
api.nvim_win_set_cursor(self.code.winid, { 1, 0 })
vim.defer_fn(function()
Diff.find_next("ours")
vim.cmd("normal! zz")
end, 100)
for filepath, snippets in pairs(selected_snippets_map) do
local bufnr = Utils.get_opened_buffer(filepath)
if not bufnr then bufnr = Utils.create_new_buffer_with_file(filepath) end
insert_conflict_contents(bufnr, snippets)
local winid = Utils.get_winid(bufnr)
if not winid then goto continue end
api.nvim_set_current_win(winid)
api.nvim_feedkeys(api.nvim_replace_termcodes("<Esc>", true, false, true), "n", true)
Diff.add_visited_buffer(bufnr)
Diff.process(bufnr)
api.nvim_win_set_cursor(winid, { 1, 0 })
vim.defer_fn(function()
Diff.find_next("ours")
vim.cmd("normal! zz")
end, 100)
::continue::
end
end, 10)
end

Expand Down
5 changes: 5 additions & 0 deletions lua/avante/utils/file.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ function M.read_content(filepath)
return nil
end

function M.exists(filepath)
local stat = vim.loop.fs_stat(filepath)
return stat ~= nil
end

return M
27 changes: 27 additions & 0 deletions lua/avante/utils/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,12 @@ function M.get_win_options(winid, opt_name, key)
end
end

function M.get_winid(bufnr)
for _, winid in ipairs(api.nvim_list_wins()) do
if api.nvim_win_get_buf(winid) == bufnr then return winid end
end
end

function M.unlock_buf(bufnr)
vim.bo[bufnr].modified = false
vim.bo[bufnr].modifiable = true
Expand Down Expand Up @@ -655,4 +661,25 @@ function M.get_mentions()
}
end

function M.get_opened_buffer(filepath)
for _, buf in ipairs(api.nvim_list_bufs()) do
if fn.buflisted(buf) == 1 and fn.bufname(buf) == filepath then return buf end
end
return nil
end

function M.create_new_buffer_with_file(filepath)
local buf = api.nvim_create_buf(false, true)

api.nvim_buf_set_name(buf, filepath)

api.nvim_set_option_value("buftype", "", { buf = buf })

api.nvim_set_current_buf(buf)

vim.cmd("edit " .. filepath)

return buf
end

return M

0 comments on commit b19573c

Please sign in to comment.