Skip to content

Commit

Permalink
[SPARK-43780][SQL] Support correlated references in join predicates f…
Browse files Browse the repository at this point in the history
…or scalar and lateral subqueries

### What changes were proposed in this pull request?

This PR adds support to subqueries that involve joins with correlated references in join predicates, e.g.

```
select * from t0 join lateral (select * from t1 join t2 on t1a = t2a and t1a = t0a);
```

(full example in https://issues.apache.org/jira/browse/SPARK-43780)

Currently we only handle scalar and lateral subqueries.

### Why are the changes needed?

This is a valid SQL that is not yet supported by Spark SQL.

### Does this PR introduce _any_ user-facing change?

Yes, previously unsupported queries become supported.

### How was this patch tested?

Query and unit tests

Closes apache#41301 from agubichev/spark-43780-corr-predicate.

Authored-by: Andrey Gubichev <andrey.gubichev@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
agubichev authored and cloud-fan committed Aug 15, 2023
1 parent f7002fb commit 420e687
Show file tree
Hide file tree
Showing 13 changed files with 605 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
case _: Filter => true
case _: Project => usingDecorrelateInnerQueryFramework
case _: Join => usingDecorrelateInnerQueryFramework
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,18 +804,88 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)

case j @ Join(left, right, joinType, condition, _) =>
val outerReferences = collectOuterReferences(j.expressions)
// Join condition containing outer references is not supported.
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
val newOuterReferences = parentOuterReferences ++ outerReferences
val shouldPushToLeft = joinType match {
// Given 'condition', computes the tuple of
// (correlated, uncorrelated, equalityCond, predicates, equivalences).
// 'correlated' and 'uncorrelated' are the conjuncts with (resp. without)
// outer (correlated) references. Furthermore, correlated conjuncts are split
// into 'equalityCond' (those that are equalities) and all rest ('predicates').
// 'equivalences' track equivalent attributes given 'equalityCond'.
// The split is only performed if 'shouldDecorrelatePredicates' is true.
// The input parameter 'isInnerJoin' is set to true for INNER joins and helps
// determine whether some predicates can be lifted up from the join (this is only
// valid for inner joins).
// Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C = D, the output
// would be:
// correlated = (A = outer(X), B > outer(Y))
// uncorrelated = (C = D)
// equalityCond = (A = outer(X))
// predicates = (B > outer(Y))
// equivalences: (A -> outer(X))
def splitCorrelatedPredicate(
condition: Option[Expression],
isInnerJoin: Boolean,
shouldDecorrelatePredicates: Boolean):
(Seq[Expression], Seq[Expression], Seq[Expression],
Seq[Expression], AttributeMap[Attribute]) = {
// Similar to Filters above, we split the join condition (if present) into correlated
// and uncorrelated predicates, and separately handle joins under set and aggregation
// operations.
if (shouldDecorrelatePredicates) {
val conditions =
if (condition.isDefined) splitConjunctivePredicates(condition.get)
else Seq.empty[Expression]
val (correlated, uncorrelated) = conditions.partition(containsOuter)
var equivalences =
if (underSetOp) AttributeMap.empty[Attribute]
else collectEquivalentOuterReferences(correlated)
var (equalityCond, predicates) =
if (underSetOp) (Seq.empty[Expression], correlated)
else correlated.partition(canPullUpOverAgg)
// Fully preserve the join predicate for non-inner joins.
if (!isInnerJoin) {
predicates = correlated
equalityCond = Seq.empty[Expression]
equivalences = AttributeMap.empty[Attribute]
}
(correlated, uncorrelated, equalityCond, predicates, equivalences)
} else {
(Seq.empty[Expression],
if (condition.isEmpty) Seq.empty[Expression] else Seq(condition.get),
Seq.empty[Expression],
Seq.empty[Expression],
AttributeMap.empty[Attribute])
}
}

val shouldDecorrelatePredicates =
SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
if (!shouldDecorrelatePredicates) {
val outerReferences = collectOuterReferences(j.expressions)
// Join condition containing outer references is not supported.
assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
}
val (correlated, uncorrelated, equalityCond, predicates, equivalences) =
splitCorrelatedPredicate(condition, joinType == Inner, shouldDecorrelatePredicates)
val outerReferences = collectOuterReferences(j.expressions) ++
collectOuterReferences(predicates)
val newOuterReferences =
parentOuterReferences ++ outerReferences -- equivalences.keySet
var shouldPushToLeft = joinType match {
case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
case _ => hasOuterReferences(left)
}
val shouldPushToRight = joinType match {
case RightOuter | FullOuter => true
case _ => hasOuterReferences(right)
}
if (shouldDecorrelatePredicates && !shouldPushToLeft && !shouldPushToRight
&& !predicates.isEmpty) {
// Neither left nor right children of the join have correlations, but the join
// predicate does, and the correlations can not be replaced via equivalences.
// Introduce a domain join on the left side of the join
// (chosen arbitrarily) to provide values for the correlated attribute reference.
shouldPushToLeft = true;
}
val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
decorrelate(left, newOuterReferences, aggregated, underSetOp)
} else {
Expand All @@ -826,8 +896,13 @@ object DecorrelateInnerQuery extends PredicateHelper {
} else {
(right, Nil, AttributeMap.empty[Attribute])
}
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
val newJoinCond = leftJoinCond ++ rightJoinCond
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap ++
equivalences
val newCorrelated =
if (shouldDecorrelatePredicates) {
replaceOuterReferences(correlated, newOuterReferenceMap)
} else Seq.empty[Expression]
val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond
// If we push the dependent join to both sides, we can augment the join condition
// such that both sides are matched on the domain attributes. For example,
// - Left Map: {outer(c1) = c1}
Expand All @@ -836,7 +911,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
val augmentedConditions = leftOuterReferenceMap.flatMap {
case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
}
val newCondition = (condition ++ augmentedConditions).reduceOption(And)
val newCondition = (newCorrelated ++ uncorrelated
++ augmentedConditions).reduceOption(And)
val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
(newJoin, newJoinCond, newOuterReferenceMap)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4370,6 +4370,16 @@ object SQLConf {
.checkValue(_ >= 0, "The threshold of cached local relations must not be negative")
.createWithDefault(64 * 1024 * 1024)

val DECORRELATE_JOIN_PREDICATE_ENABLED =
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
.internal()
.doc("Decorrelate scalar and lateral subqueries with correlated references in join " +
"predicates. This configuration is only effective when " +
"'${DECORRELATE_INNER_QUERY_ENABLED.key}' is true.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val a3 = AttributeReference("a3", IntegerType)()
val b3 = AttributeReference("b3", IntegerType)()
val c3 = AttributeReference("c3", IntegerType)()
val a4 = AttributeReference("a4", IntegerType)()
val b4 = AttributeReference("b4", IntegerType)()
val t0 = OneRowRelation()
val testRelation = LocalRelation(a, b, c)
val testRelation2 = LocalRelation(x, y, z)
val testRelation3 = LocalRelation(a3, b3, c3)
val testRelation4 = LocalRelation(a4, b4)

private def hasOuterReferences(plan: LogicalPlan): Boolean = {
plan.exists(_.expressions.exists(SubExprUtils.containsOuter))
Expand Down Expand Up @@ -198,12 +201,15 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val innerPlan =
Join(
testRelation.as("t1"),
Filter(OuterReference(y) === 3, testRelation),
Filter(OuterReference(y) === b3, testRelation3),
Inner,
Some(OuterReference(x) === a),
JoinHint.NONE)
val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan.select()) }
assert(error.getMessage.contains("Correlated column is not allowed in join"))
val correctAnswer =
Join(
testRelation.as("t1"), testRelation3,
Inner, Some(a === a), JoinHint.NONE)
check(innerPlan, outerPlan, correctAnswer, Seq(b3 === y, x === a))
}

test("correlated values in project") {
Expand Down Expand Up @@ -454,4 +460,125 @@ class DecorrelateInnerQuerySuite extends PlanTest {
DomainJoin(Seq(x), testRelation))))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
}

test("SPARK-43780: aggregation in subquery with correlated equi-join") {
// Join in the subquery is on equi-predicates, so all the correlated references can be
// substituted by equivalent ones from the outer query, and domain join is not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner, Some(And(y === y, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
}

test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
// Join in the subquery is on non-equi-predicate, so we introduce a DomainJoin.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
val correctAnswer =
Aggregate(
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
Project(Seq(x, y, a3, b3, a),
Join(
DomainJoin(Seq(a), testRelation2),
testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}

test("SPARK-43780: aggregation in subquery with correlated left join") {
// Join in the subquery is on equi-predicates, so all the correlated references can be
// substituted by equivalent ones from the outer query, and domain join is not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
Project(Seq(x, y, a3, b3, a),
Join(DomainJoin(Seq(a), testRelation2), testRelation3, LeftOuter,
Some(And(y === a, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}

test("SPARK-43780: aggregation in subquery with correlated left join, " +
"correlation over right side") {
// Same as above, but the join predicate connects the outer reference and the column from the
// right (optional) side of the left join. Domain join is still not needed.
val outerPlan = testRelation
val innerPlan =
Aggregate(
Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
Project(Seq(x, y, a3, b3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))

val correctAnswer =
Aggregate(
Seq(b), Seq(Alias(count(Literal(1)), "a")(), b),
Project(Seq(x, y, a3, b3, b),
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
Some(And(b === b3, x === a3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
}

test("SPARK-43780: correlated left join preserves the join predicates") {
// Left outer join preserves both predicates after being decorrelated.
val outerPlan = testRelation
val innerPlan =
Filter(
IsNotNull(c3),
Project(Seq(x, y, a3, b3, c3),
Join(testRelation2, testRelation3, LeftOuter,
Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))

val correctAnswer =
Filter(
IsNotNull(c3),
Project(Seq(x, y, a3, b3, c3, b),
Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
Some(And(x === a3, b === b3)), JoinHint.NONE)))
check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
}

test("SPARK-43780: union all in subquery with correlated join") {
val outerPlan = testRelation
val innerPlan =
Union(
Seq(Project(Seq(x, b3),
Join(testRelation2, testRelation3, Inner,
Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)),
Project(Seq(a4, b4),
testRelation4)))
val correctAnswer =
Union(
Seq(Project(Seq(x, b3, a),
Project(Seq(x, b3, a),
Join(
DomainJoin(Seq(a), testRelation2),
testRelation3, Inner,
Some(And(x === a3, y === a)), JoinHint.NONE))),
Project(Seq(a4, b4, a),
DomainJoin(Seq(a),
Project(Seq(a4, b4), testRelation4)))))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,72 @@ Project [c1#x, c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join Inner, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND t2.c1 != t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join Inner, (NOT (c1#x = c1#x) AND NOT (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = t4.c1 AND t2.c1 = t1.c1)
-- !query analysis
Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x], LeftOuter
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [c1#x, c2#x, c1#x, c2#x]
: +- Join LeftOuter, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
: :- SubqueryAlias spark_catalog.default.t4
: : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
-- !query analysis
Expand Down
Loading

0 comments on commit 420e687

Please sign in to comment.