Skip to content

Commit

Permalink
[SPARK-14785] [SQL] Support correlated scalar subqueries
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
In this PR we add support for correlated scalar subqueries. An example of such a query is:
```SQL
select * from tbl1 a where a.value > (select max(value) from tbl2 b where b.key = a.key)
```
The implementation adds the `RewriteCorrelatedScalarSubquery` rule to the Optimizer. This rule plans these subqueries using `LEFT OUTER` joins. It currently supports rewrites for `Project`, `Aggregate` & `Filter` logical plans.

I could not find a well defined semantics for the use of scalar subqueries in an `Aggregate`. The current implementation currently evaluates the scalar subquery *before* aggregation. This means that you either have to make scalar subquery part of the grouping expression, or that you have to aggregate it further on. I am open to suggestions on this.

The implementation currently forces the uniqueness of a scalar subquery by enforcing that it is aggregated and that the resulting column is wrapped in an `AggregateExpression`.

## How was this patch tested?
Added tests to `SubquerySuite`.

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes apache#12822 from hvanhovell/SPARK-14785.
  • Loading branch information
hvanhovell authored and davies committed May 2, 2016
1 parent 917d05f commit f362363
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,10 @@ class Analyzer(
// Step 2: Pull out the predicates if the plan is resolved.
if (current.resolved) {
// Make sure the resolved query has the required number of output columns. This is only
// needed for IN expressions.
// needed for Scalar and IN subqueries.
if (requiredColumns > 0 && requiredColumns != current.output.size) {
failAnalysis(s"The number of fields in the value ($requiredColumns) does not " +
s"match with the number of columns in the subquery (${current.output.size})")
failAnalysis(s"The number of columns in the subquery (${current.output.size}) " +
s"does not match the required number of columns ($requiredColumns)")
}
// Pullout predicates and construct a new plan.
f.tupled(rewriteSubQuery(current, plans))
Expand All @@ -1099,8 +1099,11 @@ class Analyzer(
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
case s @ ScalarSubquery(sub, conditions, exprId)
if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}")
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, exprId) =>
resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId))
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")

case ScalarSubquery(_, conditions, _) if conditions.nonEmpty =>
failAnalysis("Correlated scalar subqueries are not supported.")

case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
Expand Down Expand Up @@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis(s"Window specification $s is not valid because $m")
case None => w
}

case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
// Make sure we are using equi-joins.
conditions.foreach {
case _: EqualTo | _: EqualNullSafe => // ok
case e => failAnalysis(
s"The correlated scalar subquery can only contain equality predicates: $e")
}

// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates which contain exactly one aggregate expressions.
// The analyzer has already checked that subquery contained only one output column, and
// added all the grouping expressions to the aggregate.
def checkAggregate(a: Aggregate): Unit = {
val aggregates = a.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
failAnalysis("The output of a correlated scalar subquery must be aggregated")
}
}

query match {
case a: Aggregate => checkAggregate(a)
case Filter(_, a: Aggregate) => checkAggregate(a)
case Project(_, a: Aggregate) => checkAggregate(a)
case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a)
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
}
s
}

operator match {
Expand Down Expand Up @@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper {
| but one table has '${firstError.output.length}' columns and another table has
| '${s.children.head.output.length}' columns""".stripMargin)

case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
p match {
case _: Filter | _: Aggregate | _: Project => // Ok
case other => failAnalysis(
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}

case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression {
protected def conditionString: String = children.mkString("[", " && ", "]")
}

object SubqueryExpression {
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
case e: SubqueryExpression if e.children.nonEmpty => true
case _ => false
}.isDefined
}
}

/**
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
Expand All @@ -55,28 +64,26 @@ case class ScalarSubquery(
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {

override def plan: LogicalPlan = SubqueryAlias(toString, query)

override lazy val resolved: Boolean = childrenResolved && query.resolved

override def dataType: DataType = query.schema.fields.head.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
query.schema.length.toString)
} else {
TypeCheckResult.TypeCheckSuccess
}
override lazy val references: AttributeSet = {
if (query.resolved) super.references -- query.outputSet
else super.references
}

override def dataType: DataType = query.schema.fields.head.dataType
override def foldable: Boolean = false
override def nullable: Boolean = true

override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
}

override def toString: String = s"subquery#${exprId.id} $conditionString"
object ScalarSubquery {
def hasCorrelatedScalarSubquery(e: Expression): Boolean = {
e.find {
case e: ScalarSubquery if e.children.nonEmpty => true
case _ => false
}.isDefined
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
Expand Down Expand Up @@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
EliminateSerialization) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Expand Down Expand Up @@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(
e => !PredicateSubquery.hasPredicateSubquery(e))
e => !SubqueryExpression.hasCorrelatedSubquery(e))
val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
Expand All @@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {

val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e))
e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))

// should not have reference to same logical plan
Expand Down Expand Up @@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val v = BindReferences.bindReference(e, attributes).eval(emptyRow)
Expand Down Expand Up @@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)

joinType match {
case Inner =>
// push down the single side `where` condition into respective sides
Expand All @@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e))
commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)

val join = Join(newLeft, newRight, Inner, newJoinCond)
Expand Down Expand Up @@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
}

/**
* This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
*/
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. The expression is rewritten and returned.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.query.output.head
}
newExpression.asInstanceOf[E]
}

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
Project(
currentChild.output :+ query.output.head,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
}
}

/**
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
} else {
a
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
Project(newExpressions, constructLeftJoins(child, subqueries))
} else {
p
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
} else {
f
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan)

override protected def validConstraints: Set[Expression] = {
val predicates = splitConjunctivePredicates(condition)
.filterNot(PredicateSubquery.hasPredicateSubquery)
.filterNot(SubqueryExpression.hasCorrelatedSubquery)
child.constraints.union(predicates.toSet)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest {
"scalar subquery with 2 columns",
testRelation.select(
(ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)),
"Scalar subquery must return only one column, but got 2" :: Nil)
"The number of columns in the subquery (2)" ::
"does not match the required number of columns (1)":: Nil)

errorTest(
"scalar subquery with no column",
Expand Down Expand Up @@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(a))
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
}

test("Correlated Scalar Subquery") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a))
assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil)
}
}
47 changes: 47 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,51 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
sql("select a from l group by 1 having exists (select 1 from r where d < min(b))"),
Row(null) :: Row(1) :: Row(3) :: Nil)
}

test("correlated scalar subquery in where") {
checkAnswer(
sql("select * from l where b < (select max(d) from r where a = c)"),
Row(2, 1.0) :: Row(2, 1.0) :: Nil)
}

test("correlated scalar subquery in select") {
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1"),
Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) ::
Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil)
}

test("correlated scalar subquery in select (null safe)") {
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1"),
Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) ::
Row(null, 5.0) :: Row(null, 5.0) :: Row(6, null) :: Nil)
}

test("correlated scalar subquery in aggregate") {
checkAnswer(
sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2"),
Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil)
}

test("non-aggregated correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated"))

val msg2 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1")
}
assert(msg2.getMessage.contains(
"The output of a correlated scalar subquery must be aggregated"))
}

test("non-equal correlated scalar subquery") {
val msg1 = intercept[AnalysisException] {
sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
"The correlated scalar subquery can only contain equality predicates"))
}
}

0 comments on commit f362363

Please sign in to comment.