Skip to content

Commit

Permalink
Fix conversation and conversation model (typesense#1869)
Browse files Browse the repository at this point in the history
* Fix conversation models and conversations

* Remove unnecessary op in the test

* Fix persistence of conversations

* Fix conversaton swap

* Move TTL and conversation history logic to conversation models

* Fix conversation tests

* Fix tests
  • Loading branch information
ozanarmagan authored Aug 13, 2024
1 parent fea777f commit 0395857
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 513 deletions.
14 changes: 5 additions & 9 deletions include/conversation_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
#include "store.h"
#include "sole.hpp"
#include "collection_manager.h"

#include "conversation_model_manager.h"

class ConversationManager {
public:

ConversationManager(const ConversationManager&) = delete;
ConversationManager(ConversationManager&&) = delete;
ConversationManager& operator=(const ConversationManager&) = delete;
Expand All @@ -23,14 +22,12 @@ class ConversationManager {
static ConversationManager instance;
return instance;
}
Option<std::string> add_conversation(const nlohmann::json& conversation, const std::string& history_collection, const std::string& id = "");
Option<std::string> add_conversation(const nlohmann::json& conversation, const nlohmann::json& model, const std::string& id = "");
Option<nlohmann::json> get_conversation(const std::string& conversation_id);
static Option<nlohmann::json> truncate_conversation(nlohmann::json conversation, size_t limit);
Option<nlohmann::json> update_conversation(nlohmann::json conversation);
Option<nlohmann::json> delete_conversation(const std::string& conversation_id);
Option<bool> check_conversation_exists(const std::string& conversation_id);
Option<std::unordered_set<std::string>> get_conversation_ids();
Option<nlohmann::json> get_all_conversations();
static constexpr size_t MAX_TOKENS = 3000;
Option<bool> init(ReplicationState* raft_server);
void clear_expired_conversations();
Expand All @@ -45,17 +42,16 @@ class ConversationManager {
Option<bool> add_history_collection(const std::string& collection);
Option<bool> remove_history_collection(const std::string& collection);
Option<Collection*> get_history_collection(const std::string& conversation_id);
Option<bool> initialize_history_collection(const std::string& collection);
private:
ConversationManager() {}
std::mutex conversations_mutex;

ReplicationState* raft_server;
static constexpr size_t CONVERSATION_TTL = 60 * 60 * 24;
size_t TTL_OFFSET = 0;
size_t MAX_CONVERSATIONS_TO_DELETE_ONCE = 5;

std::atomic<bool> quit = false;
std::condition_variable cv;
std::unordered_map<std::string, uint32_t> history_collection_map;
std::unordered_map<std::string, std::string> conversation_mapper;

Option<nlohmann::json> delete_conversation_unsafe(const std::string& conversation_id);
};
6 changes: 6 additions & 0 deletions include/conversation_model_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class ConversationModelManager
static Option<nlohmann::json> get_all_models();
static Option<int> init(Store* store);
static Option<nlohmann::json> migrate_model(nlohmann::json model);
static std::unordered_set<std::string> get_history_collections();
// For testing Purpose only
static void insert_model_for_testing(const std::string& model_id, nlohmann::json model) {
std::unique_lock lock(models_mutex);
models[model_id] = model;
}
private:
static inline std::unordered_map<std::string, nlohmann::json> models;
static inline std::shared_mutex models_mutex;
Expand Down
7 changes: 0 additions & 7 deletions include/core_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,6 @@ void init_api(uint32_t cache_num_entries);
bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);


bool get_conversations(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);

bool get_conversation(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);

bool del_conversation(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);

bool put_conversation(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);


bool post_conversation_model(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
Expand Down
4 changes: 1 addition & 3 deletions src/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ Collection::~Collection() {
VQModelManager::get_instance().delete_model(vq_model->get_model_name());
}
}

ConversationManager::get_instance().remove_history_collection(name);
}

uint32_t Collection::get_next_seq_id() {
Expand Down Expand Up @@ -2920,7 +2918,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
conversation_history.push_back(formatted_question_op.get());
conversation_history.push_back(formatted_answer_op.get());

auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model["history_collection"].get<std::string>(), conversation_id);
auto add_conversation_op = ConversationManager::get_instance().add_conversation(conversation_history, conversation_model, conversation_id);
if(!add_conversation_op.ok()) {
return Option<nlohmann::json>(add_conversation_op.code(), add_conversation_op.error());
}
Expand Down
Loading

0 comments on commit 0395857

Please sign in to comment.