Skip to content

Commit

Permalink
refactor(history): use XDG_DATA_STATE as source dir (yetone#209)
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
  • Loading branch information
aarnphm authored Aug 25, 2024
1 parent ee88563 commit 3f906e5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 54 deletions.
3 changes: 3 additions & 0 deletions lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ M.defaults = {
auto_set_highlight_group = true,
auto_apply_diff_after_generation = false,
},
history = {
storage_path = vim.fn.stdpath("state") .. "/avante",
},
highlights = {
---@type AvanteConflictHighlights
diff = {
Expand Down
49 changes: 49 additions & 0 deletions lua/avante/history.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
local fn, api = vim.fn, vim.api
local Path = require("plenary.path")
local Config = require("avante.config")

local M = {}

local H = {}

---@param bufnr integer
---@return string
H.filename = function(bufnr)
local code_buf_name = api.nvim_buf_get_name(bufnr)
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(code_buf_name, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") .. ".json"
end

---@param bufnr integer
---@return Path
M.get = function(bufnr)
return Path:new(Config.history.storage_path):joinpath(H.filename(bufnr))
end

---@param bufnr integer
M.load = function(bufnr)
local history_file = M.get(bufnr)
if history_file:exists() then
local content = history_file:read()
return content ~= nil and vim.json.decode(content) or {}
end
return {}
end

---@param bufnr integer
---@param history table
M.save = function(bufnr, history)
local history_file = M.get(bufnr)
history_file:write(vim.json.encode(history), "w")
end

M.setup = function()
local history_dir = Path:new(Config.history.storage_path)
if not history_dir:exists() then
history_dir:mkdir({ parents = true })
end
end

return M
1 change: 1 addition & 0 deletions lua/avante/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ function M.setup(opts)
return
end

require("avante.history").setup()
require("avante.highlights").setup()
require("avante.diff").setup()
require("avante.providers").setup()
Expand Down
59 changes: 5 additions & 54 deletions lua/avante/sidebar.lua
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
local api = vim.api
local fn = vim.fn

local Path = require("plenary.path")
local Split = require("nui.split")
local event = require("nui.utils.autocmd").event

local History = require("avante.history")
local Config = require("avante.config")
local Diff = require("avante.diff")
local Llm = require("avante.llm")
Expand Down Expand Up @@ -927,32 +927,6 @@ local function prepend_line_number(content, start_line)
return table.concat(result, "\n")
end

-- Function to get the current project root directory
local function get_project_root()
local current_file = fn.expand("%:p")
local current_dir = fn.fnamemodify(current_file, ":h")
local git_root = vim.fs.root(current_file, { ".git" })
return git_root ~= nil and git_root or current_dir
end

---@param sidebar avante.Sidebar
local function get_chat_history_filename(sidebar)
local code_buf_name = api.nvim_buf_get_name(sidebar.code.bufnr)
local relative_path = fn.fnamemodify(code_buf_name, ":~:.")
-- Replace path separators with double underscores
local path_with_separators = fn.substitute(relative_path, "/", "__", "g")
-- Replace other non-alphanumeric characters with single underscores
return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
end

-- Function to get the chat history file path
local function get_chat_history_file(sidebar)
local project_root = get_project_root()
local filename = get_chat_history_filename(sidebar)
local history_dir = Path:new(project_root, ".avante_chat_history")
return history_dir:joinpath(filename .. ".json")
end

-- Function to get current timestamp
local function get_timestamp()
return os.date("%Y-%m-%d %H:%M:%S")
Expand All @@ -973,29 +947,6 @@ local function get_chat_record_prefix(timestamp, provider, model, request)
.. "\n\n"
end

-- Function to load chat history
local function load_chat_history(sidebar)
local history_file = get_chat_history_file(sidebar)
if history_file:exists() then
local content = history_file:read()
return fn.json_decode(content)
end
return {}
end

-- Function to save chat history
local function save_chat_history(sidebar, history)
local history_file = get_chat_history_file(sidebar)
local history_dir = history_file:parent()

-- Create the directory if it doesn't exist
if not history_dir:exists() then
history_dir:mkdir({ parents = true })
end

history_file:write(fn.json_encode(history), "w")
end

function Sidebar:update_content_with_history(history)
local content = ""
for idx, entry in ipairs(history) do
Expand Down Expand Up @@ -1082,7 +1033,7 @@ function Sidebar:get_commands()
end,
clear = function(args, cb)
local chat_history = {}
save_chat_history(self, chat_history)
History.save(self.code.bufnr, chat_history)
self:update_content("Chat history cleared", { focus = false, scroll = false })
vim.defer_fn(function()
self:close()
Expand Down Expand Up @@ -1159,7 +1110,7 @@ function Sidebar:create_input()
self.input:unmount()
end

local chat_history = load_chat_history(self)
local chat_history = History.load(self.code.bufnr)

---@param request string
local function handle_submit(request)
Expand Down Expand Up @@ -1276,7 +1227,7 @@ function Sidebar:create_input()
request = request,
response = full_response,
})
save_chat_history(self, chat_history)
History.save(self.code.bufnr, chat_history)
end

Llm.stream(
Expand Down Expand Up @@ -1476,7 +1427,7 @@ function Sidebar:create_floating_window_for_split(opts)
end

function Sidebar:render()
local chat_history = load_chat_history(self)
local chat_history = History.load(self.code.bufnr)

local sidebar_height = api.nvim_win_get_height(self.code.winid)
local selected_code_size = self:get_selected_code_size()
Expand Down

0 comments on commit 3f906e5

Please sign in to comment.