Skip to content

Commit

Permalink
Merge duplicate key fix (kuzudb#3207)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Apr 4, 2024
1 parent b3c6dc9 commit fa0ef79
Show file tree
Hide file tree
Showing 46 changed files with 837 additions and 404 deletions.
3 changes: 2 additions & 1 deletion src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(
auto boundGraphPattern = bindGraphPattern(mergeClause.getPatternElementsRef());
rewriteMatchPattern(boundGraphPattern);
auto createInfos = bindInsertInfos(boundGraphPattern.queryGraphCollection, patternsScope);
auto distinctMark = createVariable("__distinctMark", *LogicalType::BOOL());
auto boundMergeClause =
std::make_unique<BoundMergeClause>(std::move(boundGraphPattern.queryGraphCollection),
std::move(boundGraphPattern.where), std::move(createInfos));
std::move(boundGraphPattern.where), std::move(createInfos), std::move(distinctMark));
if (mergeClause.hasOnMatchSetItems()) {
for (auto& [lhs, rhs] : mergeClause.getOnMatchSetItemsRef()) {
auto setPropertyInfo = bindSetPropertyInfo(lhs.get(), rhs.get());
Expand Down
54 changes: 28 additions & 26 deletions src/include/binder/query/updating_clause/bound_merge_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,92 @@ namespace binder {
class BoundMergeClause : public BoundUpdatingClause {
public:
BoundMergeClause(QueryGraphCollection queryGraphCollection,
std::shared_ptr<Expression> predicate, std::vector<BoundInsertInfo> insertInfos)
: BoundUpdatingClause{common::ClauseType::MERGE}, queryGraphCollection{std::move(
queryGraphCollection)},
predicate{std::move(predicate)}, insertInfos{std::move(insertInfos)} {}
std::shared_ptr<Expression> predicate, std::vector<BoundInsertInfo> insertInfos,
std::shared_ptr<Expression> distinctMark)
: BoundUpdatingClause{common::ClauseType::MERGE},
queryGraphCollection{std::move(queryGraphCollection)}, predicate{std::move(predicate)},
insertInfos{std::move(insertInfos)}, distinctMark{std::move(distinctMark)} {}

inline const QueryGraphCollection* getQueryGraphCollection() const {
return &queryGraphCollection;
}
inline bool hasPredicate() const { return predicate != nullptr; }
inline std::shared_ptr<Expression> getPredicate() const { return predicate; }
const QueryGraphCollection* getQueryGraphCollection() const { return &queryGraphCollection; }
bool hasPredicate() const { return predicate != nullptr; }
std::shared_ptr<Expression> getPredicate() const { return predicate; }

inline const std::vector<BoundInsertInfo>& getInsertInfosRef() const { return insertInfos; }
inline const std::vector<BoundSetPropertyInfo>& getOnMatchSetInfosRef() const {
const std::vector<BoundInsertInfo>& getInsertInfosRef() const { return insertInfos; }
const std::vector<BoundSetPropertyInfo>& getOnMatchSetInfosRef() const {
return onMatchSetPropertyInfos;
}
inline const std::vector<BoundSetPropertyInfo>& getOnCreateSetInfosRef() const {
const std::vector<BoundSetPropertyInfo>& getOnCreateSetInfosRef() const {
return onCreateSetPropertyInfos;
}

inline bool hasInsertNodeInfo() const {
bool hasInsertNodeInfo() const {
return hasInsertInfo(
[](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; });
}
inline std::vector<const BoundInsertInfo*> getInsertNodeInfos() const {
std::vector<const BoundInsertInfo*> getInsertNodeInfos() const {
return getInsertInfos(
[](const BoundInsertInfo& info) { return info.tableType == common::TableType::NODE; });
}
inline bool hasInsertRelInfo() const {
bool hasInsertRelInfo() const {
return hasInsertInfo(
[](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; });
}
inline std::vector<const BoundInsertInfo*> getInsertRelInfos() const {
std::vector<const BoundInsertInfo*> getInsertRelInfos() const {
return getInsertInfos(
[](const BoundInsertInfo& info) { return info.tableType == common::TableType::REL; });
}

inline bool hasOnMatchSetNodeInfo() const {
bool hasOnMatchSetNodeInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<const BoundSetPropertyInfo*> getOnMatchSetNodeInfos() const {
std::vector<const BoundSetPropertyInfo*> getOnMatchSetNodeInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnMatchSetRelInfo() const {
bool hasOnMatchSetRelInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<const BoundSetPropertyInfo*> getOnMatchSetRelInfos() const {
std::vector<const BoundSetPropertyInfo*> getOnMatchSetRelInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline bool hasOnCreateSetNodeInfo() const {
bool hasOnCreateSetNodeInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<const BoundSetPropertyInfo*> getOnCreateSetNodeInfos() const {
std::vector<const BoundSetPropertyInfo*> getOnCreateSetNodeInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnCreateSetRelInfo() const {
bool hasOnCreateSetRelInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<const BoundSetPropertyInfo*> getOnCreateSetRelInfos() const {
std::vector<const BoundSetPropertyInfo*> getOnCreateSetRelInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline void addOnMatchSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) {
void addOnMatchSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) {
onMatchSetPropertyInfos.push_back(std::move(setPropertyInfo));
}
inline void addOnCreateSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) {
void addOnCreateSetPropertyInfo(BoundSetPropertyInfo setPropertyInfo) {
onCreateSetPropertyInfos.push_back(std::move(setPropertyInfo));
}

std::shared_ptr<Expression> getDistinctMark() const { return distinctMark; }

private:
bool hasInsertInfo(const std::function<bool(const BoundInsertInfo& info)>& check) const;
std::vector<const BoundInsertInfo*> getInsertInfos(
Expand All @@ -121,6 +122,7 @@ class BoundMergeClause : public BoundUpdatingClause {
std::vector<BoundSetPropertyInfo> onMatchSetPropertyInfos;
// Update on create
std::vector<BoundSetPropertyInfo> onCreateSetPropertyInfos;
std::shared_ptr<Expression> distinctMark;
};

} // namespace binder
Expand Down
3 changes: 2 additions & 1 deletion src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace kuzu {
namespace optimizer {

class FactorizationRewriter : public LogicalOperatorVisitor {
class FactorizationRewriter final : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

Expand All @@ -19,6 +19,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitAccumulate(planner::LogicalOperator* op) override;
void visitMarkAccumulate(planner::LogicalOperator*) override;
void visitAggregate(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitLimit(planner::LogicalOperator* op) override;
Expand Down
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitMarkAccumulate(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitMarkAccumulateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDistinct(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDistinctReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
16 changes: 8 additions & 8 deletions src/include/planner/operator/logical_accumulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@
namespace kuzu {
namespace planner {

class LogicalAccumulate : public LogicalOperator {
class LogicalAccumulate final : public LogicalOperator {
public:
LogicalAccumulate(common::AccumulateType accumulateType, binder::expression_vector flatExprs,
std::shared_ptr<binder::Expression> offset, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)},
accumulateType{accumulateType}, flatExprs{std::move(flatExprs)}, offset{
std::move(offset)} {}

void computeFactorizedSchema() final;
void computeFlatSchema() final;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

f_group_pos_set getGroupPositionsToFlatten() const;

inline std::string getExpressionsForPrinting() const final { return std::string{}; }
std::string getExpressionsForPrinting() const override { return {}; }

inline common::AccumulateType getAccumulateType() const { return accumulateType; }
inline binder::expression_vector getExpressionsToAccumulate() const {
common::AccumulateType getAccumulateType() const { return accumulateType; }
binder::expression_vector getPayloads() const {
return children[0]->getSchema()->getExpressionsInScope();
}
inline std::shared_ptr<binder::Expression> getOffset() const { return offset; }
std::shared_ptr<binder::Expression> getOffset() const { return offset; }

inline std::unique_ptr<LogicalOperator> copy() final {
std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(
accumulateType, flatExprs, offset, children[0]->copy());
}
Expand Down
53 changes: 20 additions & 33 deletions src/include/planner/operator/logical_distinct.h
Original file line number Diff line number Diff line change
@@ -1,56 +1,43 @@
#pragma once

#include "planner/operator/logical_operator.h"
#include "planner/operator/schema.h"

namespace kuzu {
namespace planner {

class LogicalDistinct : public LogicalOperator {
public:
LogicalDistinct(
binder::expression_vector keyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
keyExpressions{std::move(keyExpressions)} {}
LogicalDistinct(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{
std::move(dependentKeyExpressions)} {}
LogicalDistinct(binder::expression_vector keys, std::shared_ptr<LogicalOperator> child)
: LogicalDistinct{
LogicalOperatorType::DISTINCT, keys, binder::expression_vector{}, std::move(child)} {}
LogicalDistinct(LogicalOperatorType type, binder::expression_vector keys,
binder::expression_vector payloads, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{type, std::move(child)}, keys{std::move(keys)}, payloads{std::move(
payloads)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

f_group_pos_set getGroupsPosToFlatten();
virtual f_group_pos_set getGroupsPosToFlatten();

std::string getExpressionsForPrinting() const override;

inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline binder::expression_vector getAllDistinctExpressions() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
return result;
}
binder::expression_vector getKeys() const { return keys; }
void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); }
binder::expression_vector getPayloads() const { return payloads; }
void setPayloads(binder::expression_vector expressions) { payloads = std::move(expressions); }

std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDistinct>(
keyExpressions, dependentKeyExpressions, children[0]->copy());
return make_unique<LogicalDistinct>(operatorType, keys, payloads, children[0]->copy());
}

private:
binder::expression_vector keyExpressions;
// See logical_aggregate.h for details.
binder::expression_vector dependentKeyExpressions;
protected:
binder::expression_vector getKeysAndPayloads() const;

protected:
binder::expression_vector keys;
// Payloads meaning additional keys that are functional dependent on the keys above.
binder::expression_vector payloads;
};

} // namespace planner
Expand Down
35 changes: 35 additions & 0 deletions src/include/planner/operator/logical_mark_accmulate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "planner/operator/logical_operator.h"

namespace kuzu {
namespace planner {

class LogicalMarkAccumulate final : public LogicalOperator {
public:
LogicalMarkAccumulate(binder::expression_vector keys, std::shared_ptr<binder::Expression> mark,
std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::MARK_ACCUMULATE, std::move(child)},
keys{std::move(keys)}, mark{std::move(mark)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

f_group_pos_set getGroupsPosToFlatten() const;

std::string getExpressionsForPrinting() const override { return {}; }
binder::expression_vector getKeys() const { return keys; }
binder::expression_vector getPayloads() const;
std::shared_ptr<binder::Expression> getMark() const { return mark; }

std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalMarkAccumulate>(keys, mark, children[0]->copy());
}

private:
binder::expression_vector keys;
std::shared_ptr<binder::Expression> mark;
};

} // namespace planner
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum class LogicalOperatorType : uint8_t {
INTERSECT,
INSERT,
LIMIT,
MARK_ACCUMULATE,
MERGE,
MULTIPLICITY_REDUCER,
NODE_LABEL_FILTER,
Expand Down
Loading

0 comments on commit fa0ef79

Please sign in to comment.