Skip to content

Commit

Permalink
feat: supports openai o1-preview
Browse files Browse the repository at this point in the history
* feat: make O1 models on openai work by handle non-streams & correct
parameters

* chore: set temperature automatically when using o1 models
  • Loading branch information
LessComplexity authored Sep 27, 2024
1 parent 64f26b9 commit d74c9d0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
12 changes: 12 additions & 0 deletions lua/avante/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ M.stream = function(opts)
if data_match then Provider.parse_response(data_match, current_event_state, handler_opts) end
end

local function parse_response_without_stream(data)
Provider.parse_response_without_stream(data, current_event_state, handler_opts)
end

local completed = false

local active_job
Expand Down Expand Up @@ -170,6 +174,14 @@ M.stream = function(opts)
end
end)
end

-- If stream is not enabled, then handle the response here
if spec.body.stream == false and result.status == 200 then
vim.schedule(function()
completed = true
parse_response_without_stream(result.body)
end)
end
end,
})

Expand Down
50 changes: 44 additions & 6 deletions lua/avante/providers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ local P = require("avante.providers")
---@field created integer
---@field model string
---@field system_fingerprint string
---@field choices? OpenAIResponseChoice[]
---@field choices? OpenAIResponseChoice[] | OpenAIResponseChoiceComplete[]
---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer}
---
---@class OpenAIResponseChoice
Expand All @@ -18,6 +18,12 @@ local P = require("avante.providers")
---@field logprobs? integer
---@field finish_reason? "stop" | "length"
---
---@class OpenAIResponseChoiceComplete
---@field message OpenAIMessage
---@field finish_reason "stop" | "length"
---@field index integer
---@field logprobs integer
---
---@class OpenAIMessage
---@field role? "user" | "system" | "assistant"
---@field content string
Expand Down Expand Up @@ -50,10 +56,22 @@ M.parse_message = function(opts)
end)
end

return {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = user_content },
}
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)

-- NOTE: Handle the case where the selected model is the `o1` model
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
if base.model and string.find(base.model, "o1") then
table.insert(messages, { role = "user", content = opts.system_prompt })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
end

-- User message after the prompt
table.insert(messages, { role = "user", content = user_content })

return messages
end

M.parse_response = function(data_stream, _, opts)
Expand All @@ -75,6 +93,18 @@ M.parse_response = function(data_stream, _, opts)
end
end

M.parse_response_without_stream = function(data, _, opts)
---@type OpenAIChatResponse
local json = vim.json.decode(data)
if json.choices and json.choices[1] then
local choice = json.choices[1]
if choice.message and choice.message.content then
opts.on_chunk(choice.message.content)
vim.schedule(function() opts.on_complete(nil) end)
end
end
end

M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)

Expand All @@ -83,6 +113,14 @@ M.parse_curl_args = function(provider, code_opts)
}
if not P.env.is_local("openai") then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end

-- NOTE: When using "o1" set the supported parameters only
local stream = true
if base.model and string.find(base.model, "o1") then
stream = false
body_opts.max_tokens = nil
body_opts.temperature = 1
end

return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions",
proxy = base.proxy,
Expand All @@ -91,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
stream = true,
stream = stream,
}, body_opts),
}
end
Expand Down

0 comments on commit d74c9d0

Please sign in to comment.