Skip to content

Commit

Permalink
Added progress for in_query_call operators (kuzudb#3120)
Browse files Browse the repository at this point in the history
I have read and agree to the CLA of the Kuzu repository.
  • Loading branch information
MSebanc authored Mar 28, 2024
1 parent 37b58bb commit 20e5cbb
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 45 deletions.
10 changes: 8 additions & 2 deletions src/common/task_system/progress_bar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@ void ProgressBar::finishPipeline() {
return;
}
numPipelinesFinished++;
if (printing) {
std::cout << "\033[1A\033[2K\033[1B";
}
// This ensures that the progress bar is updated back to 0% after a pipeline is finished.
prevCurPipelineProgress = -0.01;
updateProgress(0.0);
}

void ProgressBar::updateProgress(double curPipelineProgress) {
// Only update the progress bar if the progress has changed by at least 1%.
if (!trackProgress || curPipelineProgress - prevCurPipelineProgress < 0.01) {
if (!trackProgress) {
return;
}
std::lock_guard<std::mutex> lock(progressBarLock);
// Only update the progress bar if the progress has changed by at least 1%.
if (curPipelineProgress - prevCurPipelineProgress < 0.01) {
return;
}
prevCurPipelineProgress = curPipelineProgress;
if (printing) {
std::cout << "\033[2A";
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/table/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ struct ScanSharedState : public BaseScanSharedState {

struct ScanFileSharedState : public ScanSharedState {
main::ClientContext* context;
uint64_t totalSize;

ScanFileSharedState(
common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context)
: ScanSharedState{std::move(readerConfig), numRows}, context{context} {}
: ScanSharedState{std::move(readerConfig), numRows}, context{context}, totalSize{0} {}
};

} // namespace function
Expand Down
8 changes: 8 additions & 0 deletions src/include/function/table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ using table_func_init_shared_t =
using table_func_init_local_t = std::function<std::unique_ptr<TableFuncLocalState>(
TableFunctionInitInput&, TableFuncSharedState*, storage::MemoryManager*)>;
using table_func_can_parallel_t = std::function<bool()>;
using table_func_progress_t = std::function<double(TableFuncSharedState* sharedState)>;

struct TableFunction final : public Function {
table_func_t tableFunc;
table_func_bind_t bindFunc;
table_func_init_shared_t initSharedStateFunc;
table_func_init_local_t initLocalStateFunc;
table_func_can_parallel_t canParallelFunc = [] { return true; };
table_func_progress_t progressFunc = [](TableFuncSharedState* /*sharedState*/) { return 0.0; };

TableFunction()
: Function{}, tableFunc{nullptr}, bindFunc{nullptr}, initSharedStateFunc{nullptr},
Expand All @@ -80,6 +82,12 @@ struct TableFunction final : public Function {
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc} {}
TableFunction(std::string name, table_func_t tableFunc, table_func_bind_t bindFunc,
table_func_init_shared_t initSharedFunc, table_func_init_local_t initLocalFunc,
table_func_progress_t progressFunc, std::vector<common::LogicalTypeID> inputTypes)
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc}, progressFunc{progressFunc} {}

inline std::string signatureToString() const override {
return common::LogicalTypeUtils::toString(parameterTypeIDs);
Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/operator/call/in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class InQueryCall : public PhysicalOperator {

bool getNextTuplesInternal(ExecutionContext* context) override;

double getProgress(ExecutionContext* context) const override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<InQueryCall>(info.copy(), sharedState, id, paramsString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class BaseCSVReader {

uint64_t countRows();
bool isEOF() const;
uint64_t getFileSize();
// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;

protected:
template<typename Driver>
Expand Down Expand Up @@ -56,8 +59,6 @@ class BaseCSVReader {

inline bool isNewLine(char c) { return c == '\n' || c == '\r'; }

// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;
uint64_t getLineNumber();

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ struct ParallelCSVScanSharedState final : public function::ScanFileSharedState {
explicit ParallelCSVScanSharedState(common::ReaderConfig readerConfig, uint64_t numRows,
uint64_t numColumns, main::ClientContext* context, common::CSVReaderConfig csvReaderConfig)
: ScanFileSharedState{std::move(readerConfig), numRows, context}, numColumns{numColumns},
csvReaderConfig{std::move(csvReaderConfig)} {}
numBlocksReadByFiles{0}, csvReaderConfig{std::move(csvReaderConfig)} {}

void setFileComplete(uint64_t completedFileIdx);

uint64_t numColumns;
uint64_t numBlocksReadByFiles = 0;
common::CSVReaderConfig csvReaderConfig;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class SerialCSVReader final : public BaseCSVReader {
struct SerialCSVScanSharedState final : public function::ScanFileSharedState {
std::unique_ptr<SerialCSVReader> reader;
uint64_t numColumns;
uint64_t totalReadSizeByFile;
common::CSVReaderConfig csvReaderConfig;

SerialCSVScanSharedState(common::ReaderConfig readerConfig, uint64_t numRows,
uint64_t numColumns, common::CSVReaderConfig csvReaderConfig, main::ClientContext* context)
: ScanFileSharedState{std::move(readerConfig), numRows, context}, numColumns{numColumns},
csvReaderConfig{std::move(csvReaderConfig)} {
totalReadSizeByFile{0}, csvReaderConfig{std::move(csvReaderConfig)} {
initReader(context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ struct ParquetScanSharedState final : public function::ScanFileSharedState {
const common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context);

std::vector<std::unique_ptr<ParquetReader>> readers;
uint64_t totalRowsGroups;
uint64_t numBlocksReadByFiles;
};

struct ParquetScanLocalState final : public function::TableFuncLocalState {
Expand Down
4 changes: 4 additions & 0 deletions src/processor/operator/call/in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@ bool InQueryCall::getNextTuplesInternal(ExecutionContext*) {
return numTuplesScanned != 0;
}

double InQueryCall::getProgress(ExecutionContext* /*context*/) const {
return info.function.progressFunc(sharedState->funcState.get());
}

} // namespace processor
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ bool BaseCSVReader::isEOF() const {
return getFileOffset() >= fileInfo->getFileSize();
}

uint64_t BaseCSVReader::getFileSize() {
return fileInfo->getFileSize();
}

template<typename Driver>
void BaseCSVReader::addValue(Driver& driver, uint64_t rowNum, column_id_t columnIdx,
std::string_view strVal, std::vector<uint64_t>& escapePositions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,18 @@ bool ParallelCSVReader::finishedBlock() const {
void ParallelCSVScanSharedState::setFileComplete(uint64_t completedFileIdx) {
std::lock_guard<std::mutex> guard{lock};
if (completedFileIdx == fileIdx) {
numBlocksReadByFiles += blockIdx;
blockIdx = 0;
fileIdx++;
}
}

static offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto& outputChunk = output.dataChunk;
auto parallelCSVLocalState = reinterpret_cast<ParallelCSVLocalState*>(input.localState);
auto parallelCSVSharedState = reinterpret_cast<ParallelCSVScanSharedState*>(input.sharedState);
auto parallelCSVLocalState =
ku_dynamic_cast<TableFuncLocalState*, ParallelCSVLocalState*>(input.localState);
auto parallelCSVSharedState =
ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(input.sharedState);
do {
if (parallelCSVLocalState->reader != nullptr &&
parallelCSVLocalState->reader->hasMoreToRead()) {
Expand Down Expand Up @@ -152,7 +155,7 @@ static offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {

static std::unique_ptr<TableFuncBindData> bindFunc(
main::ClientContext* /*context*/, TableFuncBindInput* input) {
auto scanInput = reinterpret_cast<ScanTableFuncBindInput*>(input);
auto scanInput = ku_dynamic_cast<TableFuncBindInput*, ScanTableFuncBindInput*>(input);
std::vector<std::string> detectedColumnNames;
std::vector<LogicalType> detectedColumnTypes;
SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes);
Expand All @@ -165,28 +168,48 @@ static std::unique_ptr<TableFuncBindData> bindFunc(
}

static std::unique_ptr<TableFuncSharedState> initSharedState(TableFunctionInitInput& input) {
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto bindData = ku_dynamic_cast<TableFuncBindData*, ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), bindData->context, csvConfig.copy());
auto sharedState = std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(),
numRows, bindData->columnNames.size(), bindData->context, csvConfig.copy());
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader = std::make_unique<ParallelCSVReader>(filePath,
sharedState->csvReaderConfig.option.copy(), sharedState->numColumns,
sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
TableFuncSharedState* state, storage::MemoryManager* /*mm*/) {
auto localState = std::make_unique<ParallelCSVLocalState>();
auto sharedState = reinterpret_cast<ParallelCSVScanSharedState*>(state);
auto sharedState = ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(state);
localState->reader = std::make_unique<ParallelCSVReader>(sharedState->readerConfig.filePaths[0],
sharedState->csvReaderConfig.option.copy(), sharedState->numColumns, sharedState->context);
localState->fileIdx = 0;
return localState;
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(sharedState);
if (state->fileIdx >= state->readerConfig.getNumFiles()) {
return 1.0;
}
if (state->totalSize == 0) {
return 0.0;
}
uint64_t totalReadSize =
(state->numBlocksReadByFiles + state->blockIdx) * CopyConstants::PARALLEL_BLOCK_SIZE;
return static_cast<double>(totalReadSize) / state->totalSize;
}

function_set ParallelCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
34 changes: 27 additions & 7 deletions src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void SerialCSVScanSharedState::read(DataChunk& outputChunk) {
if (numRows > 0) {
return;
}
totalReadSizeByFile += reader->getFileSize();
fileIdx++;
initReader(context);
} while (true);
Expand All @@ -68,7 +69,8 @@ void SerialCSVScanSharedState::initReader(main::ClientContext* context) {
}

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto serialCSVScanSharedState = reinterpret_cast<SerialCSVScanSharedState*>(input.sharedState);
auto serialCSVScanSharedState =
ku_dynamic_cast<TableFuncSharedState*, SerialCSVScanSharedState*>(input.sharedState);
serialCSVScanSharedState->read(output.dataChunk);
return output.dataChunk.state->selVector->selectedSize;
}
Expand Down Expand Up @@ -99,7 +101,7 @@ void SerialCSVScan::bindColumns(const ScanTableFuncBindInput* bindInput,

static std::unique_ptr<TableFuncBindData> bindFunc(
main::ClientContext* /*context*/, TableFuncBindInput* input) {
auto scanInput = reinterpret_cast<ScanTableFuncBindInput*>(input);
auto scanInput = ku_dynamic_cast<TableFuncBindInput*, ScanTableFuncBindInput*>(input);
std::vector<std::string> detectedColumnNames;
std::vector<LogicalType> detectedColumnTypes;
SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes);
Expand All @@ -112,23 +114,41 @@ static std::unique_ptr<TableFuncBindData> bindFunc(
}

static std::unique_ptr<TableFuncSharedState> initSharedState(TableFunctionInitInput& input) {
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto bindData = ku_dynamic_cast<TableFuncBindData*, ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
auto sharedState = std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), csvConfig.copy(), bindData->context);
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader =
std::make_unique<SerialCSVReader>(filePath, sharedState->csvReaderConfig.option.copy(),
sharedState->numColumns, sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
TableFuncSharedState* /*state*/, storage::MemoryManager* /*mm*/) {
return std::make_unique<TableFuncLocalState>();
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = ku_dynamic_cast<TableFuncSharedState*, SerialCSVScanSharedState*>(sharedState);
if (state->totalSize == 0) {
return 0.0;
} else if (state->fileIdx >= state->readerConfig.getNumFiles()) {
return 1.0;
}
uint64_t totalReadSize = state->totalReadSizeByFile + state->reader->getFileOffset();
return static_cast<double>(totalReadSize) / state->totalSize;
}

function_set SerialCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
Loading

0 comments on commit 20e5cbb

Please sign in to comment.