diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f8ab3f43586d..59af5b7095a77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1081,10 +1081,10 @@ class Analyzer( // Step 2: Pull out the predicates if the plan is resolved. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only - // needed for IN expressions. + // needed for Scalar and IN subqueries. if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of fields in the value ($requiredColumns) does not " + - s"match with the number of columns in the subquery (${current.output.size})") + failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + + s"does not match the required number of columns ($requiredColumns)") } // Pullout predicates and construct a new plan. f.tupled(rewriteSubQuery(current, plans)) @@ -1099,8 +1099,11 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { + case s @ ScalarSubquery(sub, conditions, exprId) + if sub.resolved && conditions.isEmpty && sub.output.size != 1 => + failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6e3a14dfb920d..800bf01abd674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin} +import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -60,9 +60,6 @@ trait CheckAnalysis extends PredicateHelper { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") - case ScalarSubquery(_, conditions, _) if conditions.nonEmpty => - failAnalysis("Correlated scalar subqueries are not supported.") - case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => @@ -104,6 +101,36 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + + case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => + // Make sure we are using equi-joins. + conditions.foreach { + case _: EqualTo | _: EqualNullSafe => // ok + case e => failAnalysis( + s"The correlated scalar subquery can only contain equality predicates: $e") + } + + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, and + // added all the grouping expressions to the aggregate. + def checkAggregate(a: Aggregate): Unit = { + val aggregates = a.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + } + + query match { + case a: Aggregate => checkAggregate(a) + case Filter(_, a: Aggregate) => checkAggregate(a) + case Project(_, a: Aggregate) => checkAggregate(a) + case Project(_, Filter(_, a: Aggregate)) => checkAggregate(a) + case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") + } + s } operator match { @@ -220,6 +247,13 @@ trait CheckAnalysis extends PredicateHelper { | but one table has '${firstError.output.length}' columns and another table has | '${s.children.head.output.length}' columns""".stripMargin) + case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => + p match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") + } + case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index eed062f8bc180..5001f9a41e07e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression { protected def conditionString: String = children.mkString("[", " && ", "]") } +object SubqueryExpression { + def hasCorrelatedSubquery(e: Expression): Boolean = { + e.find { + case e: SubqueryExpression if e.children.nonEmpty => true + case _ => false + }.isDefined + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -55,28 +64,26 @@ case class ScalarSubquery( children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { - - override def plan: LogicalPlan = SubqueryAlias(toString, query) - override lazy val resolved: Boolean = childrenResolved && query.resolved - - override def dataType: DataType = query.schema.fields.head.dataType - - override def checkInputDataTypes(): TypeCheckResult = { - if (query.schema.length != 1) { - TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + - query.schema.length.toString) - } else { - TypeCheckResult.TypeCheckSuccess - } + override lazy val references: AttributeSet = { + if (query.resolved) super.references -- query.outputSet + else super.references } - + override def dataType: DataType = query.schema.fields.head.dataType override def foldable: Boolean = false override def nullable: Boolean = true - + override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan) + override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" +} - override def toString: String = s"subquery#${exprId.id} $conditionString" +object ScalarSubquery { + def hasCorrelatedScalarSubquery(e: Expression): Boolean = { + e.find { + case e: ScalarSubquery if e.children.nonEmpty => true + case _ => false + }.isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e1c969f50f2be..a3ab89dc71145 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} @@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, + RewriteCorrelatedScalarSubquery, EliminateSerialization) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: @@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { assert(input.size >= 2) if (input.size == 2) { val (joinConditions, others) = conditions.partition( - e => !PredicateSubquery.hasPredicateSubquery(e)) + e => !SubqueryExpression.hasCorrelatedSubquery(e)) val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) @@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e)) + e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan @@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false + if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val v = BindReferences.bindReference(e, attributes).eval(emptyRow) @@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) - joinType match { case Inner => // push down the single side `where` condition into respective sides @@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val (newJoinConditions, others) = - commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e)) + commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, Inner, newJoinCond) @@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } } + +/** + * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. + */ +object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { + /** + * Extract all correlated scalar subqueries from an expression. The subqueries are collected using + * the given collector. The expression is rewritten and returned. + */ + private def extractCorrelatedScalarSubqueries[E <: Expression]( + expression: E, + subqueries: ArrayBuffer[ScalarSubquery]): E = { + val newExpression = expression transform { + case s: ScalarSubquery if s.children.nonEmpty => + subqueries += s + s.query.output.head + } + newExpression.asInstanceOf[E] + } + + /** + * Construct a new child plan by left joining the given subqueries to a base plan. + */ + private def constructLeftJoins( + child: LogicalPlan, + subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { + subqueries.foldLeft(child) { + case (currentChild, ScalarSubquery(query, conditions, _)) => + Project( + currentChild.output :+ query.output.head, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + } + } + + /** + * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar + * subqueries. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + // We currently only allow correlated subqueries in an aggregate if they are part of the + // grouping expressions. As a result we need to replace all the scalar subqueries in the + // grouping expressions by their result. + val newGrouping = grouping.map { e => + subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e) + } + Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + } else { + a + } + case p @ Project(expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + Project(newExpressions, constructLeftJoins(child, subqueries)) + } else { + p + } + case f @ Filter(condition, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + if (subqueries.nonEmpty) { + Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + } else { + f + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 830a7ac77dd6c..7b4615db0661d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan) override protected def validConstraints: Set[Expression] = { val predicates = splitConjunctivePredicates(condition) - .filterNot(PredicateSubquery.hasPredicateSubquery) + .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 10bff3d6d82ed..2e88f61d491cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,7 +111,8 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "Scalar subquery must return only one column, but got 2" :: Nil) + "The number of columns in the subquery (2)" :: + "does not match the required number of columns (1)":: Nil) errorTest( "scalar subquery with no column", @@ -499,12 +500,4 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) } - - test("Correlated Scalar Subquery") { - val a = AttributeReference("a", IntegerType)() - val b = AttributeReference("b", IntegerType)() - val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) - val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a)) - assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ff3f9bb33f9a6..80bb4e05385f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -234,4 +234,51 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select a from l group by 1 having exists (select 1 from r where d < min(b))"), Row(null) :: Row(1) :: Row(3) :: Nil) } + + test("correlated scalar subquery in where") { + checkAnswer( + sql("select * from l where b < (select max(d) from r where a = c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + } + + test("correlated scalar subquery in select") { + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1"), + Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil) + } + + test("correlated scalar subquery in select (null safe)") { + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1"), + Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) :: + Row(null, 5.0) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + + test("correlated scalar subquery in aggregate") { + checkAnswer( + sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2"), + Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil) + } + + test("non-aggregated correlated scalar subquery") { + val msg1 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") + } + assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + + val msg2 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") + } + assert(msg2.getMessage.contains( + "The output of a correlated scalar subquery must be aggregated")) + } + + test("non-equal correlated scalar subquery") { + val msg1 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1") + } + assert(msg1.getMessage.contains( + "The correlated scalar subquery can only contain equality predicates")) + } }