Skip to content

Commit

Permalink
feat: enable diagnostics (yetone#891)
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone authored Nov 23, 2024
1 parent d14b229 commit 9042f5f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 20 deletions.
2 changes: 2 additions & 0 deletions crates/avante-templates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct TemplateContext {
file_content: String,
selected_code: Option<String>,
project_context: Option<String>,
diagnostics: Option<String>,
}

// Given the file name registered after add, the context table in Lua, resulted in a formatted
Expand All @@ -47,6 +48,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult<
file_content => context.file_content,
selected_code => context.selected_code,
project_context => context.project_context,
diagnostics => context.diagnostics,
})
.map_err(LuaError::external)
.unwrap())
Expand Down
9 changes: 9 additions & 0 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ M.CANCEL_PATTERN = "AvanteLLMEscape"

local group = api.nvim_create_augroup("avante_llm", { clear = true })

---@param opts StreamOptions
---@param Provider AvanteProviderFunctor
M._stream = function(opts, Provider)
-- print opts
local mode = opts.mode or "planning"
Expand Down Expand Up @@ -52,6 +54,7 @@ M._stream = function(opts, Provider)
file_content = opts.file_content,
selected_code = opts.selected_code,
project_context = opts.project_context,
diagnostics = opts.diagnostics,
}

local system_prompt = Path.prompts.render_mode(mode, template_opts)
Expand All @@ -64,6 +67,11 @@ M._stream = function(opts, Provider)
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
end

if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
end

local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end

Expand Down Expand Up @@ -334,6 +342,7 @@ end
---@field file_content string
---@field selected_code string | nil
---@field project_context string | nil
---@field diagnostics string | nil
---@field history_messages AvanteLLMMessage[]
---
---@class StreamOptions: TemplateOptions
Expand Down
15 changes: 14 additions & 1 deletion lua/avante/range.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ local Range = {}
Range.__index = Range

---@class avante.RangeSelection: table<string, integer>
---@field line number
---@field lnum number
---@field col number

---Create a selection range
Expand All @@ -18,4 +18,17 @@ function Range:new(start, finish)
return instance
end

---Check if the line and column are within the range
---@param lnum number Line number
---@param col number Column number
---@return boolean
function Range:contains(lnum, col)
local start = self.start
local finish = self.finish
if lnum < start.lnum or lnum > finish.lnum then return false end
if lnum == start.lnum and col < start.col then return false end
if lnum == finish.lnum and col > finish.col then return false end
return true
end

return Range
21 changes: 12 additions & 9 deletions lua/avante/selection.lua
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ function Selection:create_editing_input()
local end_row
local end_col
if vim.fn.mode() == "V" then
start_row = self.selection.range.start.line - 1
start_row = self.selection.range.start.lnum - 1
start_col = 0
end_row = self.selection.range.finish.line - 1
end_col = #code_lines[self.selection.range.finish.line]
end_row = self.selection.range.finish.lnum - 1
end_col = #code_lines[self.selection.range.finish.lnum]
else
start_row = self.selection.range.start.line - 1
start_row = self.selection.range.start.lnum - 1
start_col = self.selection.range.start.col - 1
end_row = self.selection.range.finish.line - 1
end_col = math.min(self.selection.range.finish.col, #code_lines[self.selection.range.finish.line])
end_row = self.selection.range.finish.lnum - 1
end_col = math.min(self.selection.range.finish.col, #code_lines[self.selection.range.finish.lnum])
end

self.selected_code_extmark_id = api.nvim_buf_set_extmark(code_bufnr, SELECTED_CODE_NAMESPACE, start_row, start_col, {
Expand All @@ -148,10 +148,10 @@ function Selection:create_editing_input()

local submit_input = function(input)
local full_response = ""
local start_line = self.selection.range.start.line
local finish_line = self.selection.range.finish.line
local start_line = self.selection.range.start.lnum
local finish_line = self.selection.range.finish.lnum

local original_first_line_indentation = Utils.get_indentation(code_lines[self.selection.range.start.line])
local original_first_line_indentation = Utils.get_indentation(code_lines[self.selection.range.start.lnum])

local need_prepend_indentation = false

Expand Down Expand Up @@ -201,10 +201,13 @@ function Selection:create_editing_input()
input = mentions.new_content
local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil

local diagnostics = Utils.get_current_selection_diagnostics()

Llm.stream({
bufnr = code_bufnr,
ask = true,
project_context = vim.json.encode(project_context),
diagnostics = vim.json.encode(diagnostics),
file_content = code_content,
code_lang = filetype,
selected_code = self.selection.content,
Expand Down
26 changes: 20 additions & 6 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -551,12 +551,13 @@ local function minimize_snippet(original_lines, snippet)
local original_snippet_content = table.concat(original_snippet_lines, "\n")
local snippet_content = snippet.content
local snippet_lines = vim.split(snippet_content, "\n")
---@diagnostic disable-next-line: missing-fields, assign-type-mismatch
local patch = vim.diff(
---@diagnostic disable-next-line: assign-type-mismatch
local patch = vim.diff( ---@type integer[][]
original_snippet_content,
snippet_content,
---@diagnostic disable-next-line: missing-fields
{ algorithm = "histogram", result_type = "indices", ctxlen = vim.o.scrolloff }
) ---@type integer[][]
)
---@type AvanteCodeSnippet[]
local new_snippets = {}
for _, hunk in ipairs(patch) do
Expand Down Expand Up @@ -602,7 +603,6 @@ function Sidebar:apply(current_cursor)
local response, response_start_line = self:get_content_between_separators()
local all_snippets_map = extract_code_snippets_map(response)
all_snippets_map = ensure_snippets_no_overlap(all_snippets_map)
if Config.options.behaviour.minimize_diff then all_snippets_map = self:minimize_snippets(all_snippets_map) end
local selected_snippets_map = {}
if current_cursor then
if self.result and self.result.winid then
Expand All @@ -623,6 +623,10 @@ function Sidebar:apply(current_cursor)
selected_snippets_map = all_snippets_map
end

if Config.options.behaviour.minimize_diff then
selected_snippets_map = self:minimize_snippets(selected_snippets_map)
end

vim.defer_fn(function()
api.nvim_set_current_win(self.code.winid)
for filepath, snippets in pairs(selected_snippets_map) do
Expand Down Expand Up @@ -740,8 +744,8 @@ function Sidebar:render_input(ask)
ask and "Ask" or "Chat with",
icon,
code_filename,
self.code.selection.range.start.line,
self.code.selection.range.finish.line
self.code.selection.range.start.lnum,
self.code.selection.range.finish.lnum
)
end

Expand Down Expand Up @@ -1536,6 +1540,15 @@ function Sidebar:create_input(opts)

local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil

local diagnostics = nil
if mentions.enable_diagnostics then
if self.selected_code ~= nil then
diagnostics = Utils.get_current_selection_diagnostics()
else
diagnostics = Utils.get_diagnostics(self.code.bufnr)
end
end

local history_messages = {}
for i = #chat_history, 1, -1 do
local entry = chat_history[i]
Expand Down Expand Up @@ -1569,6 +1582,7 @@ function Sidebar:create_input(opts)
bufnr = self.code.bufnr,
ask = opts.ask,
project_context = vim.json.encode(project_context),
diagnostics = vim.json.encode(diagnostics),
history_messages = history_messages,
file_content = content,
code_lang = filetype,
Expand Down
12 changes: 12 additions & 0 deletions lua/avante/templates/_diagnostics.avanterules
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{%- if use_xml_format -%}
{%- if diagnostics -%}
<diagnostics>
{{diagnostics}}
</diagnostics>
{%- endif %}
{%- else -%}
{%- if diagnostics -%}
DIAGNOSTICS:
{{diagnostics}}
{%- endif %}
{%- endif %}
37 changes: 33 additions & 4 deletions lua/avante/utils/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function M.get_visual_selection_and_range()
start_col, end_col = end_col, start_col
end
local content = "" -- luacheck: ignore
local range = Range:new({ line = start_line, col = start_col }, { line = end_line, col = end_col })
local range = Range:new({ lnum = start_line, col = start_col }, { lnum = end_line, col = end_col })
-- Check if it's a single-line selection
if start_line == end_line then
-- Get partial content of a single line
Expand Down Expand Up @@ -665,19 +665,25 @@ end
function M.is_first_letter_uppercase(str) return string.match(str, "^[A-Z]") ~= nil end

---@param content string
---@return { new_content: string, enable_project_context: boolean }
---@return { new_content: string, enable_project_context: boolean, enable_diagnostics: boolean }
function M.extract_mentions(content)
-- if content contains @codebase, enable project context and remove @codebase
local new_content = content
local enable_project_context = false
local enable_diagnostics = false
if content:match("@codebase") then
enable_project_context = true
new_content = content:gsub("@codebase", "")
end
return { new_content = new_content, enable_project_context = enable_project_context }
if content:match("@diagnostics") then enable_diagnostics = true end
return {
new_content = new_content,
enable_project_context = enable_project_context,
enable_diagnostics = enable_diagnostics,
}
end

---@alias AvanteMentions "codebase"
---@alias AvanteMentions "codebase" | "diagnostics"
---@alias AvanteMentionCallback fun(args: string, cb?: fun(args: string): nil): nil
---@alias AvanteMention {description: string, command: AvanteMentions, details: string, shorthelp?: string, callback?: AvanteMentionCallback}
---@return AvanteMention[]
Expand All @@ -688,6 +694,11 @@ function M.get_mentions()
command = "codebase",
details = "repo map",
},
{
description = "diagnostics",
command = "diagnostics",
details = "diagnostics",
},
}
end

Expand Down Expand Up @@ -761,4 +772,22 @@ function M.update_buffer_content(bufnr, new_lines)
end
end

function M.get_diagnostics(bufnr)
if bufnr == nil then bufnr = api.nvim_get_current_buf() end
return vim.diagnostic.get(bufnr, { severity = { vim.diagnostic.severity.ERROR, vim.diagnostic.severity.WARN } })
end

function M.get_current_selection_diagnostics()
local selection = M.get_visual_selection_and_range()
if not selection then return {} end
local diagnostics = M.get_diagnostics()
local selection_diagnostics = {}
for _, diagnostic in ipairs(diagnostics) do
if selection.range:contains(diagnostic.lnum, diagnostic.col) then
table.insert(selection_diagnostics, diagnostic)
end
end
return selection_diagnostics
end

return M

0 comments on commit 9042f5f

Please sign in to comment.