Skip to content

Commit

Permalink
Merge pull request #3010 from kuzudb/add-client-config
Browse files Browse the repository at this point in the history
Abstract client config
  • Loading branch information
andyfengHKU authored Mar 8, 2024
2 parents b7e3bc7 + b2f50ac commit c554a20
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 121 deletions.
14 changes: 7 additions & 7 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,20 +454,20 @@ std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
function::CastString::operation(
ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()},
lowerBound);
auto upperBound = clientContext->varLengthExtendMaxDepth;
auto maxDepth = clientContext->getClientConfig()->varLengthMaxDepth;
auto upperBound = maxDepth;
if (!recursiveInfo->upperBound.empty()) {
function::CastString::operation(
ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()},
upperBound);
}
if (lowerBound > upperBound) {
throw BinderException(
"Lower bound of rel " + relPattern.getVariableName() + " is greater than upperBound.");
throw BinderException(stringFormat(
"Lower bound of rel {} is greater than upperBound.", relPattern.getVariableName()));
}
if (upperBound > clientContext->varLengthExtendMaxDepth) {
throw BinderException(
"Upper bound of rel " + relPattern.getVariableName() +
" exceeds maximum: " + std::to_string(clientContext->varLengthExtendMaxDepth) + ".");
if (upperBound > maxDepth) {
throw BinderException(stringFormat("Upper bound of rel {} exceeds maximum: {}.",
relPattern.getVariableName(), std::to_string(maxDepth)));
}
if ((relPattern.getRelType() == QueryRelType::ALL_SHORTEST ||
relPattern.getRelType() == QueryRelType::SHORTEST) &&
Expand Down
2 changes: 1 addition & 1 deletion src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void TaskScheduler::scheduleTaskAndWaitOrError(
taskLck.unlock();
break;
}
if (context->clientContext->isTimeOutEnabled()) {
if (context->clientContext->hasTimeout()) {
timeout = context->clientContext->getTimeoutRemainingInMS();
if (timeout == 0) {
context->clientContext->interrupt();
Expand Down
6 changes: 4 additions & 2 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ struct PlannerKnobs {
static constexpr uint64_t SIP_RATIO = 5;
};

struct ClientContextConstants {
// We disable query timeout by default.
struct ClientConfigDefault {
// 0 means timeout is disabled by default.
static constexpr uint64_t TIMEOUT_IN_MS = 0;
static constexpr uint32_t VAR_LENGTH_MAX_DEPTH = 30;
static constexpr bool ENABLE_SEMI_MASK = true;
};

struct OrderByConstants {
Expand Down
4 changes: 2 additions & 2 deletions src/include/common/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class Timer {
finished = true;
}

double getDuration() {
double getDuration() const {
if (finished) {
auto duration = stopTime - startTime;
return (double)std::chrono::duration_cast<std::chrono::microseconds>(duration).count();
}
throw Exception("Timer is still running.");
}

uint64_t getElapsedTimeInMS() {
uint64_t getElapsedTimeInMS() const {
auto now = std::chrono::high_resolution_clock::now();
auto duration = now - startTime;
auto count = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
Expand Down
24 changes: 24 additions & 0 deletions src/include/main/client_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <string>

namespace kuzu {
namespace main {

struct ClientConfig {
// System home directory.
std::string homeDirectory;
// File search path.
std::string fileSearchPath;
// If using semi mask in join.
bool enableSemiMask;
// Number of threads for execution.
uint64_t numThreads;
// Timeout (milliseconds)
uint64_t timeoutInMS;
// variable length maximum depth
uint32_t varLengthMaxDepth;
};

} // namespace main
} // namespace kuzu
85 changes: 35 additions & 50 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <mutex>

#include "client_config.h"
#include "common/timer.h"
#include "common/types/value/value.h"
#include "function/scalar_function.h"
Expand Down Expand Up @@ -47,77 +48,59 @@ class ClientContext {
friend class Connection;
friend class binder::Binder;
friend class binder::ExpressionBinder;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;
friend struct ThreadsSetting;
friend struct TimeoutSetting;
friend struct VarLengthExtendMaxDepthSetting;
friend struct EnableSemiMaskSetting;
friend struct HomeDirectorySetting;
friend struct FileSearchPathSetting;

public:
explicit ClientContext(Database* database);

inline void interrupt() { activeQuery.interrupted = true; }

bool isInterrupted() const { return activeQuery.interrupted; }

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

inline uint64_t getTimeoutRemainingInMS() {
KU_ASSERT(isTimeOutEnabled());
auto elapsed = activeQuery.timer.getElapsedTimeInMS();
return elapsed >= timeoutInMS ? 0 : timeoutInMS - elapsed;
}

inline bool isEnableSemiMask() const { return enableSemiMask; }

void startTimingIfEnabled();

// Client config
const ClientConfig* getClientConfig() const { return &config; }
ClientConfig* getClientConfigUnsafe() { return &config; }
KUZU_API common::Value getCurrentSetting(const std::string& optionName);
// Timer and timeout
void interrupt() { activeQuery.interrupted = true; }
bool interrupted() const { return activeQuery.interrupted; }
bool hasTimeout() const { return config.timeoutInMS != 0; }
void setQueryTimeOut(uint64_t timeoutInMS);
uint64_t getQueryTimeOut() const;
void startTimer();
uint64_t getTimeoutRemainingInMS() const;
void resetActiveQuery() { activeQuery.reset(); }

// Parallelism
void setMaxNumThreadForExec(uint64_t numThreads);
uint64_t getMaxNumThreadForExec() const;

// Transaction.
transaction::Transaction* getTx() const;
KUZU_API transaction::TransactionContext* getTransactionContext() const;

// Replace function.
inline bool hasReplaceFunc() { return replaceFunc != nullptr; }
inline void setReplaceFunc(replace_func_t func) { replaceFunc = func; }

// Extension
KUZU_API void setExtensionOption(std::string name, common::Value value);

common::RandomEngine* getRandomEngine() { return randomEngine.get(); }

common::VirtualFileSystem* getVFSUnsafe() const;

std::string getExtensionDir() const;

// Environment.
KUZU_API std::string getEnvVariable(const std::string& name);

// Database component getters.
KUZU_API Database* getDatabase() const { return database; }
storage::StorageManager* getStorageManager();
storage::MemoryManager* getMemoryManager();
catalog::Catalog* getCatalog();
common::VirtualFileSystem* getVFSUnsafe() const;
common::RandomEngine* getRandomEngine();

KUZU_API std::string getEnvVariable(const std::string& name);

// Query.
std::unique_ptr<PreparedStatement> prepare(std::string_view query);

void setQueryTimeOut(uint64_t timeoutInMS);

uint64_t getQueryTimeOut();

void setMaxNumThreadForExec(uint64_t numThreads);

uint64_t getMaxNumThreadForExec();

KUZU_API std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParams);

std::unique_ptr<QueryResult> query(std::string_view queryStatement);

void runQuery(std::string query);

private:
inline void resetActiveQuery() { activeQuery.reset(); }

std::unique_ptr<QueryResult> query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true);

Expand Down Expand Up @@ -152,17 +135,19 @@ class ClientContext {

void commitUDFTrx(bool isAutoCommitTrx);

uint64_t numThreadsForExecution;
// Client side configurable settings.
ClientConfig config;
// Current query.
ActiveQuery activeQuery;
uint64_t timeoutInMS;
uint32_t varLengthExtendMaxDepth;
// Transaction context.
std::unique_ptr<transaction::TransactionContext> transactionContext;
bool enableSemiMask;
// Replace external object as pointer Value;
replace_func_t replaceFunc;
// Extension configurable settings.
std::unordered_map<std::string, common::Value> extensionOptionValues;
// Random generator for UUID.
std::unique_ptr<common::RandomEngine> randomEngine;
std::string homeDirectory;
std::string fileSearchPath;
// Attached database.
Database* database;
std::mutex mtx;
};
Expand Down
25 changes: 12 additions & 13 deletions src/include/main/settings.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ struct ThreadsSetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64);
context->numThreadsForExecution = parameter.getValue<int64_t>();
context->getClientConfigUnsafe()->numThreads = parameter.getValue<int64_t>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value(context->numThreadsForExecution);
return common::Value(context->getClientConfig()->numThreads);
}
};

Expand All @@ -23,11 +23,10 @@ struct TimeoutSetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64);
context->timeoutInMS = parameter.getValue<int64_t>();
context->startTimingIfEnabled();
context->getClientConfigUnsafe()->timeoutInMS = parameter.getValue<int64_t>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value(context->timeoutInMS);
return common::Value(context->getClientConfig()->timeoutInMS);
}
};

Expand All @@ -36,10 +35,10 @@ struct VarLengthExtendMaxDepthSetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64);
context->varLengthExtendMaxDepth = parameter.getValue<int64_t>();
context->getClientConfigUnsafe()->varLengthMaxDepth = parameter.getValue<int64_t>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value(context->varLengthExtendMaxDepth);
return common::Value(context->getClientConfig()->varLengthMaxDepth);
}
};

Expand All @@ -48,10 +47,10 @@ struct EnableSemiMaskSetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::BOOL;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::BOOL);
context->enableSemiMask = parameter.getValue<bool>();
context->getClientConfigUnsafe()->enableSemiMask = parameter.getValue<bool>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value(context->enableSemiMask);
return common::Value(context->getClientConfig()->enableSemiMask);
}
};

Expand All @@ -60,10 +59,10 @@ struct HomeDirectorySetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::STRING;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::STRING);
context->homeDirectory = parameter.getValue<std::string>();
context->getClientConfigUnsafe()->homeDirectory = parameter.getValue<std::string>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value::createValue(context->homeDirectory);
return common::Value::createValue(context->getClientConfig()->homeDirectory);
}
};

Expand All @@ -72,10 +71,10 @@ struct FileSearchPathSetting {
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::STRING;
static void setContext(ClientContext* context, const common::Value& parameter) {
KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::STRING);
context->fileSearchPath = parameter.getValue<std::string>();
context->getClientConfigUnsafe()->fileSearchPath = parameter.getValue<std::string>();
}
static common::Value getSetting(ClientContext* context) {
return common::Value::createValue(context->fileSearchPath);
return common::Value::createValue(context->getClientConfig()->fileSearchPath);
}
};

Expand Down
11 changes: 1 addition & 10 deletions src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include "common/exception/interrupt.h"
#include "processor/execution_context.h"
#include "processor/result/result_set.h"

Expand Down Expand Up @@ -130,15 +129,7 @@ class PhysicalOperator {
// Local state is initialized for each thread.
void initLocalState(ResultSet* resultSet, ExecutionContext* context);

inline bool getNextTuple(ExecutionContext* context) {
if (context->clientContext->isInterrupted()) {
throw common::InterruptException{};
}
metrics->executionTime.start();
auto result = getNextTuplesInternal(context);
metrics->executionTime.stop();
return result;
}
bool getNextTuple(ExecutionContext* context);

std::unordered_map<std::string, std::string> getProfilerKeyValAttributes(
common::Profiler& profiler) const;
Expand Down
Loading

0 comments on commit c554a20

Please sign in to comment.