Skip to content

Commit

Permalink
feat: make tiktoken optional (yetone#245)
Browse files Browse the repository at this point in the history
yetone authored Aug 26, 2024
1 parent 3d3a249 commit b874045
Showing 5 changed files with 66 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ Install `avante.nvim` using [lazy.nvim](https://github.com/folke/lazy.nvim):
{
"yetone/avante.nvim",
event = "VeryLazy",
build = "make",
build = "make", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
opts = {
-- add any opts here
},
@@ -50,7 +50,7 @@ For Windows users, change the build command to the following:
{
"yetone/avante.nvim",
event = "VeryLazy",
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1",
build = "powershell -ExecutionPolicy Bypass -File Build-LuaTiktoken.ps1", -- This is Optional, only if you want to use tiktoken_core to calculate tokens count
-- rest of the config
}
```
11 changes: 7 additions & 4 deletions lua/avante/providers/claude.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
local Utils = require("avante.utils")
local Tiktoken = require("avante.tiktoken")
local Tokens = require("avante.utils.tokens")
local P = require("avante.providers")

---@class AvanteProviderFunctor
@@ -13,7 +13,7 @@ M.parse_message = function(opts)
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
}

if Tiktoken.count(code_prompt_obj.text) > 1024 then
if Tokens.calculate_tokens(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end

@@ -31,7 +31,7 @@ M.parse_message = function(opts)
text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.selected_code_content),
}

if Tiktoken.count(selected_code_obj.text) > 1024 then
if Tokens.calculate_tokens(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end

@@ -50,7 +50,7 @@ M.parse_message = function(opts)
text = user_prompt,
}

if Tiktoken.count(user_prompt_obj.text) > 1024 then
if Tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end

@@ -79,6 +79,9 @@ M.parse_response = function(data_stream, event_state, opts)
end
end

---@param provider AvanteProviderFunctor
---@param code_opts AvantePromptOptions
---@return table
M.parse_curl_args = function(provider, code_opts)
local base, body_opts = P.parse_config(provider)

1 change: 1 addition & 0 deletions lua/avante/providers/init.lua
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ local Dressing = require("avante.ui.dressing")
---@field parse_response_data AvanteResponseParser
---@field parse_curl_args? AvanteCurlArgsParser
---@field parse_stream_data? AvanteStreamParser
---@field parse_api_key fun(): string | nil
---
---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser
1 change: 0 additions & 1 deletion lua/avante/tiktoken.lua
Original file line number Diff line number Diff line change
@@ -52,7 +52,6 @@ local M = {}
function M.setup(model)
local ok, core = pcall(require, "tiktoken_core")
if not ok then
print("Warn: tiktoken_core is not found!!!!")
return
end

56 changes: 56 additions & 0 deletions lua/avante/utils/tokens.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
local Tiktoken = require("avante.tiktoken")
local Tokens = {}

--[[
cost_per_token
@param {string} token_name
@return {number} cost_per_token
]]
local cost_per_token = {
davinci = 0.000002,
}

--- Calculate the number of tokens in a given text.
-- @param text The text to calculate the number of tokens for.
-- @return The number of tokens in the given text.
function Tokens.calculate_tokens(text)
if Tiktoken.available() then
return Tiktoken.count(text)
end
local tokens = 0
local current_token = ""
for char in text:gmatch(".") do
if char == " " or char == "\n" then
if current_token ~= "" then
tokens = tokens + 1
current_token = ""
end
else
current_token = current_token .. char
end
end
if current_token ~= "" then
tokens = tokens + 1
end
return tokens
end

--- Calculate the cost of a given text in dollars.
-- @param text The text to calculate the cost of.
-- @param model The model to use to calculate the cost.
-- @return The cost of the given text in dollars.
function Tokens.calculate_usage_in_dollars(text, model)
local tokens = Tokens.calculate_tokens(text)
return Tokens.usage_in_dollars(tokens, model)
end

--- Calculate the cost of a given number of tokens in dollars.
-- @param tokens The number of tokens to calculate the cost of.
-- @param model The model to use to calculate the cost.
-- @return The cost of the given number of tokens in dollars.
function Tokens.usage_in_dollars(tokens, model)
return tokens * cost_per_token[model or "davinci"]
end

return Tokens

0 comments on commit b874045

Please sign in to comment.