From 3e792f3569d7a397e2817ac3b66816a3c35feed0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 16:23:04 -0800 Subject: [PATCH 01/18] generate aggregation with grouping keys --- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/BufferedRowIterator.java | 6 +- .../sql/execution/WholeStageCodegen.scala | 168 +++++++++----- .../aggregate/TungstenAggregate.scala | 217 +++++++++++++++--- .../spark/sql/execution/basicOperators.scala | 36 +-- .../BenchmarkWholeStageCodegen.scala | 119 +++++++++- .../execution/WholeStageCodegenSuite.scala | 10 + 7 files changed, 439 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6ef..ec31db19b94b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index b1bbb1da10a39..6acf70dbbad0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import scala.collection.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -34,7 +36,7 @@ public class BufferedRowIterator { // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); - public boolean hasNext() { + public boolean hasNext() throws IOException { if (currentRow == null) { processNext(); } @@ -56,7 +58,7 @@ public void setInput(Iterator iter) { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() { + protected void processNext() throws IOException { if (input.hasNext()) { currentRow = input.next(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de9804..89c4d911c3307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -41,10 +41,15 @@ trait CodegenSupport extends SparkPlan { */ private var parent: CodegenSupport = null + /** + * Returns the RDD of InternalRow which generates the input rows. + */ + def upstream(): RDD[InternalRow] + /** * Returns an input RDD of InternalRow and Java source code to process them. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent doProduce(ctx) } @@ -66,17 +71,40 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + if (child eq this) { + // This is called by itself, pass to it's parent + if (input != null) { + assert(input.length == output.length) + } + parent.consume(ctx, this, input, row) + } else { + // This is called by child + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, child, evals)} + """.stripMargin + } else { + doConsume(ctx, child, input) + } + } } - /** * Generate the Java source code to process the rows from child SparkPlan. * @@ -89,7 +117,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -105,24 +135,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def supportCodegen: Boolean = true - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" + s""" | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, columns)} + | ${consume(ctx, this, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } override def doExecute(): RDD[InternalRow] = { @@ -165,34 +194,34 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) override def output: Seq[Attribute] = plan.output override def doExecute(): RDD[InternalRow] = { - val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) - val references = ctx.references.toArray - val source = s""" - public Object generate(Object[] references) { - return new GeneratedIterator(references); - } - class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + plan.upstream().mapPartitions { iter => + val ctx = new CodegenContext + val code = plan.produce(ctx, this) + val references = ctx.references.toArray + val source = s""" + public Object generate(Object[] references) { + return new GeneratedIterator(references); + } + + class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + + private Object[] references; + ${ctx.declareMutableStates()} + + public GeneratedIterator(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + protected void processNext() throws java.io.IOException { + $code + } + } + """ + // try to compile, helpful for debug + // println(s"${CodeFormatter.format(source)}") - private Object[] references; - ${ctx.declareMutableStates()} - - public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - protected void processNext() { - $code - } - } - """ - // try to compile, helpful for debug - // println(s"${CodeFormatter.format(source)}") - CodeGenerator.compile(source) - - rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +232,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consume( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | ${code.code.trim} - | currentRow = ${code.value}; + | currentRow = $row; | return; - """.stripMargin + """.stripMargin } else { - // There is no columns - s""" - | currentRow = unsafeRow; - | return; + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; """.stripMargin + } } } @@ -246,7 +293,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -291,8 +338,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d252..9ddf9f959285d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -80,6 +82,7 @@ case class TungstenAggregate( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") + assert(modes.contains(Final) || !sqlContext.conf.wholeStageEnabled) child.execute().mapPartitions { iter => val hasInput = iter.hasNext @@ -114,20 +117,39 @@ case class TungstenAggregate( } } + private val modes = aggregateExpressions.map(_.mode).distinct + override def supportCodegen: Boolean = { - groupingExpressions.isEmpty && - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation without grouping keys only have one row, do not need to codegen + (!(modes.contains(Final) || modes.contains(Complete))) + } + + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, child, input) + } else { + doConsumeWithKeys(ctx, child, input) + } } // The variables used as aggregation buffer private var bufVars: Seq[ExprCode] = _ - private val modes = aggregateExpressions.map(_.mode).distinct - - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -146,26 +168,26 @@ case class TungstenAggregate( ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = - s""" - | if (!$initAgg) { - | $initAgg = true; - | - | // initialize aggregation buffer - | ${bufVars.map(_.code).mkString("\n")} - | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} - | } - """.stripMargin - - (rdd, source) + val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | // initialize aggregation buffer + | ${bufVars.map(_.code).mkString("\n")} + | + | $childSource + | + | // output the result + | ${consume(ctx, this, bufVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) // the mode could be only Partial or PartialMerge @@ -194,6 +216,145 @@ case class TungstenAggregate( """.stripMargin } + + def addOjb(ctx: CodegenContext, name: String, obj: Any, className: String = null): String = { + val term = ctx.freshName(name) + val idx = ctx.references.length + ctx.references += obj + val clsName = if (className == null) obj.getClass.getName else className + ctx.addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + + // The name for HashMap + var hashMapTerm: String = _ + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create initialized aggregate buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferAttributes = functions.flatMap(_.aggBufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val hashMap = new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + hashMapTerm = addOjb(ctx, "hashMap", hashMap) + + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + val joinerTerm = addOjb(ctx, "unsafeRowJoiner", unsafeRowJoiner, + classOf[UnsafeRowJoiner].getName) + val resultRow = ctx.freshName("resultRow") + + val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | $childSource + | + | $iterTerm = $hashMapTerm.iterator(); + | } + | + | // output the result + | while ($iterTerm.next()) { + | UnsafeRow $resultRow = + | $joinerTerm.join((UnsafeRow) $iterTerm.getKey(), (UnsafeRow) $iterTerm.getValue()); + | ${consume(ctx, this, null, resultRow)} + | } + | $hashMapTerm.free(); + """.stripMargin + } + + private def doConsumeWithKeys( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val keyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val key = keyCode.value + val buffer = ctx.freshName("aggBuffer") + + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + // the model could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) + } + + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) + ctx.currentVars = new Array[ExprCode](groupingExpressions.length) ++ input + ctx.INPUT_ROW = buffer + // TODO: support subexpression elimination + val evals = boundExpr.map(_.gen(ctx)) + val updates = evals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + if (updateExpr(i).nullable) { + if (dt.isInstanceOf[DecimalType]) { + s""" + | if (!${ev.isNull}) { + | ${ctx.setColumn(buffer, dt, i, ev.value)}; + | } else { + | ${ctx.setColumn(buffer, dt, i, "null")}; + | } + """.stripMargin + } else { + s""" + | if (!${ev.isNull}) { + | ${ctx.setColumn(buffer, dt, i, ev.value)}; + | } else { + | $buffer.setNullAt($i); + | } + """.stripMargin + } + } else { + s""" + | ${ctx.setColumn(buffer, dt, i, ev.value)}; + """.stripMargin + } + } + + s""" + | // Aggregate + | + | // generate grouping key + | ${keyCode.code} + | UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + | if ($buffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + | + | // evaluate aggregate function + | ${evals.map(_.code).mkString("\n")} + | + | // update aggregate buffer + | ${updates.mkString("\n")} + """.stripMargin + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5ec..cd1ab709cab39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,7 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } @@ -49,7 +53,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) s""" | ${output.map(_.code).mkString("\n")} | - | ${consume(ctx, output)} + | ${consume(ctx, this, output)} """.stripMargin } @@ -76,7 +80,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } @@ -88,7 +96,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s""" | ${eval.code} | if (!${eval.isNull} && ${eval.value}) { - | ${consume(ctx, ctx.currentVars)} + | ${consume(ctx, this, ctx.currentVars)} | } """.stripMargin } @@ -153,7 +161,12 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { val initTerm = ctx.freshName("range_initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") val partitionEnd = ctx.freshName("range_partitionEnd") @@ -172,10 +185,7 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) - - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; @@ -215,15 +225,9 @@ case class Range( | if ($number < $value ^ ${step}L < 0) { | $overflow = true; | } - | ${consume(ctx, Seq(ev))} + | ${consume(ctx, this, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index c4aad398bfa54..0937c5a1c92ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** @@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testWholeStage(values: Int): Unit = { + val benchmark = new Benchmark("rang/filter/aggregate", values) - benchmark.addCase("Without whole stage codegen") { iter => + benchmark.addCase("Without codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.range(values).filter("(id & 1) = 1").count() } - benchmark.addCase("With whole stage codegen") { iter => + benchmark.addCase("With codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.range(values).filter("(id & 1) = 1").count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - Without whole stage codegen 7775.53 26.97 1.00 X - With whole stage codegen 342.15 612.94 22.73 X + Without codegen 7775.53 26.97 1.00 X + With codegen 342.15 612.94 22.73 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + def testAggregateWithKey(values: Int): Unit = { + val benchmark = new Benchmark("Aggregate with keys", values) + + benchmark.addCase("Aggregate w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").count().count() + } + benchmark.addCase(s"Aggregate w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").count().count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 8271.75 6.34 1.00 X + Aggregate w codegen 5066.57 10.35 1.63 X + */ + benchmark.run() + } + + def testBytesToBytesMap(values: Int): Unit = { + val benchmark = new Benchmark("BytesToBytesMap", values) + + benchmark.addCase("hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < values) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + s += h + i += 1 + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + while (i < values) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, + loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { + loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + } + } + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + hash 662.06 79.19 1.00 X + BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X + BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + */ + benchmark.run() + } + + test("benchmark") { + // testWholeStage(1024 * 1024 * 200) + // testAggregateWithKey(1024 * 1024 * 50) + // testBytesToBytesMap(1024 * 1024 * 50) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 300788c88ab2f..8adaeb69e6dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -47,4 +47,14 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + df.explain() + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } } From 2f1a0821f6f6dc7b968f04fdf2ddfc93ec6b6677 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 16:58:40 -0800 Subject: [PATCH 02/18] support Final aggregate --- .../aggregate/TungstenAggregate.scala | 72 ++++++++++++++++--- 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9ddf9f959285d..17fe039b7c6cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -121,9 +121,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation without grouping keys only have one row, do not need to codegen - (!(modes.contains(Final) || modes.contains(Complete))) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } override def upstream(): RDD[InternalRow] = { @@ -168,6 +166,17 @@ case class TungstenAggregate( ExprCode(ev.code + initVars, isNull, value) } + // generate variables for output + val resultVars = if (modes.contains(Final)) { + ctx.currentVars = bufVars + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + resultExpressions.map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + } else { + bufVars.map(ev => ExprCode("", ev.isNull, ev.value)) + } + val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) s""" | if (!$initAgg) { @@ -179,7 +188,8 @@ case class TungstenAggregate( | $childSource | | // output the result - | ${consume(ctx, this, bufVars)} + | ${resultVars.map(_.code).mkString("\n")} + | ${consume(ctx, this, resultVars)} | } """.stripMargin } @@ -257,10 +267,50 @@ case class TungstenAggregate( val iterTerm = ctx.freshName("mapIter") ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - val joinerTerm = addOjb(ctx, "unsafeRowJoiner", unsafeRowJoiner, - classOf[UnsafeRowJoiner].getName) - val resultRow = ctx.freshName("resultRow") + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = if (modes.contains(Final)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + ctx.INPUT_ROW = bufferTerm + val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + + } + ctx.currentVars = keyVars ++ bufferVars + val inputAttrs = groupingAttributes ++ bufferAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + | ${keyVars.map(_.code).mkString("\n")} + | ${bufferVars.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + | ${consume(ctx, this, resultVars)} + """.stripMargin + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + val joinerTerm = addOjb(ctx, "unsafeRowJoiner", unsafeRowJoiner, + classOf[UnsafeRowJoiner].getName) + val resultRow = ctx.freshName("resultRow") + s""" + | UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + | ${consume(ctx, this, null, resultRow)} + """.stripMargin + + } else { + // only grouping key + s""" + | ${consume(ctx, this, null, keyTerm)} + """.stripMargin + } val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) s""" @@ -274,9 +324,9 @@ case class TungstenAggregate( | | // output the result | while ($iterTerm.next()) { - | UnsafeRow $resultRow = - | $joinerTerm.join((UnsafeRow) $iterTerm.getKey(), (UnsafeRow) $iterTerm.getValue()); - | ${consume(ctx, this, null, resultRow)} + | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getKey(); + | $outputCode | } | $hashMapTerm.free(); """.stripMargin From 7d1bd43aafd7c38120b9508830e7a22db11371b4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 18:57:22 -0800 Subject: [PATCH 03/18] fix tests --- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../sql/execution/WholeStageCodegen.scala | 55 ++++++++------ .../aggregate/TungstenAggregate.scala | 76 +++++++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 +- 4 files changed, 91 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a351933a366c..69eacc4a3d8d0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -599,7 +599,9 @@ public String toString() { build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } - build.deleteCharAt(build.length() - 1); + if (sizeInBytes > 0) { + build.deleteCharAt(build.length() - 1); + } build.append(']'); return build.toString(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 89c4d911c3307..4184aa28daaa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -74,34 +74,41 @@ trait CodegenSupport extends SparkPlan { protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ def consume( ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { - if (child eq this) { - // This is called by itself, pass to it's parent - if (input != null) { - assert(input.length == output.length) + // This is called by itself, pass to it's parent + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) + } + + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + // This is called by child + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) } - parent.consume(ctx, this, input, row) + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, child, evals)} + """.stripMargin } else { - // This is called by child - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, child, evals)} - """.stripMargin - } else { - doConsume(ctx, child, input) - } + doConsume(ctx, child, input) } } @@ -179,9 +186,11 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * doProduce() ---> execute() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| + * | + * doConsume() * | - * doConsume() <----- consume() + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -240,7 +249,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) throw new UnsupportedOperationException } - override def consume( + override def consumeChild( ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 17fe039b7c6cd..c7bb1985f4b84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -82,7 +82,6 @@ case class TungstenAggregate( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - assert(modes.contains(Final) || !sqlContext.conf.wholeStageEnabled) child.execute().mapPartitions { iter => val hasInput = iter.hasNext @@ -167,14 +166,23 @@ case class TungstenAggregate( } // generate variables for output - val resultVars = if (modes.contains(Final)) { + val (resultVars, genResult) = if (modes.contains(Final)) { ctx.currentVars = bufVars val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - resultExpressions.map { e => + val aggResults = functions.map(_.evaluateExpression).map { e => BindReferences.bindReference(e, bufferAttrs).gen(ctx) } + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) } else { - bufVars.map(ev => ExprCode("", ev.isNull, ev.value)) + // output the aggregate buffer directly + (bufVars, "") } val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) @@ -188,7 +196,8 @@ case class TungstenAggregate( | $childSource | | // output the result - | ${resultVars.map(_.code).mkString("\n")} + | $genResult + | | ${consume(ctx, this, resultVars)} | } """.stripMargin @@ -200,19 +209,20 @@ case class TungstenAggregate( input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val codes = updateExpr.zipWithIndex.map { case (e, i) => + val ev = BindReferences.bindReference[Expression](e, inputAttr).gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -280,17 +290,24 @@ case class TungstenAggregate( ctx.INPUT_ROW = bufferTerm val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) - } - ctx.currentVars = keyVars ++ bufferVars - val inputAttrs = groupingAttributes ++ bufferAttributes + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttributes).gen(ctx) + } + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, inputAttrs).gen(ctx) } s""" | ${keyVars.map(_.code).mkString("\n")} | ${bufferVars.map(_.code).mkString("\n")} + | ${aggResults.map(_.code).mkString("\n")} | ${resultVars.map(_.code).mkString("\n")} + | | ${consume(ctx, this, resultVars)} """.stripMargin @@ -307,8 +324,14 @@ case class TungstenAggregate( } else { // only grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } s""" - | ${consume(ctx, this, null, keyTerm)} + | ${eval.map(_.code).mkString("\n")} + | ${consume(ctx, this, eval)} """.stripMargin } @@ -325,7 +348,7 @@ case class TungstenAggregate( | // output the result | while ($iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputCode | } | $hashMapTerm.free(); @@ -346,16 +369,19 @@ case class TungstenAggregate( // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the model could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val inputAttr = bufferAttrs ++ child.output val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) - ctx.currentVars = new Array[ExprCode](groupingExpressions.length) ++ input + ctx.currentVars = new Array[ExprCode](bufferAttrs.length) ++ input ctx.INPUT_ROW = buffer // TODO: support subexpression elimination val evals = boundExpr.map(_.gen(ctx)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 47308966e92cb..21f11271c9277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1973,8 +1973,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { countAcc.++=(1) x }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + //TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + } // Would be nice if semantic equals for `+` understood commutative verifyCallCount( From 788078668795458aa29a55d18e2b23686992df8d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 22:05:57 -0800 Subject: [PATCH 04/18] cleanup --- .../expressions/codegen/CodeGenerator.scala | 14 ++++++ .../sql/execution/WholeStageCodegen.scala | 10 +--- .../aggregate/TungstenAggregate.scala | 48 ++++++++----------- .../spark/sql/execution/basicOperators.scala | 6 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../execution/WholeStageCodegenSuite.scala | 1 - 6 files changed, 39 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad374..455778193a475 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,20 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4184aa28daaa1..3bb560fa07d7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -76,12 +76,7 @@ trait CodegenSupport extends SparkPlan { /** * Consume the columns generated from current SparkPlan, call it's parent. */ - def consume( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - // This is called by itself, pass to it's parent + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { if (input != null) { assert(input.length == output.length) } @@ -96,7 +91,6 @@ trait CodegenSupport extends SparkPlan { child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { - // This is called by child if (row != null) { ctx.currentVars = null ctx.INPUT_ROW = row @@ -156,7 +150,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, this, columns)} + | ${consume(ctx, columns)} | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c7bb1985f4b84..d1bcfe5b50f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -116,6 +116,7 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct override def supportCodegen: Boolean = { @@ -166,12 +167,14 @@ case class TungstenAggregate( } // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final)) { + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results ctx.currentVars = bufVars val bufferAttrs = functions.flatMap(_.aggBufferAttributes) val aggResults = functions.map(_.evaluateExpression).map { e => BindReferences.bindReference(e, bufferAttrs).gen(ctx) } + // evaluate result expressions ctx.currentVars = aggResults val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, aggregateAttributes).gen(ctx) @@ -185,7 +188,6 @@ case class TungstenAggregate( (bufVars, "") } - val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) s""" | if (!$initAgg) { | $initAgg = true; @@ -193,12 +195,12 @@ case class TungstenAggregate( | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | | // output the result | $genResult | - | ${consume(ctx, this, resultVars)} + | ${consume(ctx, resultVars)} | } """.stripMargin } @@ -209,6 +211,7 @@ case class TungstenAggregate( input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output val updateExpr = aggregateExpressions.flatMap { e => e.mode match { case Partial | Complete => @@ -217,12 +220,10 @@ case class TungstenAggregate( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = updateExpr.zipWithIndex.map { case (e, i) => - val ev = BindReferences.bindReference[Expression](e, inputAttr).gen(ctx) + val updates = updateExpr.zipWithIndex.map { case (e, i) => + val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -232,20 +233,10 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | ${updates.mkString("")} """.stripMargin } - - def addOjb(ctx: CodegenContext, name: String, obj: Any, className: String = null): String = { - val term = ctx.freshName(name) - val idx = ctx.references.length - ctx.references += obj - val clsName = if (className == null) obj.getClass.getName else className - ctx.addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") - term - } - // The name for HashMap var hashMapTerm: String = _ @@ -272,15 +263,16 @@ case class TungstenAggregate( TaskContext.get().taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) - hashMapTerm = addOjb(ctx, "hashMap", hashMap) + hashMapTerm = ctx.addReferenceObj("hashMap", hashMap) + // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final)) { + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm @@ -308,18 +300,18 @@ case class TungstenAggregate( | ${aggResults.map(_.code).mkString("\n")} | ${resultVars.map(_.code).mkString("\n")} | - | ${consume(ctx, this, resultVars)} + | ${consume(ctx, resultVars)} """.stripMargin } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // This should be the last operator in a stage, we should output UnsafeRow directly val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - val joinerTerm = addOjb(ctx, "unsafeRowJoiner", unsafeRowJoiner, + val joinerTerm = ctx.addReferenceObj("unsafeRowJoiner", unsafeRowJoiner, classOf[UnsafeRowJoiner].getName) val resultRow = ctx.freshName("resultRow") s""" | UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); - | ${consume(ctx, this, null, resultRow)} + | ${consume(ctx, null, resultRow)} """.stripMargin } else { @@ -331,16 +323,15 @@ case class TungstenAggregate( } s""" | ${eval.map(_.code).mkString("\n")} - | ${consume(ctx, this, eval)} + | ${consume(ctx, eval)} """.stripMargin } - val childSource = child.asInstanceOf[CodegenSupport].produce(ctx, this) s""" | if (!$initAgg) { | $initAgg = true; | - | $childSource + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | | $iterTerm = $hashMapTerm.iterator(); | } @@ -351,6 +342,7 @@ case class TungstenAggregate( | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputCode | } + | | $hashMapTerm.free(); """.stripMargin } @@ -413,8 +405,6 @@ case class TungstenAggregate( } s""" - | // Aggregate - | | // generate grouping key | ${keyCode.code} | UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cd1ab709cab39..f6a142872f72f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -53,7 +53,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) s""" | ${output.map(_.code).mkString("\n")} | - | ${consume(ctx, this, output)} + | ${consume(ctx, output)} """.stripMargin } @@ -96,7 +96,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s""" | ${eval.code} | if (!${eval.isNull} && ${eval.value}) { - | ${consume(ctx, this, ctx.currentVars)} + | ${consume(ctx, ctx.currentVars)} | } """.stripMargin } @@ -225,7 +225,7 @@ case class Range( | if ($number < $value ^ ${step}L < 0) { | $overflow = true; | } - | ${consume(ctx, this, Seq(ev))} + | ${consume(ctx, Seq(ev))} | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 21f11271c9277..d2d271ac93d26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1973,7 +1973,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { countAcc.++=(1) x }) - //TODO: support subexpression elimination in whole stage codegen + // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { verifyCallCount( df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 8adaeb69e6dde..c2516509dfbbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -50,7 +50,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("Aggregate with grouping keys should be included in WholeStageCodegen") { val df = sqlContext.range(3).groupBy("id").count().orderBy("id") - df.explain() val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && From 407460d15a7dc692b852c035a5473bd335f2c87b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 21 Jan 2016 00:19:26 -0800 Subject: [PATCH 05/18] generated BroadcastHashJoin --- .../sql/execution/BufferedRowIterator.java | 13 ++- .../sql/execution/WholeStageCodegen.scala | 14 +-- .../aggregate/TungstenAggregate.scala | 2 + .../spark/sql/execution/basicOperators.scala | 2 + .../execution/joins/BroadcastHashJoin.scala | 85 ++++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 15 +++- 6 files changed, 113 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 6acf70dbbad0c..96274cf9b2b9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.LinkedList; import scala.collection.Iterator; @@ -31,22 +32,20 @@ * TODO: replaced it by batched columnar format. */ public class BufferedRowIterator { - protected InternalRow currentRow; + protected LinkedList currentRows = new LinkedList<>(); protected Iterator input; // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); public boolean hasNext() throws IOException { - if (currentRow == null) { + if (currentRows.isEmpty()) { processNext(); } - return currentRow != null; + return !currentRows.isEmpty(); } public InternalRow next() { - InternalRow r = currentRow; - currentRow = null; - return r; + return currentRows.remove(); } public void setInput(Iterator iter) { @@ -60,7 +59,7 @@ public void setInput(Iterator iter) { */ protected void processNext() throws IOException { if (input.hasNext()) { - currentRow = input.next(); + currentRows.add(input.next()); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 3bb560fa07d7c..39a84f3cdd334 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -151,12 +151,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} + | if (!currentRows.isEmpty()) { + | return; + | } | } """.stripMargin } override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + child.execute() } override def simpleString: String = "INPUT" @@ -252,8 +255,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRow = $row; - | return; + | currentRows.add($row.copy()); """.stripMargin } else { assert(input != null) @@ -266,14 +268,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" | ${code.code.trim} - | currentRow = ${code.value}; - | return; + | currentRows.add(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRow = unsafeRow; - | return; + | currentRows.add(unsafeRow); """.stripMargin } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d1bcfe5b50f5f..54e07250d0b90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -341,6 +341,8 @@ case class TungstenAggregate( | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputCode + | + | if (!currentRows.isEmpty()) return; | } | | $hashMapTerm.free(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f6a142872f72f..fa2f584dad9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -226,6 +226,8 @@ case class Range( | $overflow = true; | } | ${consume(ctx, Seq(ev))} + | + | if (!currentRows.isEmpty()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index c9ea579b5e809..bb25379cbc864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -21,13 +21,16 @@ import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SQLExecution, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -42,7 +45,7 @@ case class BroadcastHashJoin( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryNode with HashJoin { + extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), @@ -118,6 +121,82 @@ case class BroadcastHashJoin( hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + + // the broadcasted hash relation + private var broadcastRelation: Broadcast[HashedRelation] = _ + // the term for hash relation + private var relationTerm: String = _ + + override def upstream(): RDD[InternalRow] = { + broadcastRelation = Await.result(broadcastFuture, timeout) + streamedPlan.asInstanceOf[CodegenSupport].upstream() + } + + override def doProduce(ctx: CodegenContext): String = { + // create a name for HashRelation + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + relationTerm = ctx.freshName("relation") + // TODO: create specialized HashRelation for single join key + val clsName = classOf[UnsafeHashedRelation].getName + ctx.addMutableState(clsName, relationTerm, s"$relationTerm = ($clsName) $broadcast.value();") + + s""" + | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} + """.stripMargin + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + ctx.currentVars = input + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val keyTerm = keyVal.value + val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val row = ctx.freshName("row") + + ctx.currentVars = null + ctx.INPUT_ROW = row + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + val ouputCode = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output) + .gen(ctx) + s""" + | ${ev.code} + | if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + | } + """.stripMargin + } else { + consume(ctx, resultVars) + } + + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $row = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $ouputCode + | } + | } + """.stripMargin + } } object BroadcastHashJoin { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c2516509dfbbf..0153778b72f6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.functions.{avg, col, max} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, IntegerType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } + + test("BroadcastHashJoin should be included in WholeStageCodegen") { + val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val schema = new StructType().add("k", IntegerType).add("v", StringType) + val smallDF = sqlContext.createDataFrame(rdd, schema) + val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + assert(df.queryExecution.executedPlan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) + } } From ff04509e41fe2020f6aebfa9e32f3ae6e4448e56 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 22 Jan 2016 22:56:47 -0800 Subject: [PATCH 06/18] improve join --- .../spark/util/collection/CompactBuffer.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 2 +- .../execution/joins/BroadcastHashJoin.scala | 92 +++++---- .../spark/sql/execution/joins/HashJoin.scala | 35 +++- .../sql/execution/joins/HashedRelation.scala | 189 +++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 27 +++ 6 files changed, 297 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 4d43d8d5cc8d8..40aa0e1106735 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -51,7 +51,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable } } - private def update(position: Int, value: T): Unit = { + def update(position: Int, value: T): Unit = { if (position < 0 || position >= curSize) { throw new IndexOutOfBoundsException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 39a84f3cdd334..8fa43b6acd55f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -226,7 +226,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } """ // try to compile, helpful for debug - // println(s"${CodeFormatter.format(source)}") + println(s"${CodeFormatter.format(source)}") val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index bb25379cbc864..65525bfd0d28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -90,8 +90,13 @@ case class BroadcastHashJoin( // The following line doesn't run in a job so we cannot track the metric value. However, we // have already tracked it in the above lines. So here we can use // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + val hashed = if (canJoinKeyFitWithinLong) { + LongHashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } else { + HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -124,8 +129,6 @@ case class BroadcastHashJoin( // the broadcasted hash relation private var broadcastRelation: Broadcast[HashedRelation] = _ - // the term for hash relation - private var relationTerm: String = _ override def upstream(): RDD[InternalRow] = { broadcastRelation = Await.result(broadcastFuture, timeout) @@ -133,33 +136,30 @@ case class BroadcastHashJoin( } override def doProduce(ctx: CodegenContext): String = { - // create a name for HashRelation - val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - relationTerm = ctx.freshName("relation") - // TODO: create specialized HashRelation for single join key - val clsName = classOf[UnsafeHashedRelation].getName - ctx.addMutableState(clsName, relationTerm, s"$relationTerm = ($clsName) $broadcast.value();") - s""" | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} """.stripMargin } override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - ctx.currentVars = input - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) - val keyTerm = keyVal.value - val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + // create a name for HashRelation + val relationTerm = ctx.addReferenceObj("hashRelation", broadcastRelation.value) - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - val row = ctx.freshName("row") + ctx.currentVars = input + // TODO: filter out null from stream + val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + (ev, ev.isNull) + } else { + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + (ev, s"${ev.value}.anyNull()") + } + val matched = ctx.freshName("matched") ctx.currentVars = null - ctx.INPUT_ROW = row + ctx.INPUT_ROW = matched val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => BoundReference(i, a.dataType, a.nullable).gen(ctx) } @@ -168,7 +168,7 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } - val ouputCode = if (condition.isDefined) { + val outputCode = if (condition.isDefined) { ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output) .gen(ctx) @@ -182,20 +182,40 @@ case class BroadcastHashJoin( consume(ctx, resultVars) } - s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $row = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $ouputCode - | } - | } + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | UnsafeRow $matched = $anyNull ? null : + | (UnsafeRow) $relationTerm.getValue(${keyVal.value}); + | if ($matched != null) { + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = ${anyNull} ? null : + | ($bufferType) $relationTerm.get(${keyVal.value}); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } + | } """.stripMargin + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8ef854001f4de..923dd0803594c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.types.{LongType, IntegralType} trait HashJoin { @@ -47,11 +48,41 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output + /** + * Rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + */ + def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + keyExpr = Cast(e, LongType) + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1 << bits) - 1))) + width -= bits + } + case other => + return keys + } + } + keyExpr :: Nil + } + + protected val canJoinKeyFitWithinLong: Boolean = { + val key = rewriteKeyExpr(buildKeys) + key.length == 1 && key.head.dataType.isInstanceOf[LongType] + } + protected def buildSideKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) + UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) protected def streamSideKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) + UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ee7a1bdc343c0..6aecab9d19dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -40,6 +40,9 @@ import org.apache.spark.util.collection.CompactBuffer */ private[execution] sealed trait HashedRelation { def get(key: InternalRow): Seq[InternalRow] + def get(key: Long): Seq[InternalRow] = { + throw new UnsupportedOperationException + } // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -58,11 +61,42 @@ private[execution] sealed trait HashedRelation { } } +private[execution] trait UniqueHashedRelation extends HashedRelation { + + def getValue(key: InternalRow): InternalRow + def getValue(key: Long): InternalRow = { + throw new UnsupportedOperationException + } + + /** + * The buffer re-used in get(). + */ + private val buffer = CompactBuffer[InternalRow](null) + + override def get(key: InternalRow): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + buffer.update(0, row) + buffer + } else { + null + } + } + override def get(key: Long): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + buffer.update(0, row) + buffer + } else { + null + } + } +} /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation( +private[joins] class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { @@ -85,19 +119,14 @@ private[joins] final class GeneralHashedRelation( * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] -final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) - extends HashedRelation with Externalizable { +private[joins] class UniqueKeyHashedRelation( + private var hashTable: JavaHashMap[InternalRow, InternalRow]) + extends UniqueHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) def this() = this(null) - override def get(key: InternalRow): Seq[InternalRow] = { - val v = hashTable.get(key) - if (v eq null) null else CompactBuffer(v) - } - - def getValue(key: InternalRow): InternalRow = hashTable.get(key) + override def getValue(key: InternalRow): InternalRow = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -411,3 +440,143 @@ private[joins] object UnsafeHashedRelation { new UnsafeHashedRelation(hashTable) } } + +private[joins] trait LongHashedRelation extends HashedRelation { + override def get(key: InternalRow): Seq[InternalRow] = { + get(key.getLong(0)) + } +} + +private[joins] final class GeneralLongHashedRelation( + private var hashTable: JavaHashMap[Long, CompactBuffer[InternalRow]]) + extends LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] final class UniqueLongHashedRelation( + private var hashTable: JavaHashMap[Long, InternalRow]) + extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getValue(key: Long): InternalRow = { + hashTable.get(key) + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] final class LongArrayRelation( + private var start: Long, + private var array: Array[InternalRow]) + extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0L, null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getValue(key: Long): InternalRow = { + val idx = key - start + if (idx >= 0 && idx < array.length) { + array(idx.toInt) + } else { + null + } + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeLong(start) + writeBytes(out, SparkSqlSerializer.serialize(array)) + } + + override def readExternal(in: ObjectInput): Unit = { + start = in.readLong() + array = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] object LongHashedRelation { + def apply( + input: Iterator[InternalRow], + numInputRows: LongSQLMetric, + keyGenerator: Projection, + sizeEstimate: Int): HashedRelation = { + + // Use a Java hash table here because unsafe maps expect fixed size records + val hashTable = new JavaHashMap[Long, CompactBuffer[InternalRow]](sizeEstimate) + + // Create a mapping of buildKeys -> rows + var keyIsUnique = true + var minKey = Long.MaxValue + var maxKey = Long.MinValue + while (input.hasNext) { + val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val key = rowKey.getLong(0) + minKey = math.min(minKey, key) + maxKey = math.max(maxKey, key) + val existingMatchList = hashTable.get(key) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[InternalRow]() + hashTable.put(key, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + if (keyIsUnique) { + if (maxKey - minKey <= hashTable.size() * 5) { + // The keys are dense enough, so use LongArrayRelation + val array = new Array[InternalRow]((maxKey - minKey).toInt + 1) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + array(entry.getKey.toInt - minKey.toInt) = entry.getValue()(0) + } + new LongArrayRelation(minKey, array) + } else { + val uniqHashTable = new JavaHashMap[Long, InternalRow](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueLongHashedRelation(uniqHashTable) + } + } else { + new GeneralLongHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 0937c5a1c92ab..a430adacdcb8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -25,6 +25,7 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark +import org.apache.spark.sql.functions._ /** * Benchmark to measure whole stage codegen performance. @@ -81,6 +82,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testBroadcastHashJoin(values: Int): Unit = { + val benchmark = new Benchmark("BroadcastHashJoin", values) + + val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + + benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 13071.57 4.01 1.00 X + Aggregate w codegen 5072.56 10.34 2.58 X + */ + benchmark.run() + } + def testBytesToBytesMap(values: Int): Unit = { val benchmark = new Benchmark("BytesToBytesMap", values) @@ -150,6 +175,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { test("benchmark") { // testWholeStage(1024 * 1024 * 200) // testAggregateWithKey(1024 * 1024 * 50) + testBroadcastHashJoin(1024 * 1024 * 50) + // testBytesToBytesMap(1024 * 1024 * 50) } } From 081a04dbb078f9afa36d0a4eac839658cc48e954 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 23 Jan 2016 00:13:31 -0800 Subject: [PATCH 07/18] serialize LongArrayRelation --- .../sql/execution/joins/HashedRelation.scala | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6aecab9d19dfe..73f99a9086109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -448,7 +448,7 @@ private[joins] trait LongHashedRelation extends HashedRelation { } private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[InternalRow]]) + private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) extends LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) @@ -490,12 +490,13 @@ private[joins] final class UniqueLongHashedRelation( } private[joins] final class LongArrayRelation( + private var numFields: Int, private var start: Long, - private var array: Array[InternalRow]) + private var array: Array[UnsafeRow]) extends UniqueHashedRelation with LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) - def this() = this(0L, null) + def this() = this(0, 0L, null) override def getValue(key: InternalRow): InternalRow = { getValue(key.getLong(0)) @@ -511,13 +512,39 @@ private[joins] final class LongArrayRelation( } override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(numFields) out.writeLong(start) - writeBytes(out, SparkSqlSerializer.serialize(array)) + out.writeInt(array.length) + var i = 0 + while (i < array.length) { + if (array(i) != null) { + writeBytes(out, array(i).getBytes) + } else { + out.writeInt(0) + } + i += 1 + } } override def readExternal(in: ObjectInput): Unit = { + numFields = in.readInt() start = in.readLong() - array = SparkSqlSerializer.deserialize(readBytes(in)) + val length = in.readInt() + array = new Array[UnsafeRow](length) + var i = 0 + while (i < length) { + val len = in.readInt() + if (len != 0) { + val bytes = new Array[Byte](len) + in.readFully(bytes) + val row = new UnsafeRow(numFields) + row.pointTo(bytes, len) + array(i) = row + } else { + array(i) = null + } + i += 1 + } } } @@ -529,14 +556,16 @@ private[joins] object LongHashedRelation { sizeEstimate: Int): HashedRelation = { // Use a Java hash table here because unsafe maps expect fixed size records - val hashTable = new JavaHashMap[Long, CompactBuffer[InternalRow]](sizeEstimate) + val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows + var numFields = 0 var keyIsUnique = true var minKey = Long.MaxValue var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numFields = unsafeRow.numFields() numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { @@ -545,7 +574,7 @@ private[joins] object LongHashedRelation { maxKey = math.max(maxKey, key) val existingMatchList = hashTable.get(key) val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[InternalRow]() + val newMatchList = new CompactBuffer[UnsafeRow]() hashTable.put(key, newMatchList) newMatchList } else { @@ -559,13 +588,13 @@ private[joins] object LongHashedRelation { if (keyIsUnique) { if (maxKey - minKey <= hashTable.size() * 5) { // The keys are dense enough, so use LongArrayRelation - val array = new Array[InternalRow]((maxKey - minKey).toInt + 1) + val array = new Array[UnsafeRow]((maxKey - minKey).toInt + 1) val iter = hashTable.entrySet().iterator() while (iter.hasNext) { val entry = iter.next() array(entry.getKey.toInt - minKey.toInt) = entry.getValue()(0) } - new LongArrayRelation(minKey, array) + new LongArrayRelation(numFields, minKey, array) } else { val uniqHashTable = new JavaHashMap[Long, InternalRow](hashTable.size) val iter = hashTable.entrySet().iterator() From 77ba890d896a69e8a84c035bf3dc07947a999cc2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 25 Jan 2016 15:58:15 -0800 Subject: [PATCH 08/18] fix planner for BroadcastHashJoin --- .../spark/sql/execution/WholeStageCodegen.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8fa43b6acd55f..cbd43879c066f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.joins.{BuildRight, BuildLeft, BroadcastHashJoin} /** * An interface for those physical operators that support codegen. @@ -134,7 +135,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def supportCodegen: Boolean = true + override def supportCodegen: Boolean = false override def upstream(): RDD[InternalRow] = { child.execute() @@ -340,6 +341,15 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { + // The build side can't be compiled together + case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + val input = apply(left) + inputs += input + b.copy(left = input) + case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + val input = apply(right) + inputs += input + b.copy(right = input) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input From 9a42b522cb483a3502ee36cea6672c36f2e40b46 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 25 Jan 2016 18:15:59 -0800 Subject: [PATCH 09/18] address comments --- .../sql/execution/WholeStageCodegen.scala | 54 ++-- .../aggregate/TungstenAggregate.scala | 249 ++++++++++-------- .../spark/sql/execution/basicOperators.scala | 90 ++++--- 3 files changed, 216 insertions(+), 177 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 3bb560fa07d7c..bcb0051db206f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -197,33 +198,36 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) override def output: Seq[Attribute] = plan.output override def doExecute(): RDD[InternalRow] = { + val ctx = new CodegenContext + val code = plan.produce(ctx, this) + val references = ctx.references.toArray + val source = s""" + public Object generate(Object[] references) { + return new GeneratedIterator(references); + } - plan.upstream().mapPartitions { iter => - val ctx = new CodegenContext - val code = plan.produce(ctx, this) - val references = ctx.references.toArray - val source = s""" - public Object generate(Object[] references) { - return new GeneratedIterator(references); - } - - class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } + public GeneratedIterator(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } - protected void processNext() throws java.io.IOException { - $code - } + protected void processNext() throws java.io.IOException { + $code } - """ - // try to compile, helpful for debug - // println(s"${CodeFormatter.format(source)}") + } + """ + + // try to compile, helpful for debug + // println(s"${CodeFormatter.format(source)}") + CodeGenerator.compile(source) + + plan.upstream().mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] @@ -268,13 +272,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) | ${code.code.trim} | currentRow = ${code.value}; | return; - """.stripMargin + """.stripMargin } else { // There is no columns s""" | currentRow = unsafeRow; | return; - """.stripMargin + """.stripMargin } } } @@ -336,7 +340,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d1bcfe5b50f5f..882206b9d75a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -157,17 +157,20 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) - val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; - """.stripMargin + val initVars = + s""" + $isNull = ${ev.isNull}; + $value = ${ev.value}; + """ ExprCode(ev.code + initVars, isNull, value) } // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + val (resultVars, genResult) = if (modes.contains(Final) |modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars val bufferAttrs = functions.flatMap(_.aggBufferAttributes) @@ -180,29 +183,36 @@ case class TungstenAggregate( BindReferences.bindReference(e, aggregateAttributes).gen(ctx) } (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - """.stripMargin) + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + """) } else { // output the aggregate buffer directly (bufVars, "") } + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + // initialize aggregation buffer + ${bufVars.map(_.code).mkString("\n")} + + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + } + """) + s""" - | if (!$initAgg) { - | $initAgg = true; - | - | // initialize aggregation buffer - | ${bufVars.map(_.code).mkString("\n")} - | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | - | // output the result - | $genResult - | - | ${consume(ctx, resultVars)} - | } - """.stripMargin + if (!$initAgg) { + $initAgg = true; + $doAgg(); + + // output the result + $genResult + + ${consume(ctx, resultVars)} + } + """ } private def doConsumeWithoutKeys( @@ -225,36 +235,36 @@ case class TungstenAggregate( val updates = updateExpr.zipWithIndex.map { case (e, i) => val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" - | ${ev.code} - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; - """.stripMargin + ${ev.code} + ${bufVars(i).isNull} = ${ev.isNull}; + ${bufVars(i).value} = ${ev.value}; + """ } s""" - | // do aggregate and update aggregation buffer - | ${updates.mkString("")} - """.stripMargin + // do aggregate and update aggregation buffer + ${updates.mkString("")} + """ } + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + // The name for HashMap var hashMapTerm: String = _ - private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") - + def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer - val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val initExpr = functions.flatMap(f => f.initialValues) + val initExpr = declFunctions.flatMap(f => f.initialValues) val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) // create hashMap - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferAttributes = functions.flatMap(_.aggBufferAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val hashMap = new UnsafeFixedWidthAggregationMap( + new UnsafeFixedWidthAggregationMap( initialBuffer, bufferSchema, groupingKeySchema, @@ -263,7 +273,21 @@ case class TungstenAggregate( TaskContext.get().taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) - hashMapTerm = ctx.addReferenceObj("hashMap", hashMap) + } + + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("tungstenAggregate", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") @@ -272,7 +296,7 @@ case class TungstenAggregate( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + val outputCode = if (modes.contains(Final) |modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm @@ -285,7 +309,7 @@ case class TungstenAggregate( } // evaluate the aggregation result ctx.currentVars = bufferVars - val aggResults = functions.map(_.evaluateExpression).map { e => + val aggResults = declFunctions.map(_.evaluateExpression).map { e => BindReferences.bindReference(e, bufferAttributes).gen(ctx) } // generate the final result @@ -295,24 +319,24 @@ case class TungstenAggregate( BindReferences.bindReference(e, inputAttrs).gen(ctx) } s""" - | ${keyVars.map(_.code).mkString("\n")} - | ${bufferVars.map(_.code).mkString("\n")} - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - | - | ${consume(ctx, resultVars)} - """.stripMargin - - } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + ${keyVars.map(_.code).mkString("\n")} + ${bufferVars.map(_.code).mkString("\n")} + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) |modes.contains(PartialMerge)) { // This should be the last operator in a stage, we should output UnsafeRow directly - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - val joinerTerm = ctx.addReferenceObj("unsafeRowJoiner", unsafeRowJoiner, - classOf[UnsafeRowJoiner].getName) + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $thisPlan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" - | UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); - | ${consume(ctx, null, resultRow)} - """.stripMargin + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ } else { // only grouping key @@ -322,29 +346,36 @@ case class TungstenAggregate( BindReferences.bindReference(e, groupingAttributes).gen(ctx) } s""" - | ${eval.map(_.code).mkString("\n")} - | ${consume(ctx, eval)} - """.stripMargin + ${eval.map(_.code).mkString("\n")} + ${consume(ctx, eval)} + """ } + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + $iterTerm = $hashMapTerm.iterator(); + } + """) + s""" - | if (!$initAgg) { - | $initAgg = true; - | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | - | $iterTerm = $hashMapTerm.iterator(); - | } - | - | // output the result - | while ($iterTerm.next()) { - | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); - | $outputCode - | } - | - | $hashMapTerm.free(); - """.stripMargin + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + while ($iterTerm.next()) { + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + } + + $hashMapTerm.free(); + """ } private def doConsumeWithKeys( @@ -360,7 +391,6 @@ case class TungstenAggregate( val buffer = ctx.freshName("aggBuffer") // only have DeclarativeAggregate - val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val updateExpr = aggregateExpressions.flatMap { e => e.mode match { case Partial | Complete => @@ -370,10 +400,9 @@ case class TungstenAggregate( } } - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - val inputAttr = bufferAttrs ++ child.output + val inputAttr = bufferAttributes ++ child.output val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) - ctx.currentVars = new Array[ExprCode](bufferAttrs.length) ++ input + ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input ctx.INPUT_ROW = buffer // TODO: support subexpression elimination val evals = boundExpr.map(_.gen(ctx)) @@ -382,43 +411,43 @@ case class TungstenAggregate( if (updateExpr(i).nullable) { if (dt.isInstanceOf[DecimalType]) { s""" - | if (!${ev.isNull}) { - | ${ctx.setColumn(buffer, dt, i, ev.value)}; - | } else { - | ${ctx.setColumn(buffer, dt, i, "null")}; - | } - """.stripMargin + if (!${ev.isNull}) { + ${ctx.setColumn(buffer, dt, i, ev.value)}; + } else { + ${ctx.setColumn(buffer, dt, i, "null")}; + } + """ } else { s""" - | if (!${ev.isNull}) { - | ${ctx.setColumn(buffer, dt, i, ev.value)}; - | } else { - | $buffer.setNullAt($i); - | } - """.stripMargin + if (!${ev.isNull}) { + ${ctx.setColumn(buffer, dt, i, ev.value)}; + } else { + $buffer.setNullAt($i); + } + """ } } else { s""" - | ${ctx.setColumn(buffer, dt, i, ev.value)}; - """.stripMargin + ${ctx.setColumn(buffer, dt, i, ev.value)}; + """ } } s""" - | // generate grouping key - | ${keyCode.code} - | UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); - | if ($buffer == null) { - | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); - | } - | - | // evaluate aggregate function - | ${evals.map(_.code).mkString("\n")} - | - | // update aggregate buffer - | ${updates.mkString("\n")} - """.stripMargin + // generate grouping key + ${keyCode.code} + UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } + + // evaluate aggregate function + ${evals.map(_.code).mkString("\n")} + + // update aggregate buffer + ${updates.mkString("\n")} + """ } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f6a142872f72f..7a3787659e018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -185,49 +185,55 @@ case class Range( s"$number > $partitionEnd" } + val initRange = ctx.freshName("initRange") + ctx.addNewFunction(initRange, + s""" + private void $initRange(InternalRow row) { + $BigInt index = $BigInt.valueOf(row.getInt(0)); + $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + $BigInt step = $BigInt.valueOf(${step}L); + $BigInt start = $BigInt.valueOf(${start}L); + + $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + $number = Long.MAX_VALUE; + } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + $number = Long.MIN_VALUE; + } else { + $number = st.longValue(); + } + + $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + .multiply(step).add(start); + if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + $partitionEnd = Long.MAX_VALUE; + } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + $partitionEnd = Long.MIN_VALUE; + } else { + $partitionEnd = end.longValue(); + } + }""") + s""" - | // initialize Range - | if (!$initTerm) { - | $initTerm = true; - | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } - | } else { - | return; - | } - | } - | - | while (!$overflow && $checkEnd) { - | long $value = $number; - | $number += ${step}L; - | if ($number < $value ^ ${step}L < 0) { - | $overflow = true; - | } - | ${consume(ctx, Seq(ev))} - | } - """.stripMargin + // initialize Range + if (!$initTerm) { + $initTerm = true; + if (input.hasNext()) { + $initRange((InternalRow) input.next()); + } else { + return; + } + } + + while (!$overflow && $checkEnd) { + long $value = $number; + $number += ${step}L; + if ($number < $value ^ ${step}L < 0) { + $overflow = true; + } + ${consume(ctx, Seq(ev))} + }""" } protected override def doExecute(): RDD[InternalRow] = { From 37bc7f0a34c6bb7d941cfa146af471d1f83ab04a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 26 Jan 2016 22:57:10 -0800 Subject: [PATCH 10/18] fix tests --- .../sql/execution/WholeStageCodegen.scala | 19 +++- .../org/apache/spark/sql/SQLQuerySuite.scala | 96 +++++++++---------- .../execution/metric/SQLMetricsSuite.scala | 34 +++---- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 5 files changed, 90 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index bcb0051db206f..64241419940ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,8 +22,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.Utils @@ -134,8 +135,14 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def supportCodegen: Boolean = true + override def doPrepare(): Unit = { + child.prepare() + } + + override def supportCodegen: Boolean = false override def upstream(): RDD[InternalRow] = { child.execute() @@ -195,7 +202,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d2d271ac93d26..979a7b4922ea7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1933,61 +1933,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) - - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) - } + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) verifyCallCount( df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - } - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 51285431a47ed..35e25cd5d7328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -327,22 +327,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 5f73d71d4510a..f05deb49543ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -186,7 +186,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b46b0d2f6040a..c8478cf8da13d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From 3bfdeb2cacebe8567f2d0123c853a20de84fa158 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 15:32:34 -0800 Subject: [PATCH 11/18] adress comment --- .../expressions/codegen/CodeGenerator.scala | 33 ++++++++++++++ .../codegen/GenerateMutableProjection.scala | 27 +----------- .../sql/execution/WholeStageCodegen.scala | 2 +- .../aggregate/TungstenAggregate.scala | 44 ++++++------------- 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5f8c87c4f2fa2..21f9198073d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -212,6 +212,39 @@ class CodegenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b8..5b4dc8df8622b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val updates = validExpr.zip(index).map { case (e, i) => - if (e.nullable) { - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } else { - s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } - } else { - s""" - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - """ - } - + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index f90568ac2d583..ef81ba60f049f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -49,7 +49,7 @@ trait CodegenSupport extends SparkPlan { def upstream(): RDD[InternalRow] /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns Java source code to process the rows from upstream. */ def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 80feeba186dee..260eee34f09bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -231,16 +231,16 @@ case class TungstenAggregate( val updates = updateExpr.zipWithIndex.map { case (e, i) => val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" - ${ev.code} - ${bufVars(i).isNull} = ${ev.isNull}; - ${bufVars(i).value} = ${ev.value}; - """ + | ${ev.code} + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin } s""" - // do aggregate and update aggregation buffer - ${updates.mkString("")} - """ + | // do aggregate and update aggregation buffer + | ${updates.mkString("")} + """.stripMargin } val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -254,6 +254,9 @@ case class TungstenAggregate( // The name for HashMap var hashMapTerm: String = _ + /** + * This is called by generated Java class, should be public. + */ def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) @@ -271,6 +274,9 @@ case class TungstenAggregate( ) } + /** + * This is called by generated Java class, should be public. + */ def createUnsafeJoiner(): UnsafeRowJoiner = { GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } @@ -401,29 +407,7 @@ case class TungstenAggregate( val evals = boundExpr.map(_.gen(ctx)) val updates = evals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - if (updateExpr(i).nullable) { - if (dt.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${ctx.setColumn(buffer, dt, i, ev.value)}; - } else { - ${ctx.setColumn(buffer, dt, i, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${ctx.setColumn(buffer, dt, i, ev.value)}; - } else { - $buffer.setNullAt($i); - } - """ - } - } else { - s""" - ${ctx.setColumn(buffer, dt, i, ev.value)}; - """ - } + ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) } s""" From be2e53bd3be70604b885a7a219a1a28801d75320 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 11:20:01 -0800 Subject: [PATCH 12/18] minor --- .../aggregate/TungstenAggregate.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index eba8a3d8a9739..9d97fa8274134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -218,7 +218,7 @@ case class TungstenAggregate( """.stripMargin } - def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output @@ -249,16 +249,16 @@ case class TungstenAggregate( """.stripMargin } - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val declFunctions = aggregateExpressions.map(_.aggregateFunction) + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) .filter(_.isInstanceOf[DeclarativeAggregate]) .map(_.asInstanceOf[DeclarativeAggregate]) - val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) + private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + private val bufferSchema = StructType.fromAttributes(bufferAttributes) // The name for HashMap - var hashMapTerm: String = _ + private var hashMapTerm: String = _ /** * This is called by generated Java class, should be public. @@ -292,7 +292,7 @@ case class TungstenAggregate( ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") // create hashMap - val thisPlan = ctx.addReferenceObj("tungstenAggregate", this) + val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") @@ -304,7 +304,7 @@ case class TungstenAggregate( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) |modes.contains(Complete)) { + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm @@ -335,7 +335,7 @@ case class TungstenAggregate( ${consume(ctx, resultVars)} """ - } else if (modes.contains(Partial) |modes.contains(PartialMerge)) { + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, @@ -347,7 +347,7 @@ case class TungstenAggregate( """ } else { - // only grouping key + // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null val eval = resultExpressions.map{ e => @@ -359,7 +359,7 @@ case class TungstenAggregate( """ } - val doAgg = ctx.freshName("doAgg") + val doAgg = ctx.freshName("doAggregate") ctx.addNewFunction(doAgg, s""" private void $doAgg() throws java.io.IOException { @@ -406,11 +406,10 @@ case class TungstenAggregate( } val inputAttr = bufferAttributes ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input ctx.INPUT_ROW = buffer // TODO: support subexpression elimination - val evals = boundExpr.map(_.gen(ctx)) + val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) val updates = evals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) @@ -427,7 +426,6 @@ case class TungstenAggregate( // evaluate aggregate function ${evals.map(_.code).mkString("\n")} - // update aggregate buffer ${updates.mkString("\n")} """ From 9ae4bc2595052008b78e4c47d92f0a45348181ea Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 25 Jan 2016 15:58:15 -0800 Subject: [PATCH 13/18] fix planner for BroadcastHashJoin Conflicts: sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala --- .../apache/spark/sql/execution/WholeStageCodegen.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c16e214db2af8..14393035c5e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.util.Utils /** @@ -363,6 +364,15 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { + // The build side can't be compiled together + case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + val input = apply(left) + inputs += input + b.copy(left = input) + case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + val input = apply(right) + inputs += input + b.copy(right = input) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input From dcf4fdc0b955fb5abd9686447710ecc0a76ce990 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 22:06:52 -0800 Subject: [PATCH 14/18] fix style --- .../apache/spark/sql/execution/joins/BroadcastHashJoin.scala | 4 ++-- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8f39ff9ae49bf..288a4f1111f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0153778b72f6e..9350205d791d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { From 1ecce29b10ee7ee27577a8955f4a7db4d933ba66 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 22:56:11 -0800 Subject: [PATCH 15/18] add comments --- .../spark/sql/execution/joins/BroadcastHashJoin.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 3d46127b2e7fa..8b275e886c46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -147,18 +147,21 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow ctx.currentVars = input val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) val keyTerm = keyVal.value val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + // find the matches from HashedRelation val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") val size = ctx.freshName("size") val row = ctx.freshName("row") + // create variables for output ctx.currentVars = null ctx.INPUT_ROW = row val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => @@ -170,9 +173,9 @@ case class BroadcastHashJoin( } val ouputCode = if (condition.isDefined) { + // filter the output via condition ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output) - .gen(ctx) + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" | ${ev.code} | if (!${ev.isNull} && ${ev.value}) { From 0139fdeeefc2038e995c44c7e966e09e30063418 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 23:25:21 -0800 Subject: [PATCH 16/18] fix style --- .../apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index a708e907db835..36f9d71a7cb1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,11 +21,11 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark -import org.apache.spark.sql.functions._ /** * Benchmark to measure whole stage codegen performance. From c1c0588053af5aa359b6d03bac6c5d0b198c5b69 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 1 Feb 2016 10:48:59 -0800 Subject: [PATCH 17/18] update comment --- .../org/apache/spark/sql/execution/BufferedRowIterator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index c5bccdeb076e3..367c9cd379c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -54,7 +54,7 @@ public void setInput(Iterator iter) { } /** - * Returns whether it should stop processing next row or not. + * Returns whether `processNext()` should stop processing next row from `input` or not. */ protected boolean shouldStop() { return !currentRows.isEmpty(); From e0c8c652b86ce9d17bcb5d629e6b55563b5c382b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 22:32:12 -0800 Subject: [PATCH 18/18] add comment --- .../org/apache/spark/sql/execution/BufferedRowIterator.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 367c9cd379c0d..ea20115770f79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -55,6 +55,8 @@ public void setInput(Iterator iter) { /** * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). */ protected boolean shouldStop() { return !currentRows.isEmpty();