Skip to content

Commit

Permalink
[SPARK-12798] [SQL] generated BroadcastHashJoin
Browse files Browse the repository at this point in the history
A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate).

Author: Davies Liu <davies@databricks.com>

Closes apache#10989 from davies/gen_join.
  • Loading branch information
Davies Liu authored and davies committed Feb 3, 2016
1 parent e9eb248 commit c4feec2
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.execution;

import java.io.IOException;
import java.util.LinkedList;

import scala.collection.Iterator;

import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;

Expand All @@ -31,36 +33,50 @@
* TODO: replaced it by batched columnar format.
*/
public class BufferedRowIterator {
protected InternalRow currentRow;
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
protected Iterator<InternalRow> input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);

public boolean hasNext() throws IOException {
if (currentRow == null) {
if (currentRows.isEmpty()) {
processNext();
}
return currentRow != null;
return !currentRows.isEmpty();
}

public InternalRow next() {
InternalRow r = currentRow;
currentRow = null;
return r;
return currentRows.remove();
}

public void setInput(Iterator<InternalRow> iter) {
input = iter;
}

/**
* Returns whether `processNext()` should stop processing next row from `input` or not.
*
* If it returns true, the caller should exit the loop (return from processNext()).
*/
protected boolean shouldStop() {
return !currentRows.isEmpty();
}

/**
* Increase the peak execution memory for current task.
*/
protected void incPeakExecutionMemory(long size) {
TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
}

/**
* Processes the input until have a row as output (currentRow).
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected void processNext() throws IOException {
if (input.hasNext()) {
currentRow = input.next();
currentRows.add(input.next());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| }
| }
""".stripMargin
}
Expand Down Expand Up @@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
| currentRow = $row;
| return;
| currentRows.add($row.copy());
""".stripMargin
} else {
assert(input != null)
Expand All @@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
| currentRows.add(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
| currentRows.add(unsafeRow);
""".stripMargin
}
}
Expand Down Expand Up @@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru

var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ case class TungstenAggregate(
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode

if (shouldStop()) return;
}

$iterTerm.close();
Expand All @@ -480,7 +482,7 @@ case class TungstenAggregate(
"""
}

private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {

// create grouping key
ctx.currentVars = input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ case class Range(
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
|
| if (shouldStop()) return;
| }
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.collection.CompactBuffer

/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
Expand All @@ -42,7 +45,7 @@ case class BroadcastHashJoin(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
extends BinaryNode with HashJoin with CodegenSupport {

override private[sql] lazy val metrics = Map(
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
Expand Down Expand Up @@ -117,6 +120,87 @@ case class BroadcastHashJoin(
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}

// the term for hash relation
private var relationTerm: String = _

override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
}

override def doProduce(ctx: CodegenContext): String = {
// create a name for HashRelation
val broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
relationTerm = ctx.freshName("relation")
// TODO: create specialized HashRelation for single join key
val clsName = classOf[UnsafeHashedRelation].getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
| incPeakExecutionMemory($relationTerm.getUnsafeSize());
""".stripMargin)

s"""
| ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
""".stripMargin
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// generate the key as UnsafeRow
ctx.currentVars = input
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
val keyTerm = keyVal.value
val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"

// find the matches from HashedRelation
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
val row = ctx.freshName("row")

// create variables for output
ctx.currentVars = null
ctx.INPUT_ROW = row
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
}
val resultVars = buildSide match {
case BuildLeft => buildColumns ++ input
case BuildRight => input ++ buildColumns
}

val ouputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
| ${ev.code}
| if (!${ev.isNull} && ${ev.value}) {
| ${consume(ctx, resultVars)}
| }
""".stripMargin
} else {
consume(ctx, resultVars)
}

s"""
| // generate join key
| ${keyVal.code}
| // find matches from HashRelation
| $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
| if ($matches != null) {
| int $size = $matches.size();
| for (int $i = 0; $i < $size; $i++) {
| UnsafeRow $row = (UnsafeRow) $matches.apply($i);
| ${buildColumns.map(_.code).mkString("\n")}
| $ouputCode
| }
| }
""".stripMargin
}
}

object BroadcastHashJoin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.functions._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
Expand Down Expand Up @@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
}

def testBroadcastHashJoin(values: Int): Unit = {
val benchmark = new Benchmark("BroadcastHashJoin", values)

val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))

benchmark.addCase("BroadcastHashJoin w/o codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
}
benchmark.addCase(s"BroadcastHashJoin w codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
}

/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X
BroadcastHashJoin w codegen 1028.40 10.20 2.97 X
*/
benchmark.run()
}

def testBytesToBytesMap(values: Int): Unit = {
val benchmark = new Benchmark("BytesToBytesMap", values)

Expand Down Expand Up @@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// testWholeStage(200 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
// testBytesToBytesMap(50 << 20)
// testBroadcastHashJoin(10 << 20)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions.{avg, col, max}
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

Expand Down Expand Up @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}

test("BroadcastHashJoin should be included in WholeStageCodegen") {
val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
val schema = new StructType().add("k", IntegerType).add("v", StringType)
val smallDF = sqlContext.createDataFrame(rdd, schema)
val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
}

0 comments on commit c4feec2

Please sign in to comment.