diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala index 7bccb6dfb598..90e09e75f1ff 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala @@ -40,6 +40,7 @@ class GlutenClickHouseTPCHBucketSuite override protected val queriesResults: String = rootPath + "bucket-queries-output" override protected def sparkConf: SparkConf = { + import org.apache.gluten.backendsapi.clickhouse.CHConf._ super.sparkConf .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") .set("spark.io.compression.codec", "LZ4") @@ -47,9 +48,7 @@ class GlutenClickHouseTPCHBucketSuite .set("spark.sql.autoBroadcastJoinThreshold", "-1") // for test bucket join .set("spark.sql.adaptive.enabled", "true") .set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", "sparkMurmurHash3_32") - .set( - "spark.gluten.sql.columnar.backend.ch.runtime_config.enable_grace_aggregate_spill_test", - "true") + .setCHConfig("enable_grace_aggregate_spill_test", "true") } override protected val createNullableTables = true diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSuite.scala index 0ba7de90c670..f0025cf30cad 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSuite.scala @@ -30,6 +30,7 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui /** Run Gluten + ClickHouse Backend with SortShuffleManager */ override protected def sparkConf: SparkConf = { + import org.apache.gluten.backendsapi.clickhouse.CHConf._ super.sparkConf .set("spark.shuffle.manager", "sort") .set("spark.io.compression.codec", "snappy") @@ -38,9 +39,7 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui .set("spark.memory.offHeap.size", "4g") .set("spark.gluten.sql.validation.logLevel", "ERROR") .set("spark.gluten.sql.validation.printStackOnFailure", "true") - .set( - "spark.gluten.sql.columnar.backend.ch.runtime_config.enable_grace_aggregate_spill_test", - "true") + .setCHConfig("enable_grace_aggregate_spill_test", "true") } executeTPCDSTest(false) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala index ad9cb854d97b..b4186fee66aa 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala @@ -51,6 +51,7 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite .set("spark.sql.adaptive.enabled", "true") .set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", "sparkMurmurHash3_32") .setCHConfig("enable_streaming_aggregating", true) + .set(GlutenConfig.COLUMNAR_CH_SHUFFLE_SPILL_THRESHOLD.key, (1024 * 1024).toString) } override protected def createTPCHNotNullTables(): Unit = { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala index a257e2ed5094..7a927bf23a49 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala @@ -46,6 +46,7 @@ class GlutenClickHouseTPCHParquetBucketSuite protected val bucketTableDataPath: String = basePath + "/tpch-parquet-bucket" override protected def sparkConf: SparkConf = { + import org.apache.gluten.backendsapi.clickhouse.CHConf._ super.sparkConf .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") .set("spark.io.compression.codec", "LZ4") @@ -53,9 +54,7 @@ class GlutenClickHouseTPCHParquetBucketSuite .set("spark.sql.autoBroadcastJoinThreshold", "-1") // for test bucket join .set("spark.sql.adaptive.enabled", "true") .set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", "sparkMurmurHash3_32") - .set( - "spark.gluten.sql.columnar.backend.ch.runtime_config.enable_grace_aggregate_spill_test", - "true") + .setCHConfig("enable_grace_aggregate_spill_test", "true") } override protected val createNullableTables = true diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 003f11113345..edb13fdc5715 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,3 @@ CH_ORG=Kyligence -CH_BRANCH=rebase_ch/20241118 -CH_COMMIT=a5944dfb7b3 +CH_BRANCH=rebase_ch/20241129 +CH_COMMIT=101ba3f944d1 diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 88c5303c50e4..310b39d3e594 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -76,6 +76,12 @@ namespace DB { +namespace ServerSetting +{ +extern const ServerSettingsString primary_index_cache_policy; +extern const ServerSettingsUInt64 primary_index_cache_size; +extern const ServerSettingsDouble primary_index_cache_size_ratio; +} namespace Setting { extern const SettingsUInt64 prefer_external_sort_block_bytes; @@ -712,11 +718,13 @@ void BackendInitializerUtil::initSettings(const SparkConfigs::ConfigMap & spark_ settings.set("input_format_parquet_enable_row_group_prefetch", false); settings.set("output_format_parquet_use_custom_encoder", false); - /// update per https://github.com/ClickHouse/ClickHouse/pull/71539 + /// Set false after https://github.com/ClickHouse/ClickHouse/pull/71539 /// if true, we can't get correct metrics for the query settings[Setting::query_plan_merge_filters] = false; + /// We now set BuildQueryPipelineSettings according to config. - settings[Setting::compile_expressions] = true; + // TODO: FIXME. Set false after https://github.com/ClickHouse/ClickHouse/pull/70598. + settings[Setting::compile_expressions] = false; settings[Setting::short_circuit_function_evaluation] = ShortCircuitFunctionEvaluation::DISABLE; /// @@ -820,6 +828,10 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) /// Make sure global_context and shared_context are constructed only once. if (auto global_context = QueryContext::globalMutableContext(); !global_context) { + ServerSettings server_settings; + server_settings.loadSettingsFromConfig(*config); + + auto log = getLogger("CHUtil"); global_context = QueryContext::createGlobal(); global_context->makeGlobalContext(); global_context->setConfig(config); @@ -844,10 +856,16 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) size_t mark_cache_size = config->getUInt64("mark_cache_size", DEFAULT_MARK_CACHE_MAX_SIZE); double mark_cache_size_ratio = config->getDouble("mark_cache_size_ratio", DEFAULT_MARK_CACHE_SIZE_RATIO); if (!mark_cache_size) - LOG_ERROR(&Poco::Logger::get("CHUtil"), "Too low mark cache size will lead to severe performance degradation."); - + LOG_ERROR(log, "Mark cache is disabled, it will lead to severe performance degradation."); + LOG_INFO(log, "mark cache size to {}.", formatReadableSizeWithBinarySuffix(mark_cache_size)); global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio); + String primary_index_cache_policy = server_settings[ServerSetting::primary_index_cache_policy]; + size_t primary_index_cache_size = server_settings[ServerSetting::primary_index_cache_size]; + double primary_index_cache_size_ratio = server_settings[ServerSetting::primary_index_cache_size_ratio]; + LOG_INFO(log, "Primary index cache size to {}.", formatReadableSizeWithBinarySuffix(primary_index_cache_size)); + global_context->setPrimaryIndexCache(primary_index_cache_policy, primary_index_cache_size, primary_index_cache_size_ratio); + String index_uncompressed_cache_policy = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY); size_t index_uncompressed_cache_size diff --git a/cpp-ch/local-engine/Common/GlutenSignalHandler.cpp b/cpp-ch/local-engine/Common/GlutenSignalHandler.cpp index 44c43fcb65aa..712d8ddcf5cc 100644 --- a/cpp-ch/local-engine/Common/GlutenSignalHandler.cpp +++ b/cpp-ch/local-engine/Common/GlutenSignalHandler.cpp @@ -104,7 +104,7 @@ static void writeSignalIDtoSignalPipe(int sig) char buf[signal_pipe_buf_size]; WriteBufferFromFileDescriptor out(writeFD(), signal_pipe_buf_size, buf); writeBinary(sig, out); - out.next(); + out.finalize(); errno = saved_errno; } @@ -251,9 +251,7 @@ class SignalListener : public Poco::Runnable query = thread_ptr->getQueryForLog(); if (auto logs_queue = thread_ptr->getInternalTextLogsQueue()) - { CurrentThread::attachInternalTextLogsQueue(logs_queue, LogsLevel::trace); - } } std::string signal_description = "Unknown signal"; diff --git a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h index 6930c1d75b79..e9b66df84ef0 100644 --- a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h +++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h @@ -14,9 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#pragma once #include +#include +#include +#include +#include namespace DB { @@ -64,7 +67,7 @@ class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric else { auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1); - if (cmp_result < 0) + if (cmp_result < 0) best_arg = arg; } } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index 4c2847d9f92a..bf65b253479b 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -14,13 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include +#include +#include #include -#include -#include +#include +#include #include +#include +#include +#include +#include using namespace DB; diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index e43b52823175..ea841632a984 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -14,12 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include -#include -#include -#include +#include +#include #include +#include +#include using namespace DB; @@ -92,7 +93,7 @@ class SparkFunctionArraysOverlap : public IFunction { res_data[i] = 1; null_map_data[i] = 0; - break; + break; } } } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp index c75d25b6ef80..8b5a7eff65db 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp index 830fc0e65287..f89943fc7a3d 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -26,8 +27,6 @@ #include #include #include -#include -#include #include namespace DB diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp index 795e2b0be329..f136f587c539 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp @@ -15,12 +15,12 @@ * limitations under the License. */ #include +#include #include #include #include #include -#include "SparkFunctionCheckDecimalOverflow.h" - +#include namespace DB { diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h index 3f8a0c97dc07..5541245244a7 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h @@ -14,17 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#pragma once #include +#include #include #include -#include #include #include -#include #include +#include #include #include +#include #include namespace DB diff --git a/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp b/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp index 66f37c62033f..1868c40c0f31 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp @@ -20,8 +20,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 95086121d4cb..820a99ad3bfb 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -160,15 +161,17 @@ void SerializedPlanParser::adjustOutput(const DB::QueryPlanPtr & query_plan, con else { need_final_project = true; - bool need_const = origin_column.column && isColumnConst(*origin_column.column); - if (need_const) + if (origin_column.column && isColumnConst(*origin_column.column)) { + /// For const column, we need to cast it individually. Otherwise, the const column will be converted to full column in + /// ActionsDAG::makeConvertingActions. + /// Note: creating fianl_column with Field of origin_column will cause Exception in some case. const DB::ContextPtr context = DB::CurrentThread::get().getQueryContext(); const FunctionOverloadResolverPtr & cast_resolver = FunctionFactory::instance().get("CAST", context); const DataTypePtr string_type = std::make_shared(); ColumnWithTypeAndName to_type_column = {string_type->createColumnConst(1, final_type->getName()), string_type, "__cast_const__"}; FunctionBasePtr cast_function = cast_resolver->build({origin_column, to_type_column}); - ColumnPtr const_col = ColumnConst::create(cast_function->execute({origin_column, to_type_column}, final_type, 1), 1); + ColumnPtr const_col = ColumnConst::create(cast_function->execute({origin_column, to_type_column}, final_type, 1, false), 1); ColumnWithTypeAndName final_column(const_col, final_type, origin_column.name); final_columns.emplace_back(std::move(final_column)); } @@ -310,7 +313,7 @@ DB::QueryPipelineBuilderPtr SerializedPlanParser::buildQueryPipeline(DB::QueryPl BuildQueryPipelineSettings build_settings = BuildQueryPipelineSettings::fromContext(context); build_settings.process_list_element = query_status; build_settings.progress_callback = nullptr; - return query_plan.buildQueryPipeline(optimization_settings,build_settings); + return query_plan.buildQueryPipeline(optimization_settings, build_settings); } std::unique_ptr SerializedPlanParser::createExecutor(const std::string_view plan) diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp index 8857fbbf5df6..a9a0f305a0ac 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp @@ -15,6 +15,7 @@ * limitations under the License. */ #include +#include #include #include #include diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp index f45544cfa4d1..27bd9f84a9e6 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp @@ -15,11 +15,11 @@ * limitations under the License. */ #include +#include #include #include #include #include - namespace DB { namespace ErrorCodes diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp index 581fd6f66589..04f4f64e7bd2 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp @@ -16,6 +16,7 @@ * limitations under the License. */ #include +#include #include #include #include @@ -56,7 +57,7 @@ class FunctionParserArrayRepeat : public FunctionParser const auto * const_zero_node = addColumnToActionsDAG(actions_dag, n_not_null_arg->result_type, {0}); const auto * greatest_node = toFunctionNode(actions_dag, "greatest", {n_not_null_arg, const_zero_node}); const auto * range_node = toFunctionNode(actions_dag, "range", {greatest_node}); - const auto & range_type = assert_cast(*removeNullable(range_node->result_type)); + const auto & range_type = assert_cast(*removeNullable(range_node->result_type)); // Create lambda function x -> elem ActionsDAG lambda_actions_dag; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp index ebd89f8fa8e8..d663052a8ea7 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp @@ -15,8 +15,9 @@ * limitations under the License. */ -#include +#include #include +#include namespace DB { @@ -57,7 +58,7 @@ class FunctionParserBitLength : public FunctionParser const auto * const_eight_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 8); const auto * result_node = toFunctionNode(actions_dag, "multiply", {octet_length_node, const_eight_node}); - return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);; + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); } }; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp index 1586870541b6..80797dedcab5 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp @@ -16,7 +16,9 @@ */ #include +#include #include +#include #include namespace DB diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp index 34a7348b9ac6..b4e4ad119fdc 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp @@ -15,9 +15,10 @@ * limitations under the License. */ -#include +#include #include #include +#include namespace DB { @@ -71,7 +72,7 @@ class FunctionParserLength : public FunctionParser else result_node = toFunctionNode(actions_dag, "char_length", {new_arg}); - return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);; + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); } }; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp index c4f0234957af..a9ad4d9aee4a 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp @@ -15,10 +15,11 @@ * limitations under the License. */ +#include +#include #include #include #include -#include namespace DB { diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp index 1dd61587e011..d57d8d661b50 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp @@ -15,8 +15,9 @@ * limitations under the License. */ -#include +#include #include +#include namespace DB { @@ -52,7 +53,7 @@ class FunctionParserOctetLength : public FunctionParser new_arg = toFunctionNode(actions_dag, "CAST", {arg, string_type_node}); } const auto * octet_length_node = toFunctionNode(actions_dag, "octet_length", {new_arg}); - return convertNodeTypeIfNeeded(substrait_func, octet_length_node, actions_dag);; + return convertNodeTypeIfNeeded(substrait_func, octet_length_node, actions_dag); } }; diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index aa9e15b9f036..43459f20c5b3 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -172,6 +172,8 @@ size_t LocalPartitionWriter::evictPartitions() split_result->total_compress_time += compressed_output.getCompressTime(); split_result->total_write_time += compressed_output.getWriteTime(); split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + compressed_output.finalize(); + output.finalize(); }; Stopwatch spill_time_watch; @@ -342,6 +344,8 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() split_result->total_compress_time += compressed_output.getCompressTime(); split_result->total_io_time += compressed_output.getWriteTime(); split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + compressed_output.finalize(); + output.finalize(); }; Stopwatch spill_time_watch; @@ -428,6 +432,8 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() split_result->total_compress_time += compressed_output.getCompressTime(); split_result->total_io_time += compressed_output.getWriteTime(); split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + compressed_output.finalize(); + output.finalize(); }; Stopwatch spill_time_watch; diff --git a/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp b/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp index 02baa4a9c09c..ab4cfc18c89d 100644 --- a/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp +++ b/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp @@ -17,6 +17,7 @@ #include "SelectorBuilder.h" #include #include +#include #include #include #include diff --git a/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp index 8aa624ff9979..7167dabfad55 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp @@ -53,23 +53,20 @@ void ShuffleWriter::write(const Block & block) native_writer->write(block); } } -void ShuffleWriter::flush() +void ShuffleWriter::flush() const { if (native_writer) - { native_writer->flush(); - } } + ShuffleWriter::~ShuffleWriter() { if (native_writer) - { native_writer->flush(); - if (compression_enable) - { - compressed_out->finalize(); - } - write_buffer->finalize(); - } + + if (compression_enable) + compressed_out->finalize(); + + write_buffer->finalize(); } } diff --git a/cpp-ch/local-engine/Shuffle/ShuffleWriter.h b/cpp-ch/local-engine/Shuffle/ShuffleWriter.h index 541e93e0347c..94886210c1d2 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleWriter.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleWriter.h @@ -27,7 +27,7 @@ class ShuffleWriter jobject output_stream, jbyteArray buffer, const std::string & codecStr, jint level, bool enable_compression, size_t customize_buffer_size); virtual ~ShuffleWriter(); void write(const DB::Block & block); - void flush(); + void flush() const; private: std::unique_ptr compressed_out; diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp index a78d615be62b..c40b474e7a8b 100644 --- a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp @@ -16,15 +16,15 @@ */ #include "SparkExchangeSink.h" +#include #include +#include #include -#include +#include +#include #include +#include #include -#include -#include -#include -#include namespace DB @@ -74,7 +74,7 @@ void SparkExchangeSink::consume(Chunk chunk) void SparkExchangeSink::onFinish() { Stopwatch wall_time; - if (!dynamic_cast(partition_writer.get())) + if (!dynamic_cast(partition_writer.get())) { partition_writer->evictPartitions(); } @@ -222,8 +222,7 @@ void SparkExchangeManager::finish() std::vector extra_datas; for (const auto & writer : partition_writers) { - LocalPartitionWriter * local_partition_writer = dynamic_cast(writer.get()); - if (local_partition_writer) + if (LocalPartitionWriter * local_partition_writer = dynamic_cast(writer.get())) { extra_datas.emplace_back(local_partition_writer->getExtraData()); } @@ -232,12 +231,13 @@ void SparkExchangeManager::finish() chassert(extra_datas.size() == partition_writers.size()); WriteBufferFromFile output(options.data_file, options.io_buffer_size); split_result.partition_lengths = mergeSpills(output, infos, extra_datas); + output.finalize(); } split_result.wall_time += wall_time.elapsedNanoseconds(); } -void checkPartitionLengths(const std::vector & partition_length,size_t partition_num) +void checkPartitionLengths(const std::vector & partition_length, size_t partition_num) { if (partition_num != partition_length.size()) { @@ -284,7 +284,7 @@ void SparkExchangeManager::mergeSplitResult() std::vector SparkExchangeManager::gatherAllSpillInfo() const { std::vector res; - for (const auto& writer : partition_writers) + for (const auto & writer : partition_writers) { if (Spillable * spillable = dynamic_cast(writer.get())) { diff --git a/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.cpp b/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.cpp index cecb6308745c..ee6930e4de51 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.cpp +++ b/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.cpp @@ -15,11 +15,11 @@ * limitations under the License. */ #include "MergeSparkMergeTreeTask.h" -#include #include #include #include +#include #include #include #include @@ -94,6 +94,12 @@ bool MergeSparkMergeTreeTask::executeStep() } +void MergeSparkMergeTreeTask::cancel() noexcept +{ + if (merge_task) + merge_task->cancel(); +} + void MergeSparkMergeTreeTask::prepare() { future_part = merge_mutate_entry->future_part; diff --git a/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.h b/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.h index ac167da3fb49..60b3328f0d1b 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.h +++ b/cpp-ch/local-engine/Storages/MergeTree/MergeSparkMergeTreeTask.h @@ -65,6 +65,7 @@ class MergeSparkMergeTreeTask : public IExecutableTask txn_holder = std::move(txn_holder_); txn = std::move(txn_); } + void cancel() noexcept override; private: void prepare(); @@ -116,7 +117,7 @@ class MergeSparkMergeTreeTask : public IExecutableTask using MergeSparkMergeTreeTaskPtr = std::shared_ptr; -[[ maybe_unused ]] static void executeHere(MergeSparkMergeTreeTaskPtr task) +[[maybe_unused]] static void executeHere(MergeSparkMergeTreeTaskPtr task) { while (task->executeStep()) {} } diff --git a/cpp-ch/local-engine/Storages/MergeTree/MetaDataHelper.cpp b/cpp-ch/local-engine/Storages/MergeTree/MetaDataHelper.cpp index 958421022ba2..84dbc3a8d3bb 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/MetaDataHelper.cpp +++ b/cpp-ch/local-engine/Storages/MergeTree/MetaDataHelper.cpp @@ -190,6 +190,7 @@ void restoreMetaData( auto item_path = part_path / item.first; auto out = metadata_disk->writeFile(item_path); out->write(item.second.data(), item.second.size()); + out->finalize(); } }; thread_pool.scheduleOrThrow(job); diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkStorageMergeTree.cpp b/cpp-ch/local-engine/Storages/MergeTree/SparkStorageMergeTree.cpp index 17587e5200ef..5669489f5477 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/SparkStorageMergeTree.cpp +++ b/cpp-ch/local-engine/Storages/MergeTree/SparkStorageMergeTree.cpp @@ -504,7 +504,6 @@ MergeTreeDataWriter::TemporaryPart SparkMergeTreeDataWriter::writeTempPart( txn ? txn->tid : Tx::PrehistoricTID, false, false, - false, context->getWriteSettings()); out->writeWithPermutation(block, perm_ptr); diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp index e5a2d89f26c8..ad2e3abf7b52 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp @@ -191,7 +191,11 @@ void NormalFileWriter::close() /// When insert into a table with empty dataset, NormalFileWriter::consume would be never called. /// So we need to skip when writer is nullptr. if (writer) + { writer->finish(); + assert(output_format); + output_format->finalizeOutput(); + } } OutputFormatFilePtr createOutputFormatFile( diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h index d55703741845..8cfe079d92c5 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h @@ -293,9 +293,7 @@ class SubstraitFileSink final : public DB::SinkToStorage { if (output_format_) [[unlikely]] { - output_format_->output->finalize(); - output_format_->output->flush(); - output_format_->write_buffer->finalize(); + output_format_->finalizeOutput(); assert(delta_stats_.row_count > 0); if (stats_) stats_->collectStats(relative_path_, partition_id_, delta_stats_); diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h index e94923f77a43..915f9a7e7efa 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h @@ -29,9 +29,14 @@ class OutputFormatFile public: struct OutputFormat { - public: DB::OutputFormatPtr output; std::unique_ptr write_buffer; + void finalizeOutput() const + { + output->finalize(); + output->flush(); + write_buffer->finalize(); + } }; using OutputFormatPtr = std::shared_ptr; diff --git a/cpp-ch/local-engine/tests/benchmark_cast_float_function.cpp b/cpp-ch/local-engine/tests/benchmark_cast_float_function.cpp index 4ef9b5771af8..a50bcf170eff 100644 --- a/cpp-ch/local-engine/tests/benchmark_cast_float_function.cpp +++ b/cpp-ch/local-engine/tests/benchmark_cast_float_function.cpp @@ -52,7 +52,7 @@ static void BM_CHCastFloatToInt(benchmark::State & state) args.emplace_back(type_name_col); auto executable = function->build(args); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } static void BM_SparkCastFloatToInt(benchmark::State & state) @@ -63,7 +63,7 @@ static void BM_SparkCastFloatToInt(benchmark::State & state) Block block = createDataBlock(30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } BENCHMARK(BM_CHCastFloatToInt)->Unit(benchmark::kMillisecond)->Iterations(100); diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp index 13e74abaee1f..eacfb1781b26 100644 --- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp +++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp @@ -846,7 +846,7 @@ QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr right, String left_key, St auto hash_join = std::make_shared(join, right->getCurrentHeader()); QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentHeader(), right->getCurrentHeader(), hash_join, block_size, 1, false); + = std::make_unique(left->getCurrentHeader(), right->getCurrentHeader(), hash_join, block_size, 0, 1, false); std::vector plans; plans.emplace_back(std::move(left)); diff --git a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp b/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp index ef961f21cbb6..a672fdee350a 100644 --- a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp +++ b/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp @@ -66,7 +66,7 @@ static void BM_CHFloorFunction_For_Int64(benchmark::State & state) auto executable = function->build(int64_block.getColumnsWithTypeAndName()); for (auto _ : state) { - auto result = executable->execute(int64_block.getColumnsWithTypeAndName(), executable->getResultType(), int64_block.rows()); + auto result = executable->execute(int64_block.getColumnsWithTypeAndName(), executable->getResultType(), int64_block.rows(), false); benchmark::DoNotOptimize(result); } } @@ -80,7 +80,7 @@ static void BM_CHFloorFunction_For_Float64(benchmark::State & state) auto executable = function->build(float64_block.getColumnsWithTypeAndName()); for (auto _ : state) { - auto result = executable->execute(float64_block.getColumnsWithTypeAndName(), executable->getResultType(), float64_block.rows()); + auto result = executable->execute(float64_block.getColumnsWithTypeAndName(), executable->getResultType(), float64_block.rows(), false); benchmark::DoNotOptimize(result); } } @@ -94,7 +94,7 @@ static void BM_SparkFloorFunction_For_Int64(benchmark::State & state) auto executable = function->build(int64_block.getColumnsWithTypeAndName()); for (auto _ : state) { - auto result = executable->execute(int64_block.getColumnsWithTypeAndName(), executable->getResultType(), int64_block.rows()); + auto result = executable->execute(int64_block.getColumnsWithTypeAndName(), executable->getResultType(), int64_block.rows(), false); benchmark::DoNotOptimize(result); } } @@ -108,7 +108,7 @@ static void BM_SparkFloorFunction_For_Float64(benchmark::State & state) auto executable = function->build(float64_block.getColumnsWithTypeAndName()); for (auto _ : state) { - auto result = executable->execute(float64_block.getColumnsWithTypeAndName(), executable->getResultType(), float64_block.rows()); + auto result = executable->execute(float64_block.getColumnsWithTypeAndName(), executable->getResultType(), float64_block.rows(), false); benchmark::DoNotOptimize(result); } } diff --git a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp index c72125163351..49f9dde989e9 100644 --- a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp +++ b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp @@ -45,7 +45,7 @@ static void BM_CHParseDateTime64(benchmark::State & state) Block block = createDataBlock(30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } @@ -57,7 +57,7 @@ static void BM_SparkParseDateTime64(benchmark::State & state) Block block = createDataBlock(30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } BENCHMARK(BM_CHParseDateTime64)->Unit(benchmark::kMillisecond)->Iterations(50); diff --git a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp index e7abfda7a2b2..a7dc3ffa2b91 100644 --- a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp +++ b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp @@ -49,7 +49,7 @@ static void BM_CHUnixTimestamp_For_Date32(benchmark::State & state) Block block = createDataBlock("Date32", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } static void BM_CHUnixTimestamp_For_Date(benchmark::State & state) @@ -60,7 +60,7 @@ static void BM_CHUnixTimestamp_For_Date(benchmark::State & state) Block block = createDataBlock("Date", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } static void BM_SparkUnixTimestamp_For_Date32(benchmark::State & state) @@ -71,7 +71,7 @@ static void BM_SparkUnixTimestamp_For_Date32(benchmark::State & state) Block block = createDataBlock("Date32", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } static void BM_SparkUnixTimestamp_For_Date(benchmark::State & state) @@ -82,7 +82,7 @@ static void BM_SparkUnixTimestamp_For_Date(benchmark::State & state) Block block = createDataBlock("Date", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state) [[maybe_unused]] - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); } BENCHMARK(BM_CHUnixTimestamp_For_Date32)->Unit(benchmark::kMillisecond)->Iterations(100); diff --git a/cpp-ch/local-engine/tests/gluten_test_util.h b/cpp-ch/local-engine/tests/gluten_test_util.h index 799a6d7967dc..a61612662961 100644 --- a/cpp-ch/local-engine/tests/gluten_test_util.h +++ b/cpp-ch/local-engine/tests/gluten_test_util.h @@ -16,9 +16,9 @@ */ #pragma once +#include "testConfig.h" #include -#include #include #include #include diff --git a/cpp-ch/local-engine/tests/gtest_ch_functions.cpp b/cpp-ch/local-engine/tests/gtest_ch_functions.cpp index e905bc1787fa..3b91e0799404 100644 --- a/cpp-ch/local-engine/tests/gtest_ch_functions.cpp +++ b/cpp-ch/local-engine/tests/gtest_ch_functions.cpp @@ -47,7 +47,7 @@ TEST(TestFuntion, Hash) std::cerr << "input:\n"; debug::headBlock(block); auto executable = function->build(block.getColumnsWithTypeAndName()); - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); std::cerr << "output:\n"; debug::headColumn(result); ASSERT_EQ(result->getUInt(0), result->getUInt(1)); @@ -89,7 +89,7 @@ TEST(TestFunction, In) std::cerr << "input:\n"; debug::headBlock(block); auto executable = function->build(block.getColumnsWithTypeAndName()); - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); std::cerr << "output:\n"; debug::headColumn(result); ASSERT_EQ(result->getUInt(3), 0); @@ -133,7 +133,7 @@ TEST(TestFunction, NotIn1) std::cerr << "input:\n"; debug::headBlock(block); auto executable = function->build(block.getColumnsWithTypeAndName()); - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); std::cerr << "output:\n"; debug::headColumn(result); ASSERT_EQ(result->getUInt(3), 1); @@ -176,14 +176,14 @@ TEST(TestFunction, NotIn2) std::cerr << "input:\n"; debug::headBlock(block); auto executable = function->build(block.getColumnsWithTypeAndName()); - auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); + auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows(), false); auto function_not = factory.get("not", local_engine::QueryContext::globalContext()); auto type_bool = DataTypeFactory::instance().get("UInt8"); ColumnsWithTypeAndName columns2 = {ColumnWithTypeAndName(result, type_bool, "string0")}; Block block2(columns2); auto executable2 = function_not->build(block2.getColumnsWithTypeAndName()); - auto result2 = executable2->execute(block2.getColumnsWithTypeAndName(), executable2->getResultType(), block2.rows()); + auto result2 = executable2->execute(block2.getColumnsWithTypeAndName(), executable2->getResultType(), block2.rows(), false); std::cerr << "output:\n"; debug::headColumn(result2); ASSERT_EQ(result2->getUInt(3), 1); diff --git a/cpp-ch/local-engine/tests/gtest_local_engine.cpp b/cpp-ch/local-engine/tests/gtest_local_engine.cpp index 5f9b6f280e58..06e94e051b86 100644 --- a/cpp-ch/local-engine/tests/gtest_local_engine.cpp +++ b/cpp-ch/local-engine/tests/gtest_local_engine.cpp @@ -67,6 +67,7 @@ TEST(ReadBufferFromFile, seekBackwards) WriteBufferFromFile out(tmp_file->path()); for (size_t i = 0; i < N; ++i) writeIntBinary(i, out); + out.finalize(); } ReadBufferFromFile in(tmp_file->path(), BUF_SIZE);