Skip to content

Commit

Permalink
feat(api): enable customizable calls functions (#457)
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
  • Loading branch information
aarnphm authored Sep 2, 2024
1 parent d520f09 commit 7266661
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 266 deletions.
66 changes: 63 additions & 3 deletions lua/avante/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,73 @@ local Utils = require("avante.utils")
---@field hint ToggleBind.wrap
---
---@class avante.Api
---@field ask fun(): boolean
---@field edit fun(): nil
---@field ask fun(question:string?): boolean
---@field edit fun(question:string?): nil
---@field refresh fun(): nil
---@field build fun(): boolean
---@field switch_provider fun(target: string): nil
---@field toggle avante.ApiToggle

return setmetatable({}, {
local M = {}

---@param target Provider
M.switch_provider = function(target)
require("avante.providers").refresh(target)
end

---@param question? string
M.ask = function(question)
if not require("avante").toggle() then
return false
end
if question == nil or question == "" then
return true
end
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteInputSubmitted", data = { request = question } })
return true
end

---@param question? string
M.edit = function(question)
local _, selection = require("avante").get()
if not selection then
return
end
selection:create_editing_input()
if question ~= nil or question ~= "" then
vim.api.nvim_exec_autocmds("User", { pattern = "AvanteEditSubmitted", data = { request = question } })
end
end

M.refresh = function()
local sidebar, _ = require("avante").get()
if not sidebar then
return
end
if not sidebar:is_open() then
return
end
local curbuf = vim.api.nvim_get_current_buf()

local focused = sidebar.result.bufnr == curbuf or sidebar.input.bufnr == curbuf
if focused or not sidebar:is_open() then
return
end
local listed = vim.api.nvim_get_option_value("buflisted", { buf = curbuf })

if Utils.is_sidebar_buffer(curbuf) or not listed then
return
end

local curwin = vim.api.nvim_get_current_win()

sidebar:close()
sidebar.code.winid = curwin
sidebar.code.bufnr = curbuf
sidebar:render()
end

return setmetatable(M, {
__index = function(t, k)
local module = require("avante")
---@class AvailableApi: ApiCaller
Expand Down
29 changes: 22 additions & 7 deletions lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ local Utils = require("avante.utils")
local M = {}

---@class avante.Config
---@field silent_warning boolean will be determined from debug
---@field silent_warning? boolean will be determined from debug
M.defaults = {
debug = false,
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | [string]
provider = "claude", -- Only recommend using Claude
---@alias Tokenizer "tiktoken" | "hf"
-- Used for counting tokens and encoding text.
Expand Down Expand Up @@ -111,7 +111,6 @@ M.defaults = {
ours = "co",
theirs = "ct",
all_theirs = "ca",
none = "c0",
both = "cb",
cursor = "cc",
next = "]x",
Expand All @@ -130,6 +129,7 @@ M.defaults = {
edit = "<leader>ae",
refresh = "<leader>ar",
toggle = {
default = "<leader>at",
debug = "<leader>ad",
hint = "<leader>ah",
},
Expand All @@ -151,11 +151,9 @@ M.defaults = {
border = "rounded",
},
},
--- @class AvanteConflictUserConfig
--- @class AvanteConflictConfig
diff = {
autojump = true,
---@type string | fun(): any
list_opener = "copen",
},
--- @class AvanteHintsConfig
hints = {
Expand All @@ -166,11 +164,14 @@ M.defaults = {
---@type avante.Config
M.options = {}

---@class avante.ConflictConfig: AvanteConflictUserConfig
---@class avante.ConflictConfig: AvanteConflictConfig
---@field mappings AvanteConflictMappings
---@field highlights AvanteConflictHighlights
M.diff = {}

---@type Provider[]
M.providers = {}

---@param opts? avante.Config
function M.setup(opts)
vim.validate({ opts = { opts, "table", true } })
Expand All @@ -190,6 +191,16 @@ function M.setup(opts)
-- set silent_warning to true if debug is false
M.options.silent_warning = not M.options.debug
end
M.providers = vim
.iter(M.defaults)
:filter(function(_, value)
return type(value) == "table" and value.endpoint ~= nil
end)
:fold({}, function(acc, k)
acc = vim.list_extend({}, acc)
acc = vim.list_extend(acc, { k })
return acc
end)

vim.validate({ provider = { M.options.provider, "string", false } })

Expand All @@ -205,6 +216,7 @@ function M.setup(opts)
M.options.vendors[k] = type(v) == "function" and v() or v
end
vim.validate({ vendors = { M.options.vendors, "table", true } })
M.providers = vim.list_extend(M.providers, vim.tbl_keys(M.options.vendors))
end
end

Expand All @@ -228,6 +240,9 @@ function M.override(opts)
if next(M.options.vendors) ~= nil then
for k, v in pairs(M.options.vendors) do
M.options.vendors[k] = type(v) == "function" and v() or v
if not vim.tbl_contains(M.providers, k) then
M.providers = vim.list_extend(M.providers, { k })
end
end
vim.validate({ vendors = { M.options.vendors, "table", true } })
end
Expand Down
Loading

0 comments on commit 7266661

Please sign in to comment.