Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12798] [SQL] generated BroadcastHashJoin #10989

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
generate aggregation with grouping keys
  • Loading branch information
Davies Liu committed Jan 21, 2016
commit 3e792f3569d7a397e2817ac3b66816a3c35feed0
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution;

import java.io.IOException;

import scala.collection.Iterator;

import org.apache.spark.sql.catalyst.InternalRow;
Expand All @@ -34,7 +36,7 @@ public class BufferedRowIterator {
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);

public boolean hasNext() {
public boolean hasNext() throws IOException {
if (currentRow == null) {
processNext();
}
Expand All @@ -56,7 +58,7 @@ public void setInput(Iterator<InternalRow> iter) {
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected void processNext() {
protected void processNext() throws IOException {
if (input.hasNext()) {
currentRow = input.next();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@ trait CodegenSupport extends SparkPlan {
*/
private var parent: CodegenSupport = null

/**
* Returns the RDD of InternalRow which generates the input rows.
*/
def upstream(): RDD[InternalRow]

/**
* Returns an input RDD of InternalRow and Java source code to process them.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
doProduce(ctx)
}
Expand All @@ -66,17 +71,40 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
protected def doProduce(ctx: CodegenContext): String

/**
* Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
*/
protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
assert(columns.length == output.length)
parent.doConsume(ctx, this, columns)
def consume(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
if (child eq this) {
// This is called by itself, pass to it's parent
if (input != null) {
assert(input.length == output.length)
}
parent.consume(ctx, this, input, row)
} else {
// This is called by child
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, child, evals)}
""".stripMargin
} else {
doConsume(ctx, child, input)
}
}
}


/**
* Generate the Java source code to process the rows from child SparkPlan.
*
Expand All @@ -89,7 +117,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
protected def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}
}


Expand All @@ -105,24 +135,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {

override def supportCodegen: Boolean = true

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def upstream(): RDD[InternalRow] = {
child.execute()
}

override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
val code = s"""
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| ${consume(ctx, this, columns)}
| }
""".stripMargin
(child.execute(), code)
}

def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}

override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -165,34 +194,34 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
override def output: Seq[Attribute] = plan.output

override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val (rdd, code) = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}

class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
plan.upstream().mapPartitions { iter =>
val ctx = new CodegenContext
val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}

class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

private Object[] references;
${ctx.declareMutableStates()}

public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
}

protected void processNext() throws java.io.IOException {
$code
}
}
"""
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")

private Object[] references;
${ctx.declareMutableStates()}

public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
}

protected void processNext() {
$code
}
}
"""
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)

rdd.mapPartitions { iter =>
val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
Expand All @@ -203,29 +232,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}

override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
override def doProduce(ctx: CodegenContext): String = {
throw new UnsupportedOperationException
}

override def consume(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {

if (row != null) {
// There is an UnsafeRow already
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| currentRow = $row;
| return;
""".stripMargin
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
assert(input != null)
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
""".stripMargin
}
}
}

Expand All @@ -246,7 +293,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")

plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
Expand Down Expand Up @@ -291,8 +338,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
case p if !supportCodegen(p) =>
inputs += p
InputAdapter(p)
val input = apply(p) // collapse them recursively
inputs += input
InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
Expand Down
Loading