Skip to content

Commit

Permalink
fix(llm): persistent key check for override class (yetone#158)
Browse files Browse the repository at this point in the history
* fix(llm): make sure to allow passing custom module

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* fix: correct custom class

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

* fix: correct attribute

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>

---------

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
  • Loading branch information
aarnphm authored Aug 23, 2024
1 parent 49fabfc commit 6475407
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 26 deletions.
3 changes: 2 additions & 1 deletion lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ M.get_provider = function(provider)
end
end

M.BASE_PROVIDER_KEYS = { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure" }
M.BASE_PROVIDER_KEYS =
{ "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" }

---@return {width: integer, height: integer}
function M.get_sidebar_layout_options()
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "AZURE_OPENAI_API_KEY"
M.api_key_name = "AZURE_OPENAI_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = O.parse_message
Expand All @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("azure") then
headers["api-key"] = os.getenv(M.API_KEY)
headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name)
end

return {
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/claude.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "ANTHROPIC_API_KEY"
M.api_key_name = "ANTHROPIC_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = function(opts)
Expand Down Expand Up @@ -93,7 +93,7 @@ M.parse_curl_args = function(provider, code_opts)
["anthropic-beta"] = "prompt-caching-2024-07-31",
}
if not P.env.is_local("claude") then
headers["x-api-key"] = os.getenv(M.API_KEY)
headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name)
end

return {
Expand Down
2 changes: 1 addition & 1 deletion lua/avante/providers/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ H.generate_headers = function(token, sessionid, machineid)
return headers
end

M.API_KEY = P.AVANTE_INTERNAL_KEY
M.api_key_name = P.AVANTE_INTERNAL_KEY

M.has = function()
if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/deepseek.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "DEEPSEEK_API_KEY"
M.api_key_name = "DEEPSEEK_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = O.parse_message
Expand All @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("deepseek") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end

return {
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "GROQ_API_KEY"
M.api_key_name = "GROQ_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = function(opts)
Expand Down Expand Up @@ -68,7 +68,7 @@ M.parse_curl_args = function(provider, code_opts)
.. "/"
.. base.model
.. ":streamGenerateContent?alt=sse&key="
.. os.getenv(M.API_KEY),
.. os.getenv(base.api_key_name or M.api_key_name),
proxy = base.proxy,
insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" },
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/groq.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "GROQ_API_KEY"
M.api_key_name = "GROQ_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = O.parse_message
Expand All @@ -22,7 +22,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("groq") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end

return {
Expand Down
9 changes: 3 additions & 6 deletions lua/avante/providers/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ local Dressing = require("avante.ui.dressing")
---@field parse_curl_args AvanteCurlArgsParser
---@field setup? fun(): nil
---@field has fun(): boolean
---@field API_KEY string
---@field api_key_name string
---@field parse_stream_data? AvanteStreamParser
---
---@class avante.Providers
Expand All @@ -102,13 +102,10 @@ setmetatable(M, {
local v = Config.vendors[k]

-- Patch from vendors similar to supported providers.
---@type AvanteProviderFunctor
t[k] = setmetatable({}, { __index = v })
t[k].API_KEY = v.api_key_name
-- Hack for aliasing and makes it sane for us.
t[k].parse_response = v.parse_response_data
t[k].has = function()
return os.getenv(v.api_key_name) and true or false
end

return t[k]
end
Expand All @@ -130,7 +127,7 @@ E._once = false
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
---@private
E.setup = function(opts)
local var = opts.provider.API_KEY
local var = opts.provider.api_key_name

if var == M.AVANTE_INTERNAL_KEY then
return
Expand Down
6 changes: 3 additions & 3 deletions lua/avante/providers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ local P = require("avante.providers")
---@class AvanteProviderFunctor
local M = {}

M.API_KEY = "OPENAI_API_KEY"
M.api_key_name = "OPENAI_API_KEY"

M.has = function()
return os.getenv(M.API_KEY) and true or false
return os.getenv(M.api_key_name) and true or false
end

M.parse_message = function(opts)
Expand Down Expand Up @@ -91,7 +91,7 @@ M.parse_curl_args = function(provider, code_opts)
["Content-Type"] = "application/json",
}
if not P.env.is_local("openai") then
headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY)
headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
end

return {
Expand Down

0 comments on commit 6475407

Please sign in to comment.