Skip to content

Commit

Permalink
Fix recursive mutex lock while migrating conversation models (typesen…
Browse files Browse the repository at this point in the history
…se#1734)

* Fix recursive mutex lock while migrating conversation models

* Add `add_model_unsafe` and guard for `remove_conversation_collection`

* undo unrelated change

* undo unrelated changes
  • Loading branch information
ozanarmagan authored May 16, 2024
1 parent 52b9d5a commit 89bcd38
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/conversation_model_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ class ConversationModelManager
static inline Store* store;
static const std::string get_model_key(const std::string& model_id);
static Option<Collection*> get_default_conversation_collection();
static Option<nlohmann::json> delete_model_unsafe(const std::string& model_id);
static Option<nlohmann::json> add_model_unsafe(nlohmann::json model, const std::string& model_id);

};
20 changes: 16 additions & 4 deletions src/conversation_model_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ Option<nlohmann::json> ConversationModelManager::get_model(const std::string& mo

Option<nlohmann::json> ConversationModelManager::add_model(nlohmann::json model, const std::string& model_id) {
std::unique_lock lock(models_mutex);

return add_model_unsafe(model, model_id);
}

Option<nlohmann::json> ConversationModelManager::add_model_unsafe(nlohmann::json model, const std::string& model_id) {
auto validate_res = ConversationModel::validate_model(model);
if (!validate_res.ok()) {
return Option<nlohmann::json>(validate_res.code(), validate_res.error());
Expand All @@ -37,6 +42,11 @@ Option<nlohmann::json> ConversationModelManager::add_model(nlohmann::json model,

Option<nlohmann::json> ConversationModelManager::delete_model(const std::string& model_id) {
std::unique_lock lock(models_mutex);

return delete_model_unsafe(model_id);
}

Option<nlohmann::json> ConversationModelManager::delete_model_unsafe(const std::string& model_id) {
auto it = models.find(model_id);
if (it == models.end()) {
return Option<nlohmann::json>(404, "Model not found");
Expand All @@ -47,7 +57,9 @@ Option<nlohmann::json> ConversationModelManager::delete_model(const std::string&
auto model_key = get_model_key(model_id);
bool delete_op = store->remove(model_key);

ConversationManager::get_instance().remove_conversation_collection(model["conversation_collection"]);
if(model.count("conversation_collection") != 0) {
ConversationManager::get_instance().remove_conversation_collection(model["conversation_collection"].get<std::string>());
}
models.erase(it);
return Option<nlohmann::json>(model);
}
Expand Down Expand Up @@ -105,7 +117,7 @@ Option<int> ConversationModelManager::init(Store* store) {
std::string model_id = model_json["id"];
models[model_id] = model_json;
if(model_json.count("conversation_collection") == 0) {
auto delete_op = delete_model(model_id);
auto delete_op = delete_model_unsafe(model_id);
if(!delete_op.ok()) {
return Option<int>(delete_op.code(), delete_op.error());
}
Expand All @@ -115,7 +127,7 @@ Option<int> ConversationModelManager::init(Store* store) {
}
model_json = migrate_op.get();
}
ConversationManager::get_instance().add_conversation_collection(model_json["conversation_collection"]);
ConversationManager::get_instance().add_conversation_collection(model_json["conversation_collection"].get<std::string>());
loaded_models++;
}

Expand Down Expand Up @@ -180,7 +192,7 @@ Option<nlohmann::json> ConversationModelManager::migrate_model(nlohmann::json mo
return Option<nlohmann::json>(default_collection.code(), default_collection.error());
}
model["conversation_collection"] = default_collection.get()->get_name();
auto add_res = add_model(model, model_id);
auto add_res = add_model_unsafe(model, model_id);
if(!add_res.ok()) {
return Option<nlohmann::json>(add_res.code(), add_res.error());
}
Expand Down

0 comments on commit 89bcd38

Please sign in to comment.