From fa0ef79e21bcb14ce5e19fabf007d5cbb1297a07 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Thu, 4 Apr 2024 00:21:34 -0400 Subject: [PATCH] Merge duplicate key fix (#3207) --- src/binder/bind/bind_updating_clause.cpp | 3 +- .../updating_clause/bound_merge_clause.h | 54 ++-- .../optimizer/factorization_rewriter.h | 3 +- .../optimizer/logical_operator_visitor.h | 6 + .../planner/operator/logical_accumulate.h | 16 +- .../planner/operator/logical_distinct.h | 53 ++- .../planner/operator/logical_mark_accmulate.h | 35 ++ .../planner/operator/logical_operator.h | 1 + .../operator/persistent/logical_merge.h | 43 ++- src/include/planner/planner.h | 11 +- .../operator/aggregate/aggregate_hash_table.h | 54 ++-- .../operator/aggregate/hash_aggregate.h | 51 +-- .../operator/persistent/insert_executor.h | 9 +- .../processor/operator/persistent/merge.h | 14 +- src/include/processor/plan_mapper.h | 8 +- .../processor/result/mark_hash_table.h | 32 ++ .../agg_key_dependency_optimizer.cpp | 6 +- src/optimizer/factorization_rewriter.cpp | 9 + src/optimizer/logical_operator_visitor.cpp | 6 + .../projection_push_down_optimizer.cpp | 7 +- src/planner/operator/CMakeLists.txt | 1 + src/planner/operator/logical_accumulate.cpp | 2 +- src/planner/operator/logical_distinct.cpp | 18 +- .../operator/logical_mark_accumulate.cpp | 46 +++ src/planner/operator/logical_operator.cpp | 2 + src/planner/plan/CMakeLists.txt | 1 + src/planner/plan/append_distinct.cpp | 4 +- src/planner/plan/append_mark_accumulate.cpp | 20 ++ src/planner/plan/plan_read.cpp | 7 +- src/planner/plan/plan_subquery.cpp | 25 +- src/planner/plan/plan_update.cpp | 52 +-- src/processor/map/CMakeLists.txt | 1 + src/processor/map/map_accumulate.cpp | 2 +- src/processor/map/map_aggregate.cpp | 68 +++- src/processor/map/map_distinct.cpp | 8 +- src/processor/map/map_expressions_scan.cpp | 2 +- src/processor/map/map_mark_accumulate.cpp | 26 ++ src/processor/map/map_merge.cpp | 17 +- src/processor/map/plan_mapper.cpp | 3 + .../aggregate/aggregate_hash_table.cpp | 306 +++++++++--------- .../operator/aggregate/hash_aggregate.cpp | 61 +++- .../operator/persistent/insert_executor.cpp | 8 + src/processor/operator/persistent/merge.cpp | 53 ++- src/processor/result/CMakeLists.txt | 1 + src/processor/result/mark_hash_table.cpp | 51 +++ .../test_files/update_node/merge_tinysnb.test | 35 ++ 46 files changed, 837 insertions(+), 404 deletions(-) create mode 100644 src/include/planner/operator/logical_mark_accmulate.h create mode 100644 src/include/processor/result/mark_hash_table.h create mode 100644 src/planner/operator/logical_mark_accumulate.cpp create mode 100644 src/planner/plan/append_mark_accumulate.cpp create mode 100644 src/processor/map/map_mark_accumulate.cpp create mode 100644 src/processor/result/mark_hash_table.cpp diff --git a/src/binder/bind/bind_updating_clause.cpp b/src/binder/bind/bind_updating_clause.cpp index 91f19c1ce68..d94acc75a4a 100644 --- a/src/binder/bind/bind_updating_clause.cpp +++ b/src/binder/bind/bind_updating_clause.cpp @@ -80,9 +80,10 @@ std::unique_ptr Binder::bindMergeClause( auto boundGraphPattern = bindGraphPattern(mergeClause.getPatternElementsRef()); rewriteMatchPattern(boundGraphPattern); auto createInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope); + auto distinctMark = createVariable("__distinctMark", *LogicalType::BOOL()); auto boundMergeClause = std::make_unique(std::move(boundGraphPattern.queryGraphCollection), - std::move(boundGraphPattern.where), std::move(createInfos)); + std::move(boundGraphPattern.where), std::move(createInfos), std::move(distinctMark)); if (mergeClause.hasOnMatchSetItems()) { for (auto& [lhs, rhs] : mergeClause.getOnMatchSetItemsRef()) { auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get()); diff --git a/src/include/binder/query/updating_clause/bound_merge_clause.h b/src/include/binder/query/updating_clause/bound_merge_clause.h index 0e794489dba..fa29b8d4db7 100644 --- a/src/include/binder/query/updating_clause/bound_merge_clause.h +++ b/src/include/binder/query/updating_clause/bound_merge_clause.h @@ -11,91 +11,92 @@ namespace binder { class BoundMergeClause : public BoundUpdatingClause { public: BoundMergeClause(QueryGraphCollection queryGraphCollection, - std::shared_ptr predicate, std::vector insertInfos) - : BoundUpdatingClause{common::ClauseType::MERGE}, queryGraphCollection{std::move( - queryGraphCollection)}, - predicate{std::move(predicate)}, insertInfos{std::move(insertInfos)} {} + std::shared_ptr predicate, std::vector insertInfos, + std::shared_ptr distinctMark) + : BoundUpdatingClause{common::ClauseType::MERGE}, + queryGraphCollection{std::move(queryGraphCollection)}, predicate{std::move(predicate)}, + insertInfos{std::move(insertInfos)}, distinctMark{std::move(distinctMark)} {} - inline const QueryGraphCollection* getQueryGraphCollection() const { - return &queryGraphCollection; - } - inline bool hasPredicate() const { return predicate != nullptr; } - inline std::shared_ptr getPredicate() const { return predicate; } + const QueryGraphCollection* getQueryGraphCollection() const { return &queryGraphCollection; } + bool hasPredicate() const { return predicate != nullptr; } + std::shared_ptr getPredicate() const { return predicate; } - inline const std::vector& getInsertInfosRef() const { return insertInfos; } - inline const std::vector& getOnMatchSetInfosRef() const { + const std::vector& getInsertInfosRef() const { return insertInfos; } + const std::vector& getOnMatchSetInfosRef() const { return onMatchSetPropertyInfos; } - inline const std::vector& getOnCreateSetInfosRef() const { + const std::vector& getOnCreateSetInfosRef() const { return onCreateSetPropertyInfos; } - inline bool hasInsertNodeInfo() const { + bool hasInsertNodeInfo() const { return hasInsertInfo( [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); } - inline std::vector getInsertNodeInfos() const { + std::vector getInsertNodeInfos() const { return getInsertInfos( [](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; }); } - inline bool hasInsertRelInfo() const { + bool hasInsertRelInfo() const { return hasInsertInfo( [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); } - inline std::vector getInsertRelInfos() const { + std::vector getInsertRelInfos() const { return getInsertInfos( [](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; }); } - inline bool hasOnMatchSetNodeInfo() const { + bool hasOnMatchSetNodeInfo() const { return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::NODE; }); } - inline std::vector getOnMatchSetNodeInfos() const { + std::vector getOnMatchSetNodeInfos() const { return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::NODE; }); } - inline bool hasOnMatchSetRelInfo() const { + bool hasOnMatchSetRelInfo() const { return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::REL; }); } - inline std::vector getOnMatchSetRelInfos() const { + std::vector getOnMatchSetRelInfos() const { return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::REL; }); } - inline bool hasOnCreateSetNodeInfo() const { + bool hasOnCreateSetNodeInfo() const { return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::NODE; }); } - inline std::vector getOnCreateSetNodeInfos() const { + std::vector getOnCreateSetNodeInfos() const { return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::NODE; }); } - inline bool hasOnCreateSetRelInfo() const { + bool hasOnCreateSetRelInfo() const { return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::REL; }); } - inline std::vector getOnCreateSetRelInfos() const { + std::vector getOnCreateSetRelInfos() const { return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) { return info.updateTableType == UpdateTableType::REL; }); } - inline void addOnMatchSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { + void addOnMatchSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { onMatchSetPropertyInfos.push_back(std::move(setPropertyInfo)); } - inline void addOnCreateSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { + void addOnCreateSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) { onCreateSetPropertyInfos.push_back(std::move(setPropertyInfo)); } + std::shared_ptr getDistinctMark() const { return distinctMark; } + private: bool hasInsertInfo(const std::function& check) const; std::vector getInsertInfos( @@ -121,6 +122,7 @@ class BoundMergeClause : public BoundUpdatingClause { std::vector onMatchSetPropertyInfos; // Update on create std::vector onCreateSetPropertyInfos; + std::shared_ptr distinctMark; }; } // namespace binder diff --git a/src/include/optimizer/factorization_rewriter.h b/src/include/optimizer/factorization_rewriter.h index afdbf05d639..95c978b2ef4 100644 --- a/src/include/optimizer/factorization_rewriter.h +++ b/src/include/optimizer/factorization_rewriter.h @@ -6,7 +6,7 @@ namespace kuzu { namespace optimizer { -class FactorizationRewriter : public LogicalOperatorVisitor { +class FactorizationRewriter final : public LogicalOperatorVisitor { public: void rewrite(planner::LogicalPlan* plan); @@ -19,6 +19,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor { void visitIntersect(planner::LogicalOperator* op) override; void visitProjection(planner::LogicalOperator* op) override; void visitAccumulate(planner::LogicalOperator* op) override; + void visitMarkAccumulate(planner::LogicalOperator*) override; void visitAggregate(planner::LogicalOperator* op) override; void visitOrderBy(planner::LogicalOperator* op) override; void visitLimit(planner::LogicalOperator* op) override; diff --git a/src/include/optimizer/logical_operator_visitor.h b/src/include/optimizer/logical_operator_visitor.h index 6929ebc3058..69c72a5a72d 100644 --- a/src/include/optimizer/logical_operator_visitor.h +++ b/src/include/optimizer/logical_operator_visitor.h @@ -105,6 +105,12 @@ class LogicalOperatorVisitor { return op; } + virtual void visitMarkAccumulate(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitMarkAccumulateReplace( + std::shared_ptr op) { + return op; + } + virtual void visitDistinct(planner::LogicalOperator* /*op*/) {} virtual std::shared_ptr visitDistinctReplace( std::shared_ptr op) { diff --git a/src/include/planner/operator/logical_accumulate.h b/src/include/planner/operator/logical_accumulate.h index 1f1a3399357..b521200de7a 100644 --- a/src/include/planner/operator/logical_accumulate.h +++ b/src/include/planner/operator/logical_accumulate.h @@ -6,7 +6,7 @@ namespace kuzu { namespace planner { -class LogicalAccumulate : public LogicalOperator { +class LogicalAccumulate final : public LogicalOperator { public: LogicalAccumulate(common::AccumulateType accumulateType, binder::expression_vector flatExprs, std::shared_ptr offset, std::shared_ptr child) @@ -14,20 +14,20 @@ class LogicalAccumulate : public LogicalOperator { accumulateType{accumulateType}, flatExprs{std::move(flatExprs)}, offset{ std::move(offset)} {} - void computeFactorizedSchema() final; - void computeFlatSchema() final; + void computeFactorizedSchema() override; + void computeFlatSchema() override; f_group_pos_set getGroupPositionsToFlatten() const; - inline std::string getExpressionsForPrinting() const final { return std::string{}; } + std::string getExpressionsForPrinting() const override { return {}; } - inline common::AccumulateType getAccumulateType() const { return accumulateType; } - inline binder::expression_vector getExpressionsToAccumulate() const { + common::AccumulateType getAccumulateType() const { return accumulateType; } + binder::expression_vector getPayloads() const { return children[0]->getSchema()->getExpressionsInScope(); } - inline std::shared_ptr getOffset() const { return offset; } + std::shared_ptr getOffset() const { return offset; } - inline std::unique_ptr copy() final { + std::unique_ptr copy() override { return make_unique( accumulateType, flatExprs, offset, children[0]->copy()); } diff --git a/src/include/planner/operator/logical_distinct.h b/src/include/planner/operator/logical_distinct.h index c635c14447f..46ae208b431 100644 --- a/src/include/planner/operator/logical_distinct.h +++ b/src/include/planner/operator/logical_distinct.h @@ -1,56 +1,43 @@ #pragma once #include "planner/operator/logical_operator.h" -#include "planner/operator/schema.h" namespace kuzu { namespace planner { class LogicalDistinct : public LogicalOperator { public: - LogicalDistinct( - binder::expression_vector keyExpressions, std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)}, - keyExpressions{std::move(keyExpressions)} {} - LogicalDistinct(binder::expression_vector keyExpressions, - binder::expression_vector dependentKeyExpressions, std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)}, - keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{ - std::move(dependentKeyExpressions)} {} + LogicalDistinct(binder::expression_vector keys, std::shared_ptr child) + : LogicalDistinct{ + LogicalOperatorType::DISTINCT, keys, binder::expression_vector{}, std::move(child)} {} + LogicalDistinct(LogicalOperatorType type, binder::expression_vector keys, + binder::expression_vector payloads, std::shared_ptr child) + : LogicalOperator{type, std::move(child)}, keys{std::move(keys)}, payloads{std::move( + payloads)} {} void computeFactorizedSchema() override; void computeFlatSchema() override; - f_group_pos_set getGroupsPosToFlatten(); + virtual f_group_pos_set getGroupsPosToFlatten(); std::string getExpressionsForPrinting() const override; - inline binder::expression_vector getKeyExpressions() const { return keyExpressions; } - inline void setKeyExpressions(binder::expression_vector expressions) { - keyExpressions = std::move(expressions); - } - inline binder::expression_vector getDependentKeyExpressions() const { - return dependentKeyExpressions; - } - inline void setDependentKeyExpressions(binder::expression_vector expressions) { - dependentKeyExpressions = std::move(expressions); - } - inline binder::expression_vector getAllDistinctExpressions() const { - binder::expression_vector result; - result.insert(result.end(), keyExpressions.begin(), keyExpressions.end()); - result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end()); - return result; - } + binder::expression_vector getKeys() const { return keys; } + void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); } + binder::expression_vector getPayloads() const { return payloads; } + void setPayloads(binder::expression_vector expressions) { payloads = std::move(expressions); } std::unique_ptr copy() override { - return make_unique( - keyExpressions, dependentKeyExpressions, children[0]->copy()); + return make_unique(operatorType, keys, payloads, children[0]->copy()); } -private: - binder::expression_vector keyExpressions; - // See logical_aggregate.h for details. - binder::expression_vector dependentKeyExpressions; +protected: + binder::expression_vector getKeysAndPayloads() const; + +protected: + binder::expression_vector keys; + // Payloads meaning additional keys that are functional dependent on the keys above. + binder::expression_vector payloads; }; } // namespace planner diff --git a/src/include/planner/operator/logical_mark_accmulate.h b/src/include/planner/operator/logical_mark_accmulate.h new file mode 100644 index 00000000000..697cb942977 --- /dev/null +++ b/src/include/planner/operator/logical_mark_accmulate.h @@ -0,0 +1,35 @@ +#pragma once + +#include "planner/operator/logical_operator.h" + +namespace kuzu { +namespace planner { + +class LogicalMarkAccumulate final : public LogicalOperator { +public: + LogicalMarkAccumulate(binder::expression_vector keys, std::shared_ptr mark, + std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::MARK_ACCUMULATE, std::move(child)}, + keys{std::move(keys)}, mark{std::move(mark)} {} + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + f_group_pos_set getGroupsPosToFlatten() const; + + std::string getExpressionsForPrinting() const override { return {}; } + binder::expression_vector getKeys() const { return keys; } + binder::expression_vector getPayloads() const; + std::shared_ptr getMark() const { return mark; } + + std::unique_ptr copy() override { + return std::make_unique(keys, mark, children[0]->copy()); + } + +private: + binder::expression_vector keys; + std::shared_ptr mark; +}; + +} // namespace planner +} // namespace kuzu diff --git a/src/include/planner/operator/logical_operator.h b/src/include/planner/operator/logical_operator.h index 4c65a8ea540..f2ab9daac09 100644 --- a/src/include/planner/operator/logical_operator.h +++ b/src/include/planner/operator/logical_operator.h @@ -34,6 +34,7 @@ enum class LogicalOperatorType : uint8_t { INTERSECT, INSERT, LIMIT, + MARK_ACCUMULATE, MERGE, MULTIPLICITY_REDUCER, NODE_LABEL_FILTER, diff --git a/src/include/planner/operator/persistent/logical_merge.h b/src/include/planner/operator/persistent/logical_merge.h index 93c4f2d9cb9..8762152247f 100644 --- a/src/include/planner/operator/persistent/logical_merge.h +++ b/src/include/planner/operator/persistent/logical_merge.h @@ -9,7 +9,8 @@ namespace planner { class LogicalMerge : public LogicalOperator { public: - LogicalMerge(std::shared_ptr mark, + LogicalMerge(std::shared_ptr existenceMark, + std::shared_ptr distinctMark, std::vector insertNodeInfos, std::vector insertRelInfos, std::vector> onCreateSetNodeInfos, @@ -17,7 +18,8 @@ class LogicalMerge : public LogicalOperator { std::vector> onMatchSetNodeInfos, std::vector> onMatchSetRelInfos, std::shared_ptr child) - : LogicalOperator{LogicalOperatorType::MERGE, std::move(child)}, mark{std::move(mark)}, + : LogicalOperator{LogicalOperatorType::MERGE, std::move(child)}, + existenceMark{std::move(existenceMark)}, distinctMark{std::move(distinctMark)}, insertNodeInfos{std::move(insertNodeInfos)}, insertRelInfos{std::move(insertRelInfos)}, onCreateSetNodeInfos{std::move(onCreateSetNodeInfos)}, onCreateSetRelInfos(std::move(onCreateSetRelInfos)), @@ -27,44 +29,41 @@ class LogicalMerge : public LogicalOperator { void computeFactorizedSchema() final; void computeFlatSchema() final; - inline std::string getExpressionsForPrinting() const final { return std::string(""); } + std::string getExpressionsForPrinting() const final { return {}; } f_group_pos_set getGroupsPosToFlatten(); - inline std::shared_ptr getMark() const { return mark; } - inline const std::vector& getInsertNodeInfosRef() const { - return insertNodeInfos; - } - inline const std::vector& getInsertRelInfosRef() const { - return insertRelInfos; - } - inline const std::vector>& - getOnCreateSetNodeInfosRef() const { + std::shared_ptr getExistenceMark() const { return existenceMark; } + bool hasDistinctMark() const { return distinctMark != nullptr; } + std::shared_ptr getDistinctMark() const { return distinctMark; } + + const std::vector& getInsertNodeInfosRef() const { return insertNodeInfos; } + const std::vector& getInsertRelInfosRef() const { return insertRelInfos; } + const std::vector>& getOnCreateSetNodeInfosRef() const { return onCreateSetNodeInfos; } - inline const std::vector>& - getOnCreateSetRelInfosRef() const { + const std::vector>& getOnCreateSetRelInfosRef() const { return onCreateSetRelInfos; } - inline const std::vector>& - getOnMatchSetNodeInfosRef() const { + const std::vector>& getOnMatchSetNodeInfosRef() const { return onMatchSetNodeInfos; } - inline const std::vector>& - getOnMatchSetRelInfosRef() const { + const std::vector>& getOnMatchSetRelInfosRef() const { return onMatchSetRelInfos; } - inline std::unique_ptr copy() final { - return std::make_unique(mark, copyVector(insertNodeInfos), - copyVector(insertRelInfos), LogicalSetPropertyInfo::copy(onCreateSetNodeInfos), + std::unique_ptr copy() final { + return std::make_unique(existenceMark, distinctMark, + copyVector(insertNodeInfos), copyVector(insertRelInfos), + LogicalSetPropertyInfo::copy(onCreateSetNodeInfos), LogicalSetPropertyInfo::copy(onCreateSetRelInfos), LogicalSetPropertyInfo::copy(onMatchSetNodeInfos), LogicalSetPropertyInfo::copy(onMatchSetRelInfos), children[0]->copy()); } private: - std::shared_ptr mark; + std::shared_ptr existenceMark; + std::shared_ptr distinctMark; // Create infos std::vector insertNodeInfos; std::vector insertRelInfos; diff --git a/src/include/planner/planner.h b/src/include/planner/planner.h index 594b78359f9..5040462cb94 100644 --- a/src/include/planner/planner.h +++ b/src/include/planner/planner.h @@ -111,13 +111,18 @@ class Planner { // Plan subquery void planOptionalMatch(const binder::QueryGraphCollection& queryGraphCollection, - const binder::expression_vector& predicates, LogicalPlan& leftPlan); + const binder::expression_vector& predicates, const binder::expression_vector& corrExprs, + LogicalPlan& leftPlan); void planRegularMatch(const binder::QueryGraphCollection& queryGraphCollection, const binder::expression_vector& predicates, LogicalPlan& leftPlan); void planSubquery(const std::shared_ptr& subquery, LogicalPlan& outerPlan); void planSubqueryIfNecessary( const std::shared_ptr& expression, LogicalPlan& plan); + static binder::expression_vector getCorrelatedExprs( + const binder::QueryGraphCollection& collection, const binder::expression_vector& predicates, + Schema* outerSchema); + // Plan query graphs std::unique_ptr planQueryGraphCollection( const binder::QueryGraphCollection& queryGraphCollection, @@ -246,6 +251,8 @@ class Planner { void appendAccumulate(common::AccumulateType accumulateType, const binder::expression_vector& flatExprs, std::shared_ptr offset, LogicalPlan& plan); + void appendMarkAccumulate(const binder::expression_vector& keys, + std::shared_ptr mark, LogicalPlan& plan); void appendDummyScan(LogicalPlan& plan); @@ -265,7 +272,7 @@ class Planner { void appendScanFile(const binder::BoundFileScanInfo* info, std::shared_ptr offset, LogicalPlan& plan); - void appendDistinct(const binder::expression_vector& expressionsToDistinct, LogicalPlan& plan); + void appendDistinct(const binder::expression_vector& keys, LogicalPlan& plan); std::unique_ptr createUnionPlan( std::vector>& childrenPlans, bool isUnionAll); diff --git a/src/include/processor/operator/aggregate/aggregate_hash_table.h b/src/include/processor/operator/aggregate/aggregate_hash_table.h index 017c5a74b96..3b60b34d651 100644 --- a/src/include/processor/operator/aggregate/aggregate_hash_table.h +++ b/src/include/processor/operator/aggregate/aggregate_hash_table.h @@ -14,6 +14,8 @@ struct HashSlot { // groupKeyN, aggregateState1, ..., aggregateStateN, hashValue]. }; +enum class HashTableType : uint8_t { AGGREGATE_HASH_TABLE = 0, MARK_HASH_TABLE = 1 }; + /** * AggregateHashTable Design * @@ -42,15 +44,15 @@ class AggregateHashTable : public BaseHashTable { AggregateHashTable(storage::MemoryManager& memoryManager, const common::logical_type_vec_t& keysDataTypes, const std::vector>& aggregateFunctions, - uint64_t numEntriesToAllocate) + uint64_t numEntriesToAllocate, std::unique_ptr tableSchema) : AggregateHashTable(memoryManager, keysDataTypes, std::vector(), - aggregateFunctions, numEntriesToAllocate) {} + aggregateFunctions, numEntriesToAllocate, std::move(tableSchema)) {} AggregateHashTable(storage::MemoryManager& memoryManager, std::vector keysDataTypes, std::vector payloadsDataTypes, const std::vector>& aggregateFunctions, - uint64_t numEntriesToAllocate); + uint64_t numEntriesToAllocate, std::unique_ptr tableSchema); uint8_t* getEntry(uint64_t idx) { return factorizedTable->getTuple(idx); } @@ -86,9 +88,26 @@ class AggregateHashTable : public BaseHashTable { void resize(uint64_t newSize); +protected: + virtual uint64_t matchFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, uint64_t numMayMatches, + uint64_t numNoMatches); + + virtual void initializeFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, + const std::vector& dependentKeyVectors, + uint64_t numFTEntriesToInitialize); + + uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches, + uint64_t& numNoMatches, uint32_t colIdx); + + uint64_t matchFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches, + uint64_t& numNoMatches, uint32_t colIdx); + private: void initializeFT( - const std::vector>& aggregateFunctions); + const std::vector>& aggregateFunctions, + std::unique_ptr tableSchema); void initializeHashTable(uint64_t numEntriesToAllocate); @@ -105,11 +124,6 @@ class AggregateHashTable : public BaseHashTable { void initializeFTEntryWithUnFlatVec( common::ValueVector* unFlatVector, uint64_t numEntriesToInitialize, uint32_t colIdx); - void initializeFTEntries(const std::vector& flatKeyVectors, - const std::vector& unFlatKeyVectors, - const std::vector& dependentKeyVectors, - uint64_t numFTEntriesToInitialize); - uint8_t* createEntryInDistinctHT( const std::vector& groupByHashKeyVectors, common::hash_t hash); @@ -145,16 +159,6 @@ class AggregateHashTable : public BaseHashTable { // are flat. bool matchFlatGroupByKeys(const std::vector& keyVectors, uint8_t* entry); - uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches, - uint64_t& numNoMatches, uint32_t colIdx); - - uint64_t matchFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches, - uint64_t& numNoMatches, uint32_t colIdx); - - uint64_t matchFTEntries(const std::vector& flatKeyVectors, - const std::vector& unFlatKeyVectors, uint64_t numMayMatches, - uint64_t numNoMatches); - void fillEntryWithInitialNullAggregateState(uint8_t* entry); //! find an uninitialized hash slot for given hash and fill hash slot with block id and offset @@ -205,13 +209,19 @@ class AggregateHashTable : public BaseHashTable { std::unique_ptr& aggregateFunction, common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset); +protected: + uint32_t hashColIdxInFT; + std::unique_ptr mayMatchIdxes; + std::unique_ptr noMatchIdxes; + std::unique_ptr entryIdxesToInitialize; + std::unique_ptr hashSlotsToUpdateAggState; + private: std::vector dependentKeyDataTypes; std::vector> aggregateFunctions; //! special handling of distinct aggregate std::vector> distinctHashTables; - uint32_t hashColIdxInFT; uint32_t hashColOffsetInFT; uint32_t aggStateColOffsetInFT; uint32_t aggStateColIdxInFT; @@ -219,11 +229,7 @@ class AggregateHashTable : public BaseHashTable { uint32_t numBytesForDependentKeys = 0; std::vector updateAggFuncs; // Temporary arrays to hold intermediate results. - std::unique_ptr hashSlotsToUpdateAggState; std::unique_ptr tmpValueIdxes; - std::unique_ptr entryIdxesToInitialize; - std::unique_ptr mayMatchIdxes; - std::unique_ptr noMatchIdxes; std::unique_ptr tmpSlotIdxes; }; diff --git a/src/include/processor/operator/aggregate/hash_aggregate.h b/src/include/processor/operator/aggregate/hash_aggregate.h index 06f2ab9088e..aadee4cd5be 100644 --- a/src/include/processor/operator/aggregate/hash_aggregate.h +++ b/src/include/processor/operator/aggregate/hash_aggregate.h @@ -33,18 +33,42 @@ class HashAggregateSharedState final : public BaseAggregateSharedState { std::unique_ptr globalAggregateHashTable; }; +struct HashAggregateInfo { + std::vector flatKeysPos; + std::vector unFlatKeysPos; + std::vector dependentKeysPos; + std::unique_ptr tableSchema; + HashTableType hashTableType; + + HashAggregateInfo(std::vector flatKeysPos, std::vector unFlatKeysPos, + std::vector dependentKeysPos, std::unique_ptr tableSchema, + HashTableType hashTableType); + HashAggregateInfo(const HashAggregateInfo& other); +}; + +struct HashAggregateLocalState { + std::vector flatKeyVectors; + std::vector unFlatKeyVectors; + std::vector dependentKeyVectors; + common::DataChunkState* leadingState; + std::unique_ptr aggregateHashTable; + + void init(ResultSet& resultSet, main::ClientContext* context, HashAggregateInfo& info, + std::vector>& aggregateFunctions); + void append( + std::vector>& aggregateInputs, uint64_t multiplicity) const; +}; + class HashAggregate : public BaseAggregate { public: HashAggregate(std::unique_ptr resultSetDescriptor, - std::shared_ptr sharedState, std::vector flatKeysPos, - std::vector unFlatKeysPos, std::vector dependentKeysPos, + std::shared_ptr sharedState, HashAggregateInfo aggregateInfo, std::vector> aggregateFunctions, std::vector> aggregateInputInfos, std::unique_ptr child, uint32_t id, const std::string& paramsString) : BaseAggregate{std::move(resultSetDescriptor), std::move(aggregateFunctions), std::move(aggregateInputInfos), std::move(child), id, paramsString}, - flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)}, - dependentKeysPos{std::move(dependentKeysPos)}, sharedState{std::move(sharedState)} {} + aggregateInfo{std::move(aggregateInfo)}, sharedState{std::move(sharedState)} {} void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; @@ -52,24 +76,15 @@ class HashAggregate : public BaseAggregate { void finalize(ExecutionContext* context) override; - inline std::unique_ptr clone() override { - return make_unique(resultSetDescriptor->copy(), sharedState, flatKeysPos, - unFlatKeysPos, dependentKeysPos, cloneAggFunctions(), cloneAggInputInfos(), - children[0]->clone(), id, paramsString); + std::unique_ptr clone() override { + return make_unique(resultSetDescriptor->copy(), sharedState, aggregateInfo, + cloneAggFunctions(), cloneAggInputInfos(), children[0]->clone(), id, paramsString); } private: - std::vector flatKeysPos; - std::vector unFlatKeysPos; - std::vector dependentKeysPos; - - std::vector flatKeyVectors; - std::vector unFlatKeyVectors; - std::vector dependentKeyVectors; - common::DataChunkState* leadingState; - + HashAggregateInfo aggregateInfo; + HashAggregateLocalState localState; std::shared_ptr sharedState; - std::unique_ptr localAggregateHashTable; }; } // namespace processor diff --git a/src/include/processor/operator/persistent/insert_executor.h b/src/include/processor/operator/persistent/insert_executor.h index 85f9f9f6196..1643feac1d3 100644 --- a/src/include/processor/operator/persistent/insert_executor.h +++ b/src/include/processor/operator/persistent/insert_executor.h @@ -28,11 +28,14 @@ class NodeInsertExecutor { void insert(transaction::Transaction* transaction, ExecutionContext* context); + void evaluateResult(ExecutionContext* context); + + void writeResult(); + private: NodeInsertExecutor(const NodeInsertExecutor& other); bool checkConfict(transaction::Transaction* transaction); - void writeResult(); private: // Node table to insert. @@ -69,11 +72,11 @@ class RelInsertExecutor { void insert(transaction::Transaction* transaction, ExecutionContext* context); + void writeResult(); + private: RelInsertExecutor(const RelInsertExecutor& other); - void writeResult(); - private: storage::RelsStoreStats* relsStatistics; storage::RelTable* table; diff --git a/src/include/processor/operator/persistent/merge.h b/src/include/processor/operator/persistent/merge.h index 7eea1e3959b..09994e1f05c 100644 --- a/src/include/processor/operator/persistent/merge.h +++ b/src/include/processor/operator/persistent/merge.h @@ -9,7 +9,8 @@ namespace processor { class Merge : public PhysicalOperator { public: - Merge(const DataPos& markPos, std::vector nodeInsertExecutors, + Merge(const DataPos& existenceMark, const DataPos& distinctMark, + std::vector nodeInsertExecutors, std::vector relInsertExecutors, std::vector> onCreateNodeSetExecutors, std::vector> onCreateRelSetExecutors, @@ -17,7 +18,8 @@ class Merge : public PhysicalOperator { std::vector> onMatchRelSetExecutors, std::unique_ptr child, uint32_t id, const std::string& paramsString) : PhysicalOperator{PhysicalOperatorType::MERGE, std::move(child), id, paramsString}, - markPos{markPos}, nodeInsertExecutors{std::move(nodeInsertExecutors)}, + existenceMark{existenceMark}, distinctMark{distinctMark}, nodeInsertExecutors{std::move( + nodeInsertExecutors)}, relInsertExecutors{std::move(relInsertExecutors)}, onCreateNodeSetExecutors{std::move( onCreateNodeSetExecutors)}, onCreateRelSetExecutors{std::move(onCreateRelSetExecutors)}, @@ -31,7 +33,7 @@ class Merge : public PhysicalOperator { bool getNextTuplesInternal(ExecutionContext* context) final; inline std::unique_ptr clone() final { - return std::make_unique(markPos, copyVector(nodeInsertExecutors), + return std::make_unique(existenceMark, distinctMark, copyVector(nodeInsertExecutors), copyVector(relInsertExecutors), NodeSetExecutor::copy(onCreateNodeSetExecutors), RelSetExecutor::copy(onCreateRelSetExecutors), NodeSetExecutor::copy(onMatchNodeSetExecutors), @@ -39,8 +41,10 @@ class Merge : public PhysicalOperator { } private: - DataPos markPos; - common::ValueVector* markVector = nullptr; + DataPos existenceMark; + DataPos distinctMark; + common::ValueVector* existenceVector = nullptr; + common::ValueVector* distinctVector = nullptr; std::vector nodeInsertExecutors; std::vector relInsertExecutors; diff --git a/src/include/processor/plan_mapper.h b/src/include/processor/plan_mapper.h index b4b8c1fe28a..56131e4dbfd 100644 --- a/src/include/processor/plan_mapper.h +++ b/src/include/processor/plan_mapper.h @@ -69,6 +69,7 @@ class PlanMapper { std::unique_ptr mapOrderBy(planner::LogicalOperator* logicalOperator); std::unique_ptr mapUnionAll(planner::LogicalOperator* logicalOperator); std::unique_ptr mapAccumulate(planner::LogicalOperator* logicalOperator); + std::unique_ptr mapMarkAccumulate(planner::LogicalOperator* logicalOperator); std::unique_ptr mapDummyScan(planner::LogicalOperator* logicalOperator); std::unique_ptr mapInsert(planner::LogicalOperator* logicalOperator); std::unique_ptr mapSetNodeProperty(planner::LogicalOperator* logicalOperator); @@ -148,14 +149,13 @@ class PlanMapper { std::unique_ptr createHashBuildInfo(const planner::Schema& buildSideSchema, const binder::expression_vector& keys, const binder::expression_vector& payloads); - std::unique_ptr createHashAggregate( - const binder::expression_vector& keyExpressions, - const binder::expression_vector& dependentKeyExpressions, + std::unique_ptr createHashAggregate(const binder::expression_vector& keys, + const binder::expression_vector& payloads, std::vector> aggregateFunctions, std::vector> aggregateInputInfos, std::vector aggregatesOutputPos, planner::Schema* inSchema, planner::Schema* outSchema, std::unique_ptr prevOperator, - const std::string& paramsString); + const std::string& paramsString, std::shared_ptr markExpression); std::unique_ptr getNodeInsertExecutor( const planner::LogicalInsertInfo* info, const planner::Schema& inSchema, diff --git a/src/include/processor/result/mark_hash_table.h b/src/include/processor/result/mark_hash_table.h new file mode 100644 index 00000000000..5f014ae9819 --- /dev/null +++ b/src/include/processor/result/mark_hash_table.h @@ -0,0 +1,32 @@ +#include "processor/operator/aggregate/aggregate_hash_table.h" + +#pragma once + +namespace kuzu { +namespace processor { + +class MarkHashTable : public AggregateHashTable { + +public: + MarkHashTable(storage::MemoryManager& memoryManager, + std::vector keyDataTypes, + std::vector dependentKeyDataTypes, + const std::vector>& aggregateFunctions, + uint64_t numEntriesToAllocate, std::unique_ptr tableSchema); + + uint64_t matchFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, uint64_t numMayMatches, + uint64_t numNoMatches) override; + + void initializeFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, + const std::vector& dependentKeyVectors, + uint64_t numFTEntriesToInitialize) override; + +private: + std::unordered_set onMatchSlotIdxes; + uint32_t distinctColIdxInFT; +}; + +} // namespace processor +} // namespace kuzu diff --git a/src/optimizer/agg_key_dependency_optimizer.cpp b/src/optimizer/agg_key_dependency_optimizer.cpp index 04dcd7fbf9e..e027e792e31 100644 --- a/src/optimizer/agg_key_dependency_optimizer.cpp +++ b/src/optimizer/agg_key_dependency_optimizer.cpp @@ -33,9 +33,9 @@ void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) { void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) { auto distinct = (LogicalDistinct*)op; - auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeyExpressions()); - distinct->setKeyExpressions(keys); - distinct->setDependentKeyExpressions(dependentKeys); + auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeys()); + distinct->setKeys(keys); + distinct->setPayloads(dependentKeys); } std::pair diff --git a/src/optimizer/factorization_rewriter.cpp b/src/optimizer/factorization_rewriter.cpp index d18e3ad37e4..f87470dd1f1 100644 --- a/src/optimizer/factorization_rewriter.cpp +++ b/src/optimizer/factorization_rewriter.cpp @@ -1,6 +1,7 @@ #include "optimizer/factorization_rewriter.h" #include "binder/expression_visitor.h" +#include "common/cast.h" #include "planner/operator/extend/logical_extend.h" #include "planner/operator/extend/logical_recursive_extend.h" #include "planner/operator/factorization/flatten_resolver.h" @@ -12,6 +13,7 @@ #include "planner/operator/logical_hash_join.h" #include "planner/operator/logical_intersect.h" #include "planner/operator/logical_limit.h" +#include "planner/operator/logical_mark_accmulate.h" #include "planner/operator/logical_order_by.h" #include "planner/operator/logical_projection.h" #include "planner/operator/logical_union.h" @@ -22,6 +24,7 @@ #include "planner/operator/persistent/logical_merge.h" #include "planner/operator/persistent/logical_set.h" +using namespace kuzu::common; using namespace kuzu::binder; using namespace kuzu::planner; @@ -104,6 +107,12 @@ void FactorizationRewriter::visitAccumulate(planner::LogicalOperator* op) { accumulate->setChild(0, appendFlattens(accumulate->getChild(0), groupsPosToFlatten)); } +void FactorizationRewriter::visitMarkAccumulate(planner::LogicalOperator* op) { + auto markAccumulate = ku_dynamic_cast(op); + auto groupsPos = markAccumulate->getGroupsPosToFlatten(); + markAccumulate->setChild(0, appendFlattens(markAccumulate->getChild(0), groupsPos)); +} + void FactorizationRewriter::visitAggregate(planner::LogicalOperator* op) { auto aggregate = (LogicalAggregate*)op; auto groupsPosToFlattenForGroupBy = aggregate->getGroupsPosToFlattenForGroupBy(); diff --git a/src/optimizer/logical_operator_visitor.cpp b/src/optimizer/logical_operator_visitor.cpp index b43c85167b8..9cb1e0fb2cf 100644 --- a/src/optimizer/logical_operator_visitor.cpp +++ b/src/optimizer/logical_operator_visitor.cpp @@ -52,6 +52,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) { case LogicalOperatorType::ACCUMULATE: { visitAccumulate(op); } break; + case LogicalOperatorType::MARK_ACCUMULATE: { + visitMarkAccumulate(op); + } break; case LogicalOperatorType::DISTINCT: { visitDistinct(op); } break; @@ -141,6 +144,9 @@ std::shared_ptr LogicalOperatorVisitor::visitOperatorR case LogicalOperatorType::ACCUMULATE: { return visitAccumulateReplace(op); } + case LogicalOperatorType::MARK_ACCUMULATE: { + return visitMarkAccumulateReplace(op); + } case LogicalOperatorType::DISTINCT: { return visitDistinctReplace(op); } diff --git a/src/optimizer/projection_push_down_optimizer.cpp b/src/optimizer/projection_push_down_optimizer.cpp index b6ca4f855af..90d5c7bdf47 100644 --- a/src/optimizer/projection_push_down_optimizer.cpp +++ b/src/optimizer/projection_push_down_optimizer.cpp @@ -66,7 +66,7 @@ void ProjectionPushDownOptimizer::visitAccumulate(planner::LogicalOperator* op) if (accumulate->getAccumulateType() != AccumulateType::REGULAR) { return; } - auto expressionsBeforePruning = accumulate->getExpressionsToAccumulate(); + auto expressionsBeforePruning = accumulate->getPayloads(); auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning); if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) { return; @@ -200,7 +200,10 @@ void ProjectionPushDownOptimizer::visitDeleteRel(planner::LogicalOperator* op) { // TODO(Xiyang): come back and refactor this after changing insert interface void ProjectionPushDownOptimizer::visitMerge(planner::LogicalOperator* op) { auto merge = (LogicalMerge*)op; - collectExpressionsInUse(merge->getMark()); + if (merge->hasDistinctMark()) { + collectExpressionsInUse(merge->getDistinctMark()); + } + collectExpressionsInUse(merge->getExistenceMark()); for (auto& info : merge->getInsertNodeInfosRef()) { visitInsertInfo(&info); } diff --git a/src/planner/operator/CMakeLists.txt b/src/planner/operator/CMakeLists.txt index a69aaf38a68..fd176ffe117 100644 --- a/src/planner/operator/CMakeLists.txt +++ b/src/planner/operator/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(kuzu_planner_operator logical_in_query_call.cpp logical_intersect.cpp logical_limit.cpp + logical_mark_accumulate.cpp logical_operator.cpp logical_order_by.cpp logical_partitioner.cpp diff --git a/src/planner/operator/logical_accumulate.cpp b/src/planner/operator/logical_accumulate.cpp index 35c402422fd..ffc4dc760ea 100644 --- a/src/planner/operator/logical_accumulate.cpp +++ b/src/planner/operator/logical_accumulate.cpp @@ -9,7 +9,7 @@ namespace planner { void LogicalAccumulate::computeFactorizedSchema() { createEmptySchema(); auto childSchema = children[0]->getSchema(); - SinkOperatorUtil::recomputeSchema(*childSchema, getExpressionsToAccumulate(), *schema); + SinkOperatorUtil::recomputeSchema(*childSchema, getPayloads(), *schema); if (offset != nullptr) { // If we need to generate row offset. Then all expressions must have been flattened and // accumulated. So the new schema should just have one group. diff --git a/src/planner/operator/logical_distinct.cpp b/src/planner/operator/logical_distinct.cpp index 74fde0ae8f9..b27415b7c8e 100644 --- a/src/planner/operator/logical_distinct.cpp +++ b/src/planner/operator/logical_distinct.cpp @@ -1,5 +1,6 @@ #include "planner/operator/logical_distinct.h" +#include "binder/expression/expression_util.h" #include "planner/operator/factorization/flatten_resolver.h" namespace kuzu { @@ -8,7 +9,7 @@ namespace planner { void LogicalDistinct::computeFactorizedSchema() { createEmptySchema(); auto groupPos = schema->createGroup(); - for (auto& expression : getAllDistinctExpressions()) { + for (auto& expression : getKeysAndPayloads()) { schema->insertToGroupAndScope(expression, groupPos); } } @@ -16,7 +17,7 @@ void LogicalDistinct::computeFactorizedSchema() { void LogicalDistinct::computeFlatSchema() { createEmptySchema(); schema->createGroup(); - for (auto& expression : getAllDistinctExpressions()) { + for (auto& expression : getKeysAndPayloads()) { schema->insertToGroupAndScope(expression, 0); } } @@ -24,7 +25,7 @@ void LogicalDistinct::computeFlatSchema() { f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { f_group_pos_set dependentGroupsPos; auto childSchema = children[0]->getSchema(); - for (auto& expression : getAllDistinctExpressions()) { + for (auto& expression : getKeysAndPayloads()) { for (auto groupPos : childSchema->getDependentGroupsPos(expression)) { dependentGroupsPos.insert(groupPos); } @@ -33,10 +34,13 @@ f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { } std::string LogicalDistinct::getExpressionsForPrinting() const { - std::string result; - for (auto& expression : getAllDistinctExpressions()) { - result += expression->getUniqueName() + ", "; - } + return binder::ExpressionUtil::toString(getKeysAndPayloads()); +} + +binder::expression_vector LogicalDistinct::getKeysAndPayloads() const { + binder::expression_vector result; + result.insert(result.end(), keys.begin(), keys.end()); + result.insert(result.end(), payloads.begin(), payloads.end()); return result; } diff --git a/src/planner/operator/logical_mark_accumulate.cpp b/src/planner/operator/logical_mark_accumulate.cpp new file mode 100644 index 00000000000..01234d84f81 --- /dev/null +++ b/src/planner/operator/logical_mark_accumulate.cpp @@ -0,0 +1,46 @@ +#include "binder/expression/expression_util.h" +#include "planner/operator/factorization/flatten_resolver.h" +#include "planner/operator/factorization/sink_util.h" +#include "planner/operator/logical_mark_accmulate.h" + +using namespace kuzu::binder; + +namespace kuzu { +namespace planner { + +void LogicalMarkAccumulate::computeFactorizedSchema() { + createEmptySchema(); + auto childSchema = children[0]->getSchema(); + SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema); + f_group_pos groupPos; + if (!keys.empty()) { + groupPos = schema->getGroupPos(*keys[0]); + } else { + groupPos = schema->createGroup(); + } + schema->insertToGroupAndScope(mark, groupPos); +} + +void LogicalMarkAccumulate::computeFlatSchema() { + copyChildSchema(0); + schema->insertToGroupAndScope(mark, 0); +} + +f_group_pos_set LogicalMarkAccumulate::getGroupsPosToFlatten() const { + f_group_pos_set dependentGroupsPos; + auto childSchema = children[0]->getSchema(); + for (auto& key : keys) { + for (auto groupPos : childSchema->getDependentGroupsPos(key)) { + dependentGroupsPos.insert(groupPos); + } + } + return factorization::FlattenAll::getGroupsPosToFlatten(dependentGroupsPos, childSchema); +} + +expression_vector LogicalMarkAccumulate::getPayloads() const { + auto exprs = children[0]->getSchema()->getExpressionsInScope(); + return ExpressionUtil::excludeExpressions(exprs, keys); +} + +} // namespace planner +} // namespace kuzu diff --git a/src/planner/operator/logical_operator.cpp b/src/planner/operator/logical_operator.cpp index fca71e61f0c..3bf8c27794e 100644 --- a/src/planner/operator/logical_operator.cpp +++ b/src/planner/operator/logical_operator.cpp @@ -63,6 +63,8 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp return "INSERT"; case LogicalOperatorType::LIMIT: return "LIMIT"; + case LogicalOperatorType::MARK_ACCUMULATE: + return "MARK_ACCUMULATE"; case LogicalOperatorType::MERGE: return "MERGE"; case LogicalOperatorType::MULTIPLICITY_REDUCER: diff --git a/src/planner/plan/CMakeLists.txt b/src/planner/plan/CMakeLists.txt index 5c658dec138..1c26c928621 100644 --- a/src/planner/plan/CMakeLists.txt +++ b/src/planner/plan/CMakeLists.txt @@ -14,6 +14,7 @@ add_library(kuzu_planner_plan_operator append_in_query_call.cpp append_join.cpp append_limit.cpp + append_mark_accumulate.cpp append_multiplicity_reducer.cpp append_order_by.cpp append_projection.cpp diff --git a/src/planner/plan/append_distinct.cpp b/src/planner/plan/append_distinct.cpp index ec6f620b36e..1fc8f9d06f3 100644 --- a/src/planner/plan/append_distinct.cpp +++ b/src/planner/plan/append_distinct.cpp @@ -6,8 +6,8 @@ using namespace kuzu::binder; namespace kuzu { namespace planner { -void Planner::appendDistinct(const expression_vector& expressionsToDistinct, LogicalPlan& plan) { - auto distinct = make_shared(expressionsToDistinct, plan.getLastOperator()); +void Planner::appendDistinct(const expression_vector& keys, LogicalPlan& plan) { + auto distinct = make_shared(keys, plan.getLastOperator()); appendFlattens(distinct->getGroupsPosToFlatten(), plan); distinct->setChild(0, plan.getLastOperator()); distinct->computeFactorizedSchema(); diff --git a/src/planner/plan/append_mark_accumulate.cpp b/src/planner/plan/append_mark_accumulate.cpp new file mode 100644 index 00000000000..feb67c9e1a5 --- /dev/null +++ b/src/planner/plan/append_mark_accumulate.cpp @@ -0,0 +1,20 @@ +#include "planner/operator/logical_mark_accmulate.h" +#include "planner/planner.h" + +using namespace kuzu::binder; + +namespace kuzu { +namespace planner { + +void Planner::appendMarkAccumulate( + const expression_vector& keys, std::shared_ptr mark, LogicalPlan& plan) { + auto markAccumulate = + std::make_shared(keys, mark, plan.getLastOperator()); + appendFlattens(markAccumulate->getGroupsPosToFlatten(), plan); + markAccumulate->setChild(0, plan.getLastOperator()); + markAccumulate->computeFactorizedSchema(); + plan.setLastOperator(std::move(markAccumulate)); +} + +} // namespace planner +} // namespace kuzu diff --git a/src/planner/plan/plan_read.cpp b/src/planner/plan/plan_read.cpp index a366e1c5d72..010d58789f3 100644 --- a/src/planner/plan/plan_read.cpp +++ b/src/planner/plan/plan_read.cpp @@ -50,7 +50,12 @@ void Planner::planMatchClause(const BoundReadingClause* boundReadingClause, } break; case MatchClauseType::OPTIONAL_MATCH: { for (auto& plan : plans) { - planOptionalMatch(*queryGraphCollection, predicates, *plan); + expression_vector corrExprs; + if (!plan->isEmpty()) { + corrExprs = + getCorrelatedExprs(*queryGraphCollection, predicates, plan->getSchema()); + } + planOptionalMatch(*queryGraphCollection, predicates, corrExprs, *plan); } } break; default: diff --git a/src/planner/plan/plan_subquery.cpp b/src/planner/plan/plan_subquery.cpp index a81e6c1a9d0..67da0e2332a 100644 --- a/src/planner/plan/plan_subquery.cpp +++ b/src/planner/plan/plan_subquery.cpp @@ -9,7 +9,7 @@ using namespace kuzu::common; namespace kuzu { namespace planner { -static expression_vector getCorrelatedExpressions(const QueryGraphCollection& collection, +binder::expression_vector Planner::getCorrelatedExprs(const QueryGraphCollection& collection, const expression_vector& predicates, Schema* outerSchema) { expression_vector result; for (auto& predicate : predicates) { @@ -26,7 +26,8 @@ static expression_vector getCorrelatedExpressions(const QueryGraphCollection& co } void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection, - const expression_vector& predicates, LogicalPlan& leftPlan) { + const expression_vector& predicates, const binder::expression_vector& corrExprs, + LogicalPlan& leftPlan) { if (leftPlan.isEmpty()) { // Optional match is the first clause. No left plan to join. auto plan = planQueryGraphCollection(queryGraphCollection, predicates); @@ -34,9 +35,7 @@ void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection appendAccumulate(AccumulateType::OPTIONAL_, leftPlan); return; } - auto correlatedExpressions = - getCorrelatedExpressions(queryGraphCollection, predicates, leftPlan.getSchema()); - if (correlatedExpressions.empty()) { + if (corrExprs.empty()) { // No join condition, apply cross product. auto rightPlan = planQueryGraphCollection(queryGraphCollection, predicates); if (leftPlan.hasUpdate()) { @@ -46,32 +45,32 @@ void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection } return; } - bool isInternalIDCorrelated = ExpressionUtil::isExpressionsWithDataType( - correlatedExpressions, LogicalTypeID::INTERNAL_ID); + bool isInternalIDCorrelated = + ExpressionUtil::isExpressionsWithDataType(corrExprs, LogicalTypeID::INTERNAL_ID); std::unique_ptr rightPlan; if (isInternalIDCorrelated) { // If all correlated expressions are node IDs. We can trivially unnest by scanning internal // ID in both outer and inner plan as these are fast in-memory operations. For node // properties, we only scan in the outer query. rightPlan = planQueryGraphCollectionInNewContext(SubqueryType::INTERNAL_ID_CORRELATED, - correlatedExpressions, leftPlan.getCardinality(), queryGraphCollection, predicates); + corrExprs, leftPlan.getCardinality(), queryGraphCollection, predicates); } else { // Unnest using ExpressionsScan which scans the accumulated table on probe side. - rightPlan = planQueryGraphCollectionInNewContext(SubqueryType::CORRELATED, - correlatedExpressions, leftPlan.getCardinality(), queryGraphCollection, predicates); - appendAccumulate(AccumulateType::REGULAR, correlatedExpressions, leftPlan); + rightPlan = planQueryGraphCollectionInNewContext(SubqueryType::CORRELATED, corrExprs, + leftPlan.getCardinality(), queryGraphCollection, predicates); + appendAccumulate(AccumulateType::REGULAR, corrExprs, leftPlan); } if (leftPlan.hasUpdate()) { throw RuntimeException(stringFormat("Optional match after update is not supported. Missing " "right outer join implementation.")); } - appendHashJoin(correlatedExpressions, JoinType::LEFT, leftPlan, *rightPlan, leftPlan); + appendHashJoin(corrExprs, JoinType::LEFT, leftPlan, *rightPlan, leftPlan); } void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection, const expression_vector& predicates, LogicalPlan& leftPlan) { auto correlatedExpressions = - getCorrelatedExpressions(queryGraphCollection, predicates, leftPlan.getSchema()); + getCorrelatedExprs(queryGraphCollection, predicates, leftPlan.getSchema()); expression_vector predicatesToPushDown, predicatesToPullUp; // E.g. MATCH (a) WITH COUNT(*) AS s MATCH (b) WHERE b.age > s // "b.age > s" should be pulled up after both MATCH clauses are joined. diff --git a/src/planner/plan/plan_update.cpp b/src/planner/plan/plan_update.cpp index be64ccf2195..210d3e95d06 100644 --- a/src/planner/plan/plan_update.cpp +++ b/src/planner/plan/plan_update.cpp @@ -65,23 +65,18 @@ void Planner::planMergeClause(const BoundUpdatingClause* updatingClause, Logical if (mergeClause->hasPredicate()) { predicates = mergeClause->getPredicate()->splitOnAND(); } - planOptionalMatch(*mergeClause->getQueryGraphCollection(), predicates, plan); - std::shared_ptr mark; - auto& createInfos = mergeClause->getInsertInfosRef(); - KU_ASSERT(!createInfos.empty()); - auto& createInfo = createInfos[0]; - switch (createInfo.tableType) { - case TableType::NODE: { - auto node = (NodeExpression*)createInfo.pattern.get(); - mark = node->getInternalID(); - } break; - case TableType::REL: { - auto rel = (RelExpression*)createInfo.pattern.get(); - mark = rel->getInternalIDProperty(); - } break; - default: - KU_UNREACHABLE; + std::shared_ptr distinctMark = nullptr; + expression_vector corrExprs; + if (!plan.isEmpty()) { + distinctMark = mergeClause->getDistinctMark(); + corrExprs = getCorrelatedExprs( + *mergeClause->getQueryGraphCollection(), predicates, plan.getSchema()); + if (corrExprs.size() == 0) { + throw RuntimeException{"Constant key in merge clause is not supported yet."}; + } + appendMarkAccumulate(corrExprs, distinctMark, plan); } + planOptionalMatch(*mergeClause->getQueryGraphCollection(), predicates, corrExprs, plan); std::vector logicalInsertNodeInfos; if (mergeClause->hasInsertNodeInfo()) { auto boundInsertNodeInfos = mergeClause->getInsertNodeInfos(); @@ -119,10 +114,27 @@ void Planner::planMergeClause(const BoundUpdatingClause* updatingClause, Logical logicalOnMatchSetRelInfos.push_back(createLogicalSetPropertyInfo(info)); } } - auto merge = std::make_shared(mark, std::move(logicalInsertNodeInfos), - std::move(logicalInsertRelInfos), std::move(logicalOnCreateSetNodeInfos), - std::move(logicalOnCreateSetRelInfos), std::move(logicalOnMatchSetNodeInfos), - std::move(logicalOnMatchSetRelInfos), plan.getLastOperator()); + std::shared_ptr existenceMark; + auto& createInfos = mergeClause->getInsertInfosRef(); + KU_ASSERT(!createInfos.empty()); + auto& createInfo = createInfos[0]; + switch (createInfo.tableType) { + case TableType::NODE: { + auto node = (NodeExpression*)createInfo.pattern.get(); + existenceMark = node->getInternalID(); + } break; + case TableType::REL: { + auto rel = (RelExpression*)createInfo.pattern.get(); + existenceMark = rel->getInternalIDProperty(); + } break; + default: + KU_UNREACHABLE; + } + auto merge = std::make_shared(existenceMark, distinctMark, + std::move(logicalInsertNodeInfos), std::move(logicalInsertRelInfos), + std::move(logicalOnCreateSetNodeInfos), std::move(logicalOnCreateSetRelInfos), + std::move(logicalOnMatchSetNodeInfos), std::move(logicalOnMatchSetRelInfos), + plan.getLastOperator()); appendFlattens(merge->getGroupsPosToFlatten(), plan); merge->setChild(0, plan.getLastOperator()); merge->computeFactorizedSchema(); diff --git a/src/processor/map/CMakeLists.txt b/src/processor/map/CMakeLists.txt index 8df4e48f0f3..c7aa880fcdc 100644 --- a/src/processor/map/CMakeLists.txt +++ b/src/processor/map/CMakeLists.txt @@ -33,6 +33,7 @@ add_library(kuzu_processor_mapper map_intersect.cpp map_label_filter.cpp map_limit.cpp + map_mark_accumulate.cpp map_merge.cpp map_multiplicity_reducer.cpp map_order_by.cpp diff --git a/src/processor/map/map_accumulate.cpp b/src/processor/map/map_accumulate.cpp index 2cb4d64b994..78f4fb36a8e 100644 --- a/src/processor/map/map_accumulate.cpp +++ b/src/processor/map/map_accumulate.cpp @@ -13,7 +13,7 @@ std::unique_ptr PlanMapper::mapAccumulate(LogicalOperator* log auto outSchema = acc->getSchema(); auto inSchema = acc->getChild(0)->getSchema(); auto prevOperator = mapOperator(acc->getChild(0).get()); - auto expressions = acc->getExpressionsToAccumulate(); + auto expressions = acc->getPayloads(); auto resultCollector = createResultCollector( acc->getAccumulateType(), expressions, inSchema, std::move(prevOperator)); auto table = resultCollector->getResultFactorizedTable(); diff --git a/src/processor/map/map_aggregate.cpp b/src/processor/map/map_aggregate.cpp index 9b5ffc39f6c..cfe8c0d26e2 100644 --- a/src/processor/map/map_aggregate.cpp +++ b/src/processor/map/map_aggregate.cpp @@ -81,7 +81,7 @@ std::unique_ptr PlanMapper::mapAggregate(LogicalOperator* logi return createHashAggregate(logicalAggregate.getKeyExpressions(), logicalAggregate.getDependentKeyExpressions(), std::move(aggregateFunctions), std::move(aggregateInputInfos), std::move(aggregatesOutputPos), inSchema, outSchema, - std::move(prevOperator), paramsString); + std::move(prevOperator), paramsString, nullptr); } else { auto sharedState = make_shared(aggregateFunctions); auto aggregate = @@ -93,28 +93,66 @@ std::unique_ptr PlanMapper::mapAggregate(LogicalOperator* logi } } +static std::unique_ptr getFactorizedTableSchema( + const binder::expression_vector& flatKeys, const binder::expression_vector& unflatKeys, + const binder::expression_vector& payloads, + std::vector>& aggregateFunctions, + std::shared_ptr markExpression) { + auto isUnflat = false; + auto dataChunkPos = 0u; + std::unique_ptr tableSchema = std::make_unique(); + for (auto& flatKey : flatKeys) { + auto size = LogicalTypeUtils::getRowLayoutSize(flatKey->dataType); + tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); + } + for (auto& unflatKey : unflatKeys) { + auto size = LogicalTypeUtils::getRowLayoutSize(unflatKey->dataType); + tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); + } + for (auto& payload : payloads) { + auto size = LogicalTypeUtils::getRowLayoutSize(payload->dataType); + tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); + } + for (auto& aggregateFunc : aggregateFunctions) { + tableSchema->appendColumn(std::make_unique( + isUnflat, dataChunkPos, aggregateFunc->getAggregateStateSize())); + } + if (markExpression != nullptr) { + tableSchema->appendColumn(std::make_unique( + isUnflat, dataChunkPos, LogicalTypeUtils::getRowLayoutSize(markExpression->dataType))); + } + tableSchema->appendColumn( + std::make_unique(isUnflat, dataChunkPos, sizeof(hash_t))); + return tableSchema; +} + std::unique_ptr PlanMapper::createHashAggregate( - const binder::expression_vector& keyExpressions, - const binder::expression_vector& dependentKeyExpressions, + const binder::expression_vector& keys, const binder::expression_vector& payloads, std::vector> aggregateFunctions, std::vector> aggregateInputInfos, std::vector aggregatesOutputPos, planner::Schema* inSchema, planner::Schema* outSchema, - std::unique_ptr prevOperator, const std::string& paramsString) { + std::unique_ptr prevOperator, const std::string& paramsString, + std::shared_ptr markExpression) { auto sharedState = make_shared(aggregateFunctions); - auto flatKeyExpressions = getKeyExpressions(keyExpressions, *inSchema, true /* isFlat */); - auto unFlatKeyExpressions = getKeyExpressions(keyExpressions, *inSchema, false /* isFlat */); + auto flatKeys = getKeyExpressions(keys, *inSchema, true /* isFlat */); + auto unFlatKeys = getKeyExpressions(keys, *inSchema, false /* isFlat */); + auto tableSchema = getFactorizedTableSchema( + flatKeys, unFlatKeys, payloads, aggregateFunctions, markExpression); + HashAggregateInfo aggregateInfo{getExpressionsDataPos(flatKeys, *inSchema), + getExpressionsDataPos(unFlatKeys, *inSchema), getExpressionsDataPos(payloads, *inSchema), + std::move(tableSchema), + markExpression == nullptr ? HashTableType::AGGREGATE_HASH_TABLE : + HashTableType::MARK_HASH_TABLE}; auto aggregate = make_unique(std::make_unique(inSchema), - sharedState, getExpressionsDataPos(flatKeyExpressions, *inSchema), - getExpressionsDataPos(unFlatKeyExpressions, *inSchema), - getExpressionsDataPos(dependentKeyExpressions, *inSchema), std::move(aggregateFunctions), + sharedState, std::move(aggregateInfo), std::move(aggregateFunctions), std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(), paramsString); binder::expression_vector outputExpressions; - outputExpressions.insert( - outputExpressions.end(), flatKeyExpressions.begin(), flatKeyExpressions.end()); - outputExpressions.insert( - outputExpressions.end(), unFlatKeyExpressions.begin(), unFlatKeyExpressions.end()); - outputExpressions.insert( - outputExpressions.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end()); + outputExpressions.insert(outputExpressions.end(), flatKeys.begin(), flatKeys.end()); + outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end()); + outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end()); + if (markExpression != nullptr) { + outputExpressions.emplace_back(markExpression); + } return std::make_unique(sharedState, getExpressionsDataPos(outputExpressions, *outSchema), std::move(aggregatesOutputPos), std::move(aggregate), getOperatorID(), paramsString); diff --git a/src/processor/map/map_distinct.cpp b/src/processor/map/map_distinct.cpp index 67eb1ab37c4..a1de19d1a49 100644 --- a/src/processor/map/map_distinct.cpp +++ b/src/processor/map/map_distinct.cpp @@ -16,10 +16,10 @@ std::unique_ptr PlanMapper::mapDistinct(LogicalOperator* logic std::vector> emptyAggFunctions; std::vector> emptyAggInputInfos; std::vector emptyAggregatesOutputPos; - return createHashAggregate(logicalDistinct.getKeyExpressions(), - logicalDistinct.getDependentKeyExpressions(), std::move(emptyAggFunctions), - std::move(emptyAggInputInfos), std::move(emptyAggregatesOutputPos), inSchema, outSchema, - std::move(prevOperator), logicalDistinct.getExpressionsForPrinting()); + return createHashAggregate(logicalDistinct.getKeys(), logicalDistinct.getPayloads(), + std::move(emptyAggFunctions), std::move(emptyAggInputInfos), + std::move(emptyAggregatesOutputPos), inSchema, outSchema, std::move(prevOperator), + logicalDistinct.getExpressionsForPrinting(), nullptr /* markExpression */); } } // namespace processor diff --git a/src/processor/map/map_expressions_scan.cpp b/src/processor/map/map_expressions_scan.cpp index 123a3d564a5..4a6558ee012 100644 --- a/src/processor/map/map_expressions_scan.cpp +++ b/src/processor/map/map_expressions_scan.cpp @@ -14,7 +14,7 @@ std::unique_ptr PlanMapper::mapExpressionsScan( auto expressionsScan = (planner::LogicalExpressionsScan*)logicalOperator; auto outerAccumulate = (planner::LogicalAccumulate*)expressionsScan->getOuterAccumulate(); expression_map materializedExpressionToColIdx; - auto materializedExpressions = outerAccumulate->getExpressionsToAccumulate(); + auto materializedExpressions = outerAccumulate->getPayloads(); for (auto i = 0u; i < materializedExpressions.size(); ++i) { materializedExpressionToColIdx.insert({materializedExpressions[i], i}); } diff --git a/src/processor/map/map_mark_accumulate.cpp b/src/processor/map/map_mark_accumulate.cpp new file mode 100644 index 00000000000..96268c08c7d --- /dev/null +++ b/src/processor/map/map_mark_accumulate.cpp @@ -0,0 +1,26 @@ +#include "planner/operator/logical_mark_accmulate.h" +#include "processor/operator/aggregate/aggregate_input.h" +#include "processor/plan_mapper.h" + +using namespace kuzu::planner; +using namespace kuzu::common; + +namespace kuzu { +namespace processor { + +std::unique_ptr PlanMapper::mapMarkAccumulate(LogicalOperator* op) { + auto logicalMarkAccumulate = ku_dynamic_cast(op); + auto keys = logicalMarkAccumulate->getKeys(); + auto payloads = logicalMarkAccumulate->getPayloads(); + auto outSchema = logicalMarkAccumulate->getSchema(); + auto inSchema = logicalMarkAccumulate->getChild(0)->getSchema(); + auto prevOperator = mapOperator(logicalMarkAccumulate->getChild(0).get()); + return createHashAggregate(keys, payloads, + std::vector>{}, + std::vector>{}, std::vector{}, inSchema, + outSchema, std::move(prevOperator), logicalMarkAccumulate->getExpressionsForPrinting(), + logicalMarkAccumulate->getMark()); +} + +} // namespace processor +} // namespace kuzu diff --git a/src/processor/map/map_merge.cpp b/src/processor/map/map_merge.cpp index 2bcc375946f..4b1b4ed3a83 100644 --- a/src/processor/map/map_merge.cpp +++ b/src/processor/map/map_merge.cpp @@ -12,7 +12,12 @@ std::unique_ptr PlanMapper::mapMerge(planner::LogicalOperator* auto outSchema = logicalMerge->getSchema(); auto inSchema = logicalMerge->getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); - auto markPos = DataPos(inSchema->getExpressionPos(*logicalMerge->getMark())); + auto existenceMarkPos = getDataPos(*logicalMerge->getExistenceMark(), *inSchema); + auto distinctMarkPos = DataPos(); + if (logicalMerge + ->hasDistinctMark()) { // If there is no distinct mark, then every input is distinct. + distinctMarkPos = getDataPos(*logicalMerge->getDistinctMark(), *inSchema); + } std::vector nodeInsertExecutors; for (auto& info : logicalMerge->getInsertNodeInfosRef()) { nodeInsertExecutors.push_back(getNodeInsertExecutor(&info, *inSchema, *outSchema)->copy()); @@ -37,11 +42,11 @@ std::unique_ptr PlanMapper::mapMerge(planner::LogicalOperator* for (auto& info : logicalMerge->getOnMatchSetRelInfosRef()) { onMatchRelSetExecutors.push_back(getRelSetExecutor(info.get(), *inSchema)); } - return std::make_unique(markPos, std::move(nodeInsertExecutors), - std::move(relInsertExecutors), std::move(onCreateNodeSetExecutors), - std::move(onCreateRelSetExecutors), std::move(onMatchNodeSetExecutors), - std::move(onMatchRelSetExecutors), std::move(prevOperator), getOperatorID(), - logicalMerge->getExpressionsForPrinting()); + return std::make_unique(existenceMarkPos, distinctMarkPos, + std::move(nodeInsertExecutors), std::move(relInsertExecutors), + std::move(onCreateNodeSetExecutors), std::move(onCreateRelSetExecutors), + std::move(onMatchNodeSetExecutors), std::move(onMatchRelSetExecutors), + std::move(prevOperator), getOperatorID(), logicalMerge->getExpressionsForPrinting()); } } // namespace processor diff --git a/src/processor/map/plan_mapper.cpp b/src/processor/map/plan_mapper.cpp index cd52ed24967..366b072dc48 100644 --- a/src/processor/map/plan_mapper.cpp +++ b/src/processor/map/plan_mapper.cpp @@ -108,6 +108,9 @@ std::unique_ptr PlanMapper::mapOperator(LogicalOperator* logic case LogicalOperatorType::ACCUMULATE: { physicalOperator = mapAccumulate(logicalOperator); } break; + case LogicalOperatorType::MARK_ACCUMULATE: { + physicalOperator = mapMarkAccumulate(logicalOperator); + } break; case LogicalOperatorType::DUMMY_SCAN: { physicalOperator = mapDummyScan(logicalOperator); } break; diff --git a/src/processor/operator/aggregate/aggregate_hash_table.cpp b/src/processor/operator/aggregate/aggregate_hash_table.cpp index 1b29edc6e4a..d5870141d7f 100644 --- a/src/processor/operator/aggregate/aggregate_hash_table.cpp +++ b/src/processor/operator/aggregate/aggregate_hash_table.cpp @@ -13,10 +13,10 @@ namespace processor { AggregateHashTable::AggregateHashTable(MemoryManager& memoryManager, std::vector keyDataTypes, std::vector dependentKeyDataTypes, const std::vector>& aggregateFunctions, - uint64_t numEntriesToAllocate) + uint64_t numEntriesToAllocate, std::unique_ptr tableSchema) : BaseHashTable{memoryManager, std::move(keyDataTypes)}, dependentKeyDataTypes{ std::move(dependentKeyDataTypes)} { - initializeFT(aggregateFunctions); + initializeFT(aggregateFunctions, std::move(tableSchema)); initializeHashTable(numEntriesToAllocate); distinctHashTables = AggregateHashTableUtils::createDistinctHashTables( memoryManager, this->keyTypes, this->aggregateFunctions); @@ -127,20 +127,14 @@ void AggregateHashTable::finalizeAggregateStates() { } void AggregateHashTable::initializeFT( - const std::vector>& aggFuncs) { - auto isUnflat = false; - auto dataChunkPos = 0u; - std::unique_ptr tableSchema = std::make_unique(); + const std::vector>& aggFuncs, + std::unique_ptr tableSchema) { aggStateColIdxInFT = keyTypes.size() + dependentKeyDataTypes.size(); for (auto& dataType : keyTypes) { - auto size = LogicalTypeUtils::getRowLayoutSize(dataType); - tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); - numBytesForKeys += size; + numBytesForKeys += LogicalTypeUtils::getRowLayoutSize(dataType); } for (auto& dataType : dependentKeyDataTypes) { - auto size = LogicalTypeUtils::getRowLayoutSize(dataType); - tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); - numBytesForDependentKeys += size; + numBytesForDependentKeys += LogicalTypeUtils::getRowLayoutSize(dataType); } aggStateColOffsetInFT = numBytesForKeys + numBytesForDependentKeys; @@ -148,16 +142,12 @@ void AggregateHashTable::initializeFT( updateAggFuncs.reserve(aggFuncs.size()); for (auto i = 0u; i < aggFuncs.size(); i++) { auto& aggFunc = aggFuncs[i]; - tableSchema->appendColumn(std::make_unique( - isUnflat, dataChunkPos, aggFunc->getAggregateStateSize())); aggregateFunctions.push_back(aggFunc->clone()); updateAggFuncs.push_back(aggFunc->isFunctionDistinct() ? &AggregateHashTable::updateDistinctAggState : &AggregateHashTable::updateAggState); } - tableSchema->appendColumn( - std::make_unique(isUnflat, dataChunkPos, sizeof(hash_t))); - hashColIdxInFT = aggStateColIdxInFT + aggFuncs.size(); + hashColIdxInFT = tableSchema->getNumColumns() - 1; hashColOffsetInFT = tableSchema->getColOffset(hashColIdxInFT); factorizedTable = std::make_unique(&memoryManager, std::move(tableSchema)); } @@ -212,6 +202,140 @@ void AggregateHashTable::resize(uint64_t newSize) { } } +uint64_t AggregateHashTable::matchFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, uint64_t numMayMatches, + uint64_t numNoMatches) { + auto colIdx = 0u; + for (auto& flatKeyVector : flatKeyVectors) { + numMayMatches = + matchFlatVecWithFTColumn(flatKeyVector, numMayMatches, numNoMatches, colIdx++); + } + for (auto& unFlatKeyVector : unFlatKeyVectors) { + numMayMatches = + matchUnFlatVecWithFTColumn(unFlatKeyVector, numMayMatches, numNoMatches, colIdx++); + } + return numNoMatches; +} + +void AggregateHashTable::initializeFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, + const std::vector& dependentKeyVectors, uint64_t numFTEntriesToInitialize) { + auto colIdx = 0u; + for (auto flatKeyVector : flatKeyVectors) { + initializeFTEntryWithFlatVec(flatKeyVector, numFTEntriesToInitialize, colIdx++); + } + for (auto unFlatKeyVector : unFlatKeyVectors) { + initializeFTEntryWithUnFlatVec(unFlatKeyVector, numFTEntriesToInitialize, colIdx++); + } + for (auto dependentKeyVector : dependentKeyVectors) { + if (dependentKeyVector->state->isFlat()) { + initializeFTEntryWithFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); + } else { + initializeFTEntryWithUnFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); + } + } + for (auto i = 0u; i < numFTEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto entry = hashSlotsToUpdateAggState[entryIdx]->entry; + fillEntryWithInitialNullAggregateState(entry); + // Fill the hashValue in the ftEntry. + factorizedTable->updateFlatCellNoNull(entry, hashColIdxInFT, + hashVector->getData() + hashVector->getNumBytesPerValue() * entryIdx); + } +} + +uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn( + ValueVector* vector, uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { + KU_ASSERT(!vector->state->isFlat()); + auto colOffset = factorizedTable->getTableSchema()->getColOffset(colIdx); + uint64_t mayMatchIdx = 0; + if (vector->hasNoNullsGuarantee()) { + if (factorizedTable->hasNoNullGuarantee(colIdx)) { + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + if (compareEntryFuncs[colIdx]( + vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + } else { + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + hashSlotsToUpdateAggState[idx]->entry + + factorizedTable->getTableSchema()->getNullMapOffset(), + colIdx); + if (isEntryKeyNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + if (compareEntryFuncs[colIdx]( + vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + } + } else { + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto isKeyVectorNull = vector->isNull(idx); + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + hashSlotsToUpdateAggState[idx]->entry + + factorizedTable->getTableSchema()->getNullMapOffset(), + colIdx); + if (isKeyVectorNull && isEntryKeyNull) { + mayMatchIdxes[mayMatchIdx++] = idx; + continue; + } else if (isKeyVectorNull != isEntryKeyNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + + if (compareEntryFuncs[colIdx]( + vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + } + return mayMatchIdx; +} + +uint64_t AggregateHashTable::matchFlatVecWithFTColumn( + ValueVector* vector, uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { + KU_ASSERT(vector->state->isFlat()); + auto colOffset = factorizedTable->getTableSchema()->getColOffset(colIdx); + uint64_t mayMatchIdx = 0; + auto pos = vector->state->selVector->selectedPositions[0]; + auto isVectorNull = vector->isNull(pos); + for (auto i = 0u; i < numMayMatches; i++) { + auto idx = mayMatchIdxes[i]; + auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( + hashSlotsToUpdateAggState[idx]->entry + + factorizedTable->getTableSchema()->getNullMapOffset(), + colIdx); + if (isEntryKeyNull && isVectorNull) { + mayMatchIdxes[mayMatchIdx++] = idx; + continue; + } else if (isEntryKeyNull != isVectorNull) { + noMatchIdxes[numNoMatches++] = idx; + continue; + } + if (compareEntryFuncs[colIdx]( + vector, pos, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { + mayMatchIdxes[mayMatchIdx++] = idx; + } else { + noMatchIdxes[numNoMatches++] = idx; + } + } + return mayMatchIdx; +} + void AggregateHashTable::initializeFTEntryWithFlatVec( ValueVector* flatVector, uint64_t numEntriesToInitialize, uint32_t colIdx) { KU_ASSERT(flatVector->state->isFlat()); @@ -254,33 +378,6 @@ void AggregateHashTable::initializeFTEntryWithUnFlatVec( } } -void AggregateHashTable::initializeFTEntries(const std::vector& flatKeyVectors, - const std::vector& unFlatKeyVectors, - const std::vector& dependentKeyVectors, uint64_t numFTEntriesToInitialize) { - auto colIdx = 0u; - for (auto flatKeyVector : flatKeyVectors) { - initializeFTEntryWithFlatVec(flatKeyVector, numFTEntriesToInitialize, colIdx++); - } - for (auto unFlatKeyVector : unFlatKeyVectors) { - initializeFTEntryWithUnFlatVec(unFlatKeyVector, numFTEntriesToInitialize, colIdx++); - } - for (auto dependentKeyVector : dependentKeyVectors) { - if (dependentKeyVector->state->isFlat()) { - initializeFTEntryWithFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); - } else { - initializeFTEntryWithUnFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); - } - } - for (auto i = 0u; i < numFTEntriesToInitialize; i++) { - auto entryIdx = entryIdxesToInitialize[i]; - auto entry = hashSlotsToUpdateAggState[entryIdx]->entry; - fillEntryWithInitialNullAggregateState(entry); - // Fill the hashValue in the ftEntry. - factorizedTable->updateFlatCellNoNull(entry, hashColIdxInFT, - hashVector->getData() + hashVector->getNumBytesPerValue() * entryIdx); - } -} - uint8_t* AggregateHashTable::createEntryInDistinctHT( const std::vector& groupByHashKeyVectors, hash_t hash) { auto entry = factorizedTable->appendEmptyTuple(); @@ -447,113 +544,6 @@ bool AggregateHashTable::matchFlatGroupByKeys( return true; } -uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn( - ValueVector* vector, uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { - KU_ASSERT(!vector->state->isFlat()); - auto colOffset = factorizedTable->getTableSchema()->getColOffset(colIdx); - uint64_t mayMatchIdx = 0; - if (vector->hasNoNullsGuarantee()) { - if (factorizedTable->hasNoNullGuarantee(colIdx)) { - for (auto i = 0u; i < numMayMatches; i++) { - auto idx = mayMatchIdxes[i]; - if (compareEntryFuncs[colIdx]( - vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { - mayMatchIdxes[mayMatchIdx++] = idx; - } else { - noMatchIdxes[numNoMatches++] = idx; - } - } - } else { - for (auto i = 0u; i < numMayMatches; i++) { - auto idx = mayMatchIdxes[i]; - auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( - hashSlotsToUpdateAggState[idx]->entry + - factorizedTable->getTableSchema()->getNullMapOffset(), - colIdx); - if (isEntryKeyNull) { - noMatchIdxes[numNoMatches++] = idx; - continue; - } - if (compareEntryFuncs[colIdx]( - vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { - mayMatchIdxes[mayMatchIdx++] = idx; - } else { - noMatchIdxes[numNoMatches++] = idx; - } - } - } - } else { - for (auto i = 0u; i < numMayMatches; i++) { - auto idx = mayMatchIdxes[i]; - auto isKeyVectorNull = vector->isNull(idx); - auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( - hashSlotsToUpdateAggState[idx]->entry + - factorizedTable->getTableSchema()->getNullMapOffset(), - colIdx); - if (isKeyVectorNull && isEntryKeyNull) { - mayMatchIdxes[mayMatchIdx++] = idx; - continue; - } else if (isKeyVectorNull != isEntryKeyNull) { - noMatchIdxes[numNoMatches++] = idx; - continue; - } - - if (compareEntryFuncs[colIdx]( - vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { - mayMatchIdxes[mayMatchIdx++] = idx; - } else { - noMatchIdxes[numNoMatches++] = idx; - } - } - } - return mayMatchIdx; -} - -uint64_t AggregateHashTable::matchFlatVecWithFTColumn( - ValueVector* vector, uint64_t numMayMatches, uint64_t& numNoMatches, uint32_t colIdx) { - KU_ASSERT(vector->state->isFlat()); - auto colOffset = factorizedTable->getTableSchema()->getColOffset(colIdx); - uint64_t mayMatchIdx = 0; - auto pos = vector->state->selVector->selectedPositions[0]; - auto isVectorNull = vector->isNull(pos); - for (auto i = 0u; i < numMayMatches; i++) { - auto idx = mayMatchIdxes[i]; - auto isEntryKeyNull = factorizedTable->isNonOverflowColNull( - hashSlotsToUpdateAggState[idx]->entry + - factorizedTable->getTableSchema()->getNullMapOffset(), - colIdx); - if (isEntryKeyNull && isVectorNull) { - mayMatchIdxes[mayMatchIdx++] = idx; - continue; - } else if (isEntryKeyNull != isVectorNull) { - noMatchIdxes[numNoMatches++] = idx; - continue; - } - if (compareEntryFuncs[colIdx]( - vector, pos, hashSlotsToUpdateAggState[idx]->entry + colOffset)) { - mayMatchIdxes[mayMatchIdx++] = idx; - } else { - noMatchIdxes[numNoMatches++] = idx; - } - } - return mayMatchIdx; -} - -uint64_t AggregateHashTable::matchFTEntries(const std::vector& flatKeyVectors, - const std::vector& unFlatKeyVectors, uint64_t numMayMatches, - uint64_t numNoMatches) { - auto colIdx = 0u; - for (auto& flatKeyVector : flatKeyVectors) { - numMayMatches = - matchFlatVecWithFTColumn(flatKeyVector, numMayMatches, numNoMatches, colIdx++); - } - for (auto& unFlatKeyVector : unFlatKeyVectors) { - numMayMatches = - matchUnFlatVecWithFTColumn(unFlatKeyVector, numMayMatches, numNoMatches, colIdx++); - } - return numNoMatches; -} - void AggregateHashTable::fillEntryWithInitialNullAggregateState(uint8_t* entry) { for (auto i = 0u; i < aggregateFunctions.size(); i++) { factorizedTable->updateFlatCellNoNull(entry, aggStateColIdxInFT + i, @@ -766,18 +756,30 @@ void AggregateHashTable::updateBothUnFlatDifferentDCAggVectorState( std::vector> AggregateHashTableUtils::createDistinctHashTables( MemoryManager& memoryManager, const std::vector& groupByKeyDataTypes, const std::vector>& aggregateFunctions) { + // TODO(Xiyang): move the creation of distinct hashtable schema to mapper. std::vector> distinctHTs; for (auto& aggregateFunction : aggregateFunctions) { if (aggregateFunction->isFunctionDistinct()) { std::vector distinctKeysDataTypes(groupByKeyDataTypes.size() + 1); + auto tableSchema = std::make_unique(); for (auto i = 0u; i < groupByKeyDataTypes.size(); i++) { distinctKeysDataTypes[i] = groupByKeyDataTypes[i]; + auto size = LogicalTypeUtils::getRowLayoutSize(distinctKeysDataTypes[i]); + tableSchema->appendColumn(std::make_unique( + false /* isUnflat */, 0 /* dataChunkPos */, size)); } distinctKeysDataTypes[groupByKeyDataTypes.size()] = LogicalType{aggregateFunction->parameterTypeIDs[0]}; + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + LogicalTypeUtils::getRowLayoutSize( + LogicalType{aggregateFunction->parameterTypeIDs[0]}))); + tableSchema->appendColumn(std::make_unique( + false /* isUnflat */, 0 /* dataChunkPos */, sizeof(hash_t))); std::vector> emptyFunctions; auto ht = std::make_unique(memoryManager, - std::move(distinctKeysDataTypes), emptyFunctions, 0 /* numEntriesToAllocate */); + std::move(distinctKeysDataTypes), emptyFunctions, 0 /* numEntriesToAllocate */, + std::move(tableSchema)); distinctHTs.push_back(std::move(ht)); } else { distinctHTs.push_back(nullptr); diff --git a/src/processor/operator/aggregate/hash_aggregate.cpp b/src/processor/operator/aggregate/hash_aggregate.cpp index 1fc641d76f2..6d27dd7c9bf 100644 --- a/src/processor/operator/aggregate/hash_aggregate.cpp +++ b/src/processor/operator/aggregate/hash_aggregate.cpp @@ -1,6 +1,7 @@ #include "processor/operator/aggregate/hash_aggregate.h" #include "common/utils.h" +#include "processor/result/mark_hash_table.h" using namespace kuzu::common; using namespace kuzu::function; @@ -49,38 +50,70 @@ std::pair HashAggregateSharedState::getNextRangeToRead() { return std::make_pair(startOffset, startOffset + range); } -void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { - BaseAggregate::initLocalStateInternal(resultSet, context); +HashAggregateInfo::HashAggregateInfo(std::vector flatKeysPos, + std::vector unFlatKeysPos, std::vector dependentKeysPos, + std::unique_ptr tableSchema, HashTableType hashTableType) + : flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)}, + dependentKeysPos{std::move(dependentKeysPos)}, tableSchema{std::move(tableSchema)}, + hashTableType{hashTableType} {} + +HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other) + : flatKeysPos{other.flatKeysPos}, unFlatKeysPos{other.unFlatKeysPos}, + dependentKeysPos{other.dependentKeysPos}, tableSchema{other.tableSchema->copy()}, + hashTableType{other.hashTableType} {} + +void HashAggregateLocalState::init(ResultSet& resultSet, main::ClientContext* context, + HashAggregateInfo& info, + std::vector>& aggregateFunctions) { std::vector keyDataTypes; - for (auto& pos : flatKeysPos) { - auto vector = resultSet->getValueVector(pos).get(); + for (auto& pos : info.flatKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); flatKeyVectors.push_back(vector); keyDataTypes.push_back(vector->dataType); } - for (auto& pos : unFlatKeysPos) { - auto vector = resultSet->getValueVector(pos).get(); + for (auto& pos : info.unFlatKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); unFlatKeyVectors.push_back(vector); keyDataTypes.push_back(vector->dataType); } std::vector payloadDataTypes; - for (auto& pos : dependentKeysPos) { - auto vector = resultSet->getValueVector(pos).get(); + for (auto& pos : info.dependentKeysPos) { + auto vector = resultSet.getValueVector(pos).get(); dependentKeyVectors.push_back(vector); payloadDataTypes.push_back(vector->dataType); } leadingState = unFlatKeyVectors.empty() ? flatKeyVectors[0]->state.get() : unFlatKeyVectors[0]->state.get(); - localAggregateHashTable = - make_unique(*context->clientContext->getMemoryManager(), keyDataTypes, - payloadDataTypes, aggregateFunctions, 0); + switch (info.hashTableType) { + case HashTableType::AGGREGATE_HASH_TABLE: + aggregateHashTable = std::make_unique(*context->getMemoryManager(), + keyDataTypes, payloadDataTypes, aggregateFunctions, 0, std::move(info.tableSchema)); + break; + case HashTableType::MARK_HASH_TABLE: + aggregateHashTable = std::make_unique(*context->getMemoryManager(), + keyDataTypes, payloadDataTypes, aggregateFunctions, 0, std::move(info.tableSchema)); + break; + default: + KU_UNREACHABLE; + } +} + +void HashAggregateLocalState::append( + std::vector>& aggregateInputs, uint64_t multiplicity) const { + aggregateHashTable->append(flatKeyVectors, unFlatKeyVectors, dependentKeyVectors, leadingState, + aggregateInputs, multiplicity); +} + +void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { + BaseAggregate::initLocalStateInternal(resultSet, context); + localState.init(*resultSet, context->clientContext, aggregateInfo, aggregateFunctions); } void HashAggregate::executeInternal(ExecutionContext* context) { while (children[0]->getNextTuple(context)) { - localAggregateHashTable->append(flatKeyVectors, unFlatKeyVectors, dependentKeyVectors, - leadingState, aggregateInputs, resultSet->multiplicity); + localState.append(aggregateInputs, resultSet->multiplicity); } - sharedState->appendAggregateHashTable(std::move(localAggregateHashTable)); + sharedState->appendAggregateHashTable(std::move(localState.aggregateHashTable)); } void HashAggregate::finalize(ExecutionContext* context) { diff --git a/src/processor/operator/persistent/insert_executor.cpp b/src/processor/operator/persistent/insert_executor.cpp index 65e6b25a398..c99d253c8b8 100644 --- a/src/processor/operator/persistent/insert_executor.cpp +++ b/src/processor/operator/persistent/insert_executor.cpp @@ -60,6 +60,14 @@ void NodeInsertExecutor::insert(Transaction* tx, ExecutionContext* context) { writeResult(); } +void NodeInsertExecutor::evaluateResult(ExecutionContext* context) { + for (auto& evaluator : columnDataEvaluators) { + evaluator->evaluate(context->clientContext); + } + nodeIDVector->setNull(nodeIDVector->state->selVector->selectedPositions[0], false); + writeResult(); +} + bool NodeInsertExecutor::checkConfict(Transaction* transaction) { if (conflictAction == ConflictAction::ON_CONFLICT_DO_NOTHING) { auto off = table->validateUniquenessConstraint(transaction, columnDataVectors); diff --git a/src/processor/operator/persistent/merge.cpp b/src/processor/operator/persistent/merge.cpp index 5f179a5b8c8..ccf66d830db 100644 --- a/src/processor/operator/persistent/merge.cpp +++ b/src/processor/operator/persistent/merge.cpp @@ -4,7 +4,10 @@ namespace kuzu { namespace processor { void Merge::initLocalStateInternal(ResultSet* /*resultSet_*/, ExecutionContext* context) { - markVector = resultSet->getValueVector(markPos).get(); + existenceVector = resultSet->getValueVector(existenceMark).get(); + if (distinctMark.isValid()) { + distinctVector = resultSet->getValueVector(distinctMark).get(); + } for (auto& executor : nodeInsertExecutors) { executor.init(resultSet, context); } @@ -29,9 +32,9 @@ bool Merge::getNextTuplesInternal(ExecutionContext* context) { if (!children[0]->getNextTuple(context)) { return false; } - KU_ASSERT(markVector->state->isFlat()); - auto pos = markVector->state->selVector->selectedPositions[0]; - if (!markVector->isNull(pos)) { + KU_ASSERT(existenceVector->state->isFlat()); + auto existencePos = existenceVector->state->selVector->selectedPositions[0]; + if (!existenceVector->isNull(existencePos)) { for (auto& executor : onMatchNodeSetExecutors) { executor->set(context); } @@ -39,17 +42,37 @@ bool Merge::getNextTuplesInternal(ExecutionContext* context) { executor->set(context); } } else { - for (auto& executor : nodeInsertExecutors) { - executor.insert(context->clientContext->getTx(), context); - } - for (auto& executor : relInsertExecutors) { - executor.insert(context->clientContext->getTx(), context); - } - for (auto& executor : onCreateNodeSetExecutors) { - executor->set(context); - } - for (auto& executor : onCreateRelSetExecutors) { - executor->set(context); + // pattern not exist + if (distinctVector != nullptr && + !distinctVector->getValue( + distinctVector->state->selVector->selectedPositions[0])) { + // pattern has been created + for (auto& executor : nodeInsertExecutors) { + executor.evaluateResult(context); + } + for (auto& executor : relInsertExecutors) { + executor.insert(context->clientContext->getTx(), context); + } + for (auto& executor : onMatchNodeSetExecutors) { + executor->set(context); + } + for (auto& executor : onMatchRelSetExecutors) { + executor->set(context); + } + } else { + // do insert and on create + for (auto& executor : nodeInsertExecutors) { + executor.insert(context->clientContext->getTx(), context); + } + for (auto& executor : relInsertExecutors) { + executor.insert(context->clientContext->getTx(), context); + } + for (auto& executor : onCreateNodeSetExecutors) { + executor->set(context); + } + for (auto& executor : onCreateRelSetExecutors) { + executor->set(context); + } } } return true; diff --git a/src/processor/result/CMakeLists.txt b/src/processor/result/CMakeLists.txt index cfcc402ea2f..9bab3f7059f 100644 --- a/src/processor/result/CMakeLists.txt +++ b/src/processor/result/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(kuzu_processor_result base_hash_table.cpp factorized_table.cpp flat_tuple.cpp + mark_hash_table.cpp result_set.cpp result_set_descriptor.cpp ) diff --git a/src/processor/result/mark_hash_table.cpp b/src/processor/result/mark_hash_table.cpp new file mode 100644 index 00000000000..545d4197653 --- /dev/null +++ b/src/processor/result/mark_hash_table.cpp @@ -0,0 +1,51 @@ +#include "processor/result/mark_hash_table.h" + +namespace kuzu { +namespace processor { + +MarkHashTable::MarkHashTable(storage::MemoryManager& memoryManager, + std::vector keyDataTypes, + std::vector dependentKeyDataTypes, + const std::vector>& aggregateFunctions, + uint64_t numEntriesToAllocate, std::unique_ptr tableSchema) + : AggregateHashTable(memoryManager, std::move(keyDataTypes), std::move(dependentKeyDataTypes), + std::move(aggregateFunctions), numEntriesToAllocate, std::move(tableSchema)) { + distinctColIdxInFT = hashColIdxInFT - 1; +} + +uint64_t MarkHashTable::matchFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, uint64_t numMayMatches, + uint64_t numNoMatches) { + auto colIdx = 0u; + for (auto& flatKeyVector : flatKeyVectors) { + numMayMatches = + matchFlatVecWithFTColumn(flatKeyVector, numMayMatches, numNoMatches, colIdx++); + } + for (auto& unFlatKeyVector : unFlatKeyVectors) { + numMayMatches = + matchUnFlatVecWithFTColumn(unFlatKeyVector, numMayMatches, numNoMatches, colIdx++); + } + for (auto i = 0u; i < numMayMatches; i++) { + noMatchIdxes[numNoMatches++] = mayMatchIdxes[i]; + onMatchSlotIdxes.emplace(mayMatchIdxes[i]); + } + return numNoMatches; +} + +void MarkHashTable::initializeFTEntries(const std::vector& flatKeyVectors, + const std::vector& unFlatKeyVectors, + const std::vector& dependentKeyVectors, + uint64_t numFTEntriesToInitialize) { + AggregateHashTable::initializeFTEntries( + flatKeyVectors, unFlatKeyVectors, dependentKeyVectors, numFTEntriesToInitialize); + for (auto i = 0u; i < numFTEntriesToInitialize; i++) { + auto entryIdx = entryIdxesToInitialize[i]; + auto entry = hashSlotsToUpdateAggState[entryIdx]->entry; + auto onMatch = !onMatchSlotIdxes.contains(entryIdx); + onMatchSlotIdxes.erase(entryIdx); + factorizedTable->updateFlatCellNoNull(entry, distinctColIdxInFT, &onMatch /* isOnMatch */); + } +} + +} // namespace processor +} // namespace kuzu diff --git a/test/test_files/update_node/merge_tinysnb.test b/test/test_files/update_node/merge_tinysnb.test index b25b0a0a843..b2e1c8c91c0 100644 --- a/test/test_files/update_node/merge_tinysnb.test +++ b/test/test_files/update_node/merge_tinysnb.test @@ -60,3 +60,38 @@ Runtime exception: Found duplicated primary key value 1, which violates the uniq -STATEMENT MATCH (a:person) RETURN COUNT(*); ---- 1 11 + +-CASE MergeDuplicatedKey +-STATEMENT CREATE NODE TABLE user (ID int64, primary key(ID)) +---- ok +-STATEMENT MATCH (a:person) with a.ID % 4 as result, a.age as age MERGE (u:user {ID: result}) RETURN u.ID, age +---- 8 +0|35 +2|30 +3|45 +1|20 +3|20 +0|25 +1|40 +2|83 +-STATEMENT MATCH (a:user) RETURN a.ID +---- 4 +0 +1 +2 +3 +-STATEMENT MATCH (a:person) with a.ID as id MERGE (u:user {ID: 10}) RETURN u.ID, id +---- error +Runtime exception: Constant key in merge clause is not supported yet. +-STATEMENT CREATE NODE TABLE user1 (ID int64, name string, primary key(ID)) +---- ok +-STATEMENT MATCH (a:person) with a.ID % 7 as result, a.fName as name MERGE (u:user1 {ID: result}) ON MATCH SET u.name = 'match: ' + name ON CREATE SET u.name = 'create: ' + name RETURN u.ID, u.name +---- 8 +0|create: Alice +0|match: Elizabeth +1|create: Farooq +2|create: Bob +2|match: Greg +3|create: Carol +3|match: Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +5|create: Dan