Skip to content

Commit

Permalink
feat: repo map (yetone#496)
Browse files Browse the repository at this point in the history
* feat: repo map

* chore: remove breakline

* chore: remove spaces

* fix: golang public method

* feat: mentions for editing input
  • Loading branch information
yetone authored Sep 23, 2024
1 parent 8dbfe85 commit 8e1018f
Show file tree
Hide file tree
Showing 15 changed files with 1,191 additions and 64 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ For building binary if you wish to build from source, then `cargo` is required.
build = "make",
-- build = "powershell -ExecutionPolicy Bypass -File Build.ps1 -BuildFromSource false" -- for windows
dependencies = {
"nvim-treesitter/nvim-treesitter",
"stevearc/dressing.nvim",
"nvim-lua/plenary.nvim",
"MunifTanjim/nui.nvim",
Expand Down Expand Up @@ -427,7 +428,7 @@ If you have the following structure:
- [x] Slash commands
- [x] Edit the selected block
- [x] Smart Tab (Cursor Flow)
- [ ] Chat with project
- [x] Chat with project (You can use `@codebase` to chat with the whole project)
- [ ] Chat with selected files
## Roadmap
Expand Down
2 changes: 2 additions & 0 deletions crates/avante-templates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct TemplateContext {
ask: bool,
question: String,
code_lang: String,
filepath: String,
file_content: String,
selected_code: Option<String>,
project_context: Option<String>,
Expand All @@ -45,6 +46,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
ask => context.ask,
question => context.question,
code_lang => context.code_lang,
filepath => context.filepath,
file_content => context.file_content,
selected_code => context.selected_code,
project_context => context.project_context,
Expand Down
3 changes: 3 additions & 0 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ M.stream = function(opts)

Path.prompts.initialize(Path.prompts.get(opts.bufnr))

local filepath = Utils.relative_path(api.nvim_buf_get_name(opts.bufnr))

local template_opts = {
use_xml_format = Provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
question = original_instructions,
code_lang = opts.code_lang,
filepath = filepath,
file_content = opts.file_content,
selected_code = opts.selected_code,
project_context = opts.project_context,
Expand Down
81 changes: 58 additions & 23 deletions lua/avante/path.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ end
H.get_mode_file = function(mode) return string.format("custom.%s.avanterules", mode) end

-- History path
local M = {}
local History = {}

-- Returns the Path to the chat history file for the given buffer.
---@param bufnr integer
---@return Path
M.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end
History.get = function(bufnr) return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr)) end

-- Loads the chat history for the given buffer.
---@param bufnr integer
M.load = function(bufnr)
local history_file = M.get(bufnr)
History.load = function(bufnr)
local history_file = History.get(bufnr)
if history_file:exists() then
local content = history_file:read()
return content ~= nil and vim.json.decode(content) or {}
Expand All @@ -49,29 +49,29 @@ end
-- Saves the chat history for the given buffer.
---@param bufnr integer
---@param history table
M.save = function(bufnr, history)
local history_file = M.get(bufnr)
History.save = function(bufnr, history)
local history_file = History.get(bufnr)
history_file:write(vim.json.encode(history), "w")
end

P.history = M
P.history = History

-- Prompt path
local N = {}
local Prompt = {}

---@class AvanteTemplates
---@field initialize fun(directory: string): nil
---@field render fun(template: string, context: TemplateOptions): string
local templates = nil

N.templates = { planning = nil, editing = nil, suggesting = nil }
Prompt.templates = { planning = nil, editing = nil, suggesting = nil }

-- Creates a directory in the cache path for the given buffer and copies the custom prompts to it.
-- We need to do this beacuse the prompt template engine requires a given directory to load all required files.
-- PERF: Hmm instead of copy to cache, we can also load in globals context, but it requires some work on bindings. (eh maybe?)
---@param bufnr number
---@return string the resulted cache_directory to be loaded with avante_templates
N.get = function(bufnr)
Prompt.get = function(bufnr)
if not P.available() then error("Make sure to build avante (missing avante_templates)", 2) end

-- get root directory of given bufnr
Expand All @@ -85,19 +85,19 @@ N.get = function(bufnr)
local scanner = Scan.scan_dir(directory:absolute(), { depth = 1, add_dirs = true })
for _, entry in ipairs(scanner) do
local file = Path:new(entry)
if entry:find("planning") and N.templates.planning == nil then
N.templates.planning = file:read()
elseif entry:find("editing") and N.templates.editing == nil then
N.templates.editing = file:read()
elseif entry:find("suggesting") and N.templates.suggesting == nil then
N.templates.suggesting = file:read()
if entry:find("planning") and Prompt.templates.planning == nil then
Prompt.templates.planning = file:read()
elseif entry:find("editing") and Prompt.templates.editing == nil then
Prompt.templates.editing = file:read()
elseif entry:find("suggesting") and Prompt.templates.suggesting == nil then
Prompt.templates.suggesting = file:read()
end
end

Path:new(debug.getinfo(1).source:match("@?(.*/)"):gsub("/lua/avante/path.lua$", "") .. "templates")
:copy({ destination = cache_prompt_dir, recursive = true })

vim.iter(N.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
vim.iter(Prompt.templates):filter(function(_, v) return v ~= nil end):each(function(k, v)
local f = cache_prompt_dir:joinpath(H.get_mode_file(k))
f:write(v, "w")
end)
Expand All @@ -106,22 +106,53 @@ N.get = function(bufnr)
end

---@param mode LlmMode
N.get_file = function(mode)
if N.templates[mode] ~= nil then return H.get_mode_file(mode) end
Prompt.get_file = function(mode)
if Prompt.templates[mode] ~= nil then return H.get_mode_file(mode) end
return string.format("%s.avanterules", mode)
end

---@param path string
---@param opts TemplateOptions
N.render_file = function(path, opts) return templates.render(path, opts) end
Prompt.render_file = function(path, opts) return templates.render(path, opts) end

---@param mode LlmMode
---@param opts TemplateOptions
N.render_mode = function(mode, opts) return templates.render(N.get_file(mode), opts) end
Prompt.render_mode = function(mode, opts) return templates.render(Prompt.get_file(mode), opts) end

N.initialize = function(directory) templates.initialize(directory) end
Prompt.initialize = function(directory) templates.initialize(directory) end

P.prompts = N
P.prompts = Prompt

local RepoMap = {}

-- Get a chat history file name given a buffer
---@param project_root string
---@param ext string
---@return string
RepoMap.filename = function(project_root, ext)
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. "." .. ext .. ".repo_map.json"
end

RepoMap.get = function(project_root, ext) return Path:new(P.data_path):joinpath(RepoMap.filename(project_root, ext)) end

RepoMap.save = function(project_root, ext, data)
local file = RepoMap.get(project_root, ext)
file:write(vim.json.encode(data), "w")
end

RepoMap.load = function(project_root, ext)
local file = RepoMap.get(project_root, ext)
if file:exists() then
local content = file:read()
return content ~= nil and vim.json.decode(content) or {}
end
return nil
end

P.repo_map = RepoMap

P.setup = function()
local history_path = Path:new(Config.history.storage_path)
Expand All @@ -132,6 +163,10 @@ P.setup = function()
if not cache_path:exists() then cache_path:mkdir({ parents = true }) end
P.cache_path = cache_path

local data_path = Path:new(vim.fn.stdpath("data") .. "/avante")
if not data_path:exists() then data_path:mkdir({ parents = true }) end
P.data_path = data_path

vim.defer_fn(function()
local ok, module = pcall(require, "avante_templates")
---@cast module AvanteTemplates
Expand Down
26 changes: 25 additions & 1 deletion lua/avante/selection.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
local Utils = require("avante.utils")
local Config = require("avante.config")
local Llm = require("avante.llm")
local Highlights = require("avante.highlights")
local Provider = require("avante.providers")

local api = vim.api
Expand Down Expand Up @@ -391,10 +390,16 @@ function Selection:create_editing_input()
end

local filetype = api.nvim_get_option_value("filetype", { buf = code_bufnr })
local file_ext = api.nvim_buf_get_name(code_bufnr):match("^.+%.(.+)$")

local mentions = Utils.extract_mentions(input)
input = mentions.new_content
local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil

Llm.stream({
bufnr = code_bufnr,
ask = true,
project_context = vim.json.encode(project_context),
file_content = code_content,
code_lang = filetype,
selected_code = self.selection.content,
Expand Down Expand Up @@ -453,6 +458,25 @@ function Selection:create_editing_input()
end,
})

api.nvim_create_autocmd("InsertEnter", {
group = self.augroup,
buffer = bufnr,
once = true,
desc = "Setup the completion of helpers in the input buffer",
callback = function()
local has_cmp, cmp = pcall(require, "cmp")
if has_cmp then
cmp.register_source("avante_mentions", require("cmp_avante.mentions").new(Utils.get_mentions(), bufnr))
cmp.setup.buffer({
enabled = true,
sources = {
{ name = "avante_mentions" },
},
})
end
end,
})

api.nvim_create_autocmd("User", {
pattern = "AvanteEditSubmitted",
callback = function(ev)
Expand Down
25 changes: 17 additions & 8 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,8 @@ local function insert_conflict_contents(bufnr, snippets)

local result = {}
table.insert(result, "<<<<<<< HEAD")
if start_line ~= end_line then
for i = start_line, end_line do
table.insert(result, lines[i])
end
for i = start_line, end_line do
table.insert(result, lines[i])
end
table.insert(result, "=======")

Expand All @@ -460,8 +458,6 @@ local function insert_conflict_contents(bufnr, snippets)
table.insert(result, line)
end

if start_line == end_line then table.insert(result, lines[start_line]) end

table.insert(result, ">>>>>>> Snippet")

api.nvim_buf_set_lines(bufnr, offset + start_line - 1, offset + end_line, false, result)
Expand Down Expand Up @@ -1296,9 +1292,17 @@ function Sidebar:create_input(opts)
Path.history.save(self.code.bufnr, chat_history)
end

local mentions = Utils.extract_mentions(request)
request = mentions.new_content

local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$")

local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil

Llm.stream({
bufnr = self.code.bufnr,
ask = opts.ask,
project_context = vim.json.encode(project_context),
file_content = content_with_line_numbers,
code_lang = filetype,
selected_code = selected_code_content_with_line_numbers,
Expand Down Expand Up @@ -1358,9 +1362,9 @@ function Sidebar:create_input(opts)
local function place_sign_at_first_line(bufnr)
local group = "avante_input_prompt_group"

vim.fn.sign_unplace(group, { buffer = bufnr })
fn.sign_unplace(group, { buffer = bufnr })

vim.fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 })
fn.sign_place(0, group, "AvanteInputPromptSign", bufnr, { lnum = 1 })
end

place_sign_at_first_line(self.input.bufnr)
Expand All @@ -1387,11 +1391,16 @@ function Sidebar:create_input(opts)
if not self.registered_cmp then
self.registered_cmp = true
cmp.register_source("avante_commands", require("cmp_avante.commands").new(self))
cmp.register_source(
"avante_mentions",
require("cmp_avante.mentions").new(Utils.get_mentions(), self.input.bufnr)
)
end
cmp.setup.buffer({
enabled = true,
sources = {
{ name = "avante_commands" },
{ name = "avante_mentions" },
},
})
end
Expand Down
38 changes: 20 additions & 18 deletions lua/avante/suggestion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,26 @@ function Suggestion:suggest()
return
end
Utils.debug("full_response: " .. vim.inspect(full_response))
local cursor_row, cursor_col = Utils.get_cursor_pos()
if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end
local ok, suggestions = pcall(vim.json.decode, full_response)
if not ok then
Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" })
return
end
if not suggestions then
Utils.info("No suggestions found", { once = true, title = "Avante" })
return
end
suggestions = vim
.iter(suggestions)
:map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end)
:totable()
ctx.suggestions = suggestions
ctx.current_suggestion_idx = 1
self:show()
vim.schedule(function()
local cursor_row, cursor_col = Utils.get_cursor_pos()
if cursor_row ~= doc.position.row or cursor_col ~= doc.position.col then return end
local ok, suggestions = pcall(vim.json.decode, full_response)
if not ok then
Utils.error("Error while decoding suggestions: " .. full_response, { once = true, title = "Avante" })
return
end
if not suggestions then
Utils.info("No suggestions found", { once = true, title = "Avante" })
return
end
suggestions = vim
.iter(suggestions)
:map(function(s) return { row = s.row, col = s.col, content = Utils.trim_all_line_numbers(s.content) } end)
:totable()
ctx.suggestions = suggestions
ctx.current_suggestion_idx = 1
self:show()
end)
end,
})
end
Expand Down
4 changes: 4 additions & 0 deletions lua/avante/templates/_context.avanterules
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{%- if use_xml_format -%}
<filepath>{{filepath}}</filepath>

{%- if selected_code -%}
<context>
```{{code_lang}}
Expand All @@ -19,6 +21,8 @@
</code>
{%- endif %}
{% else %}
FILEPATH: {{filepath}}

{%- if selected_code -%}
CONTEXT:
```{{code_lang}}
Expand Down
2 changes: 1 addition & 1 deletion lua/avante/templates/suggesting.avanterules
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Your task is to suggest code modifications at the cursor position. Follow these
{% endraw %}

3. When suggesting suggested code:
- DO NOT include three backticks: {%raw%}```{%endraw%} in your suggestion. Treat the suggested code AS IS.
- Each element in the returned list is a COMPLETE and INDEPENDENT code snippet.
- MUST be a valid json format. Don't be lazy!
- Only return the new code to be inserted.
Expand All @@ -29,4 +30,3 @@ Your task is to suggest code modifications at the cursor position. Follow these

Remember to ONLY RETURN the suggested code snippet, without any additional formatting or explanation.
{% endblock %}

Loading

0 comments on commit 8e1018f

Please sign in to comment.