Skip to content

Commit

Permalink
Added onDisconnected callback called after calling db:disconnect.
Browse files Browse the repository at this point in the history
  • Loading branch information
FredyH committed Sep 9, 2024
1 parent bf3ff37 commit 95f9750
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
*.zip
cmake-build-debug
.idea
.vs
.vs
.cache
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Database:connect()
Database:disconnect(shouldWait)
-- Returns nothing
-- disconnects from the database and waits for all queries to finish if shouldWait is true
-- This function calls the onDisconnected callback if it existed on the database before the database was connected.

Database:query( sql )
-- Returns [Query]
Expand Down Expand Up @@ -149,6 +150,11 @@ Database.onConnected( db )
Database.onConnectionFailed( db, err )
-- Called when the connection to the MySQL server fails, [String] err is why.

Database.onDisconnected( db )
-- Called after Database.disconnect has been called and all queries have finished executing
-- Note: You have to set this callback before calling Database:connect() or it will not be called.


-- Query/PreparedQuery object (transactions also inherit all functions, some have no effect though)

-- Functions
Expand Down
32 changes: 28 additions & 4 deletions src/lua/LuaDatabase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ MYSQLOO_LUA_FUNCTION(connect) {
LUA->Push(1);
database->m_tableReference = LuaReferenceCreate(LUA);
}

LUA->ReferencePush(database->m_tableReference);
LUA->GetField(-1, "onDisconnected");
database->m_hasOnDisconnected = LUA->IsType(-1, GarrysMod::Lua::Type::Function);
LUA->Pop(2); // callback, table

database->m_database->connect();
return 0;
}
Expand Down Expand Up @@ -351,7 +357,6 @@ void LuaDatabase::think(ILuaBase *LUA) {
LUA->ReferencePush(this->m_tableReference);
pcallWithErrorReporter(LUA, 1);
}
LUA->Pop(); //Callback function
} else {
LUA->GetField(-1, "onConnectionFailed");
if (LUA->GetType(-1) == GarrysMod::Lua::Type::Function) {
Expand All @@ -360,18 +365,37 @@ void LuaDatabase::think(ILuaBase *LUA) {
LUA->PushString(error.c_str());
pcallWithErrorReporter(LUA, 2);
}
LUA->Pop(); //Callback function
}
LUA->Pop(); // DB Table

LuaReferenceFree(LUA, this->m_tableReference);
this->m_tableReference = 0;
if (!this->m_hasOnDisconnected) {
// Only free the table reference if we do not have an onDisconnected callback.
// Otherwise, it will be freed after the onDisconnected callback was called.
LuaReferenceFree(LUA, this->m_tableReference);
this->m_tableReference = 0;
}
}

//Run callbacks of finished queries
auto finishedQueries = database->takeFinishedQueries();
for (auto &pair: finishedQueries) {
LuaQuery::runCallback(LUA, pair.first, pair.second);
}

if (database->wasDisconnected() && this->m_hasOnDisconnected && this->m_tableReference != 0) {
this->m_hasOnDisconnected = false;

LUA->ReferencePush(this->m_tableReference);

LUA->GetField(-1, "onDisconnected");
if (LUA->GetType(-1) == GarrysMod::Lua::Type::Function) {
LUA->ReferencePush(this->m_tableReference);
pcallWithErrorReporter(LUA, 1);
}
LUA->Pop(1); // DB Table

LuaReferenceFree(LUA, this->m_tableReference);
}
}

void LuaDatabase::onDestroyedByLua(ILuaBase *LUA) {
Expand Down
1 change: 1 addition & 0 deletions src/lua/LuaDatabase.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class LuaDatabase : public LuaObject {
void think(ILuaBase *LUA);

int m_tableReference = 0;
bool m_hasOnDisconnected = false;
std::shared_ptr<Database> m_database;
bool m_dbCallbackRan = false;

Expand Down
5 changes: 5 additions & 0 deletions src/lua/LuaObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ LUA_FUNCTION(errorReporter) {
return 1;
}

/**
* Similar to LUA->PCall but also uses an error reporter and prints the
* error to the console using ErrorNoHalt (if it exists).
* Consumes the function and all nargs arguments on the stack, does not return any values.
*/
void LuaObject::pcallWithErrorReporter(ILuaBase *LUA, int nargs) {
LUA->PushCFunction(errorReporter);
int errorHandlerIndex = LUA->Top() - nargs - 1;
Expand Down
9 changes: 8 additions & 1 deletion src/mysql/Database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,13 @@ void Database::disconnect(bool wait) {
if (wait && m_thread.joinable()) {
m_thread.join();
}
disconnected = true;
}

/*
* Returns true after the database has been fully disconnected and no more queries are in the queue.
*/
bool Database::wasDisconnected() {
return disconnected;
}

/* Returns the status of the database, constants can be found in GMModule
Expand Down Expand Up @@ -361,6 +367,7 @@ void Database::connectRun() {
if (m_status == DATABASE_CONNECTED) {
m_status = DATABASE_NOT_CONNECTED;
}
disconnected = true;
});
{
auto connectionSignaler = finally([&] { m_connectWakeupVariable.notify_one(); });
Expand Down
3 changes: 2 additions & 1 deletion src/mysql/Database.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Database : public std::enable_shared_from_this<Database> {
return finishedQueries.clear();
}

bool wasDisconnected();
private:
Database(std::string host, std::string username, std::string pw, std::string database, unsigned int port,
std::string unixSocket);
Expand Down Expand Up @@ -158,10 +159,10 @@ class Database : public std::enable_shared_from_this<Database> {
bool shouldAutoReconnect = true;
bool useMultiStatements = true;
bool startedConnecting = false;
bool disconnected = false;
bool m_canWait = false;
std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>> m_waitingQuery = {nullptr, nullptr};
std::atomic<bool> m_success{true};
std::atomic<bool> disconnected { false };
std::atomic<bool> m_connectionDone{false};
std::atomic<bool> cachePreparedStatements{true};
std::condition_variable m_queryWakeupVariable{};
Expand Down

0 comments on commit 95f9750

Please sign in to comment.