diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 2ed7d2e15..2096ee00b 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -204,7 +204,8 @@ M.get_provider = function(provider) end end -M.BASE_PROVIDER_KEYS = { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure" } +M.BASE_PROVIDER_KEYS = + { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" } ---@return {width: integer, height: integer} function M.get_sidebar_layout_options() diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 108865504..7763fab2b 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -6,10 +6,10 @@ local O = require("avante.providers").openai ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "AZURE_OPENAI_API_KEY" +M.api_key_name = "AZURE_OPENAI_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = O.parse_message @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("azure") then - headers["api-key"] = os.getenv(M.API_KEY) + headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name) end return { diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 8cd770ea2..dff439b33 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -6,10 +6,10 @@ local P = require("avante.providers") ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "ANTHROPIC_API_KEY" +M.api_key_name = "ANTHROPIC_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = function(opts) @@ -93,7 +93,7 @@ M.parse_curl_args = function(provider, code_opts) ["anthropic-beta"] = "prompt-caching-2024-07-31", } if not P.env.is_local("claude") then - headers["x-api-key"] = os.getenv(M.API_KEY) + headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name) end return { diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index af6374ab1..6c49a39aa 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -142,7 +142,7 @@ H.generate_headers = function(token, sessionid, machineid) return headers end -M.API_KEY = P.AVANTE_INTERNAL_KEY +M.api_key_name = P.AVANTE_INTERNAL_KEY M.has = function() if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then diff --git a/lua/avante/providers/deepseek.lua b/lua/avante/providers/deepseek.lua index effef50ef..965411342 100644 --- a/lua/avante/providers/deepseek.lua +++ b/lua/avante/providers/deepseek.lua @@ -6,10 +6,10 @@ local O = require("avante.providers").openai ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "DEEPSEEK_API_KEY" +M.api_key_name = "DEEPSEEK_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = O.parse_message @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("deepseek") then - headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) end return { diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 5aee0e6da..58f6df0fc 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -5,10 +5,10 @@ local P = require("avante.providers") ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "GROQ_API_KEY" +M.api_key_name = "GROQ_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = function(opts) @@ -68,7 +68,7 @@ M.parse_curl_args = function(provider, code_opts) .. "/" .. base.model .. ":streamGenerateContent?alt=sse&key=" - .. os.getenv(M.API_KEY), + .. os.getenv(base.api_key_name or M.api_key_name), proxy = base.proxy, insecure = base.allow_insecure, headers = { ["Content-Type"] = "application/json" }, diff --git a/lua/avante/providers/groq.lua b/lua/avante/providers/groq.lua index 779d932a7..abf6d5163 100644 --- a/lua/avante/providers/groq.lua +++ b/lua/avante/providers/groq.lua @@ -6,10 +6,10 @@ local O = require("avante.providers").openai ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "GROQ_API_KEY" +M.api_key_name = "GROQ_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = O.parse_message @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("groq") then - headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) end return { diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 7523cdfcc..80287c265 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -80,7 +80,7 @@ local Dressing = require("avante.ui.dressing") ---@field parse_curl_args AvanteCurlArgsParser ---@field setup? fun(): nil ---@field has fun(): boolean ----@field API_KEY string +---@field api_key_name string ---@field parse_stream_data? AvanteStreamParser --- ---@class avante.Providers @@ -102,13 +102,10 @@ setmetatable(M, { local v = Config.vendors[k] -- Patch from vendors similar to supported providers. + ---@type AvanteProviderFunctor t[k] = setmetatable({}, { __index = v }) - t[k].API_KEY = v.api_key_name -- Hack for aliasing and makes it sane for us. t[k].parse_response = v.parse_response_data - t[k].has = function() - return os.getenv(v.api_key_name) and true or false - end return t[k] end @@ -130,7 +127,7 @@ E._once = false ---@param opts {refresh: boolean, provider: AvanteProviderFunctor} ---@private E.setup = function(opts) - local var = opts.provider.API_KEY + local var = opts.provider.api_key_name if var == M.AVANTE_INTERNAL_KEY then return diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index d2a3dc7fe..72e2b76c7 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -24,10 +24,10 @@ local P = require("avante.providers") ---@class AvanteProviderFunctor local M = {} -M.API_KEY = "OPENAI_API_KEY" +M.api_key_name = "OPENAI_API_KEY" M.has = function() - return os.getenv(M.API_KEY) and true or false + return os.getenv(M.api_key_name) and true or false end M.parse_message = function(opts) @@ -91,7 +91,7 @@ M.parse_curl_args = function(provider, code_opts) ["Content-Type"] = "application/json", } if not P.env.is_local("openai") then - headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name) end return {