Skip to content

Commit

Permalink
[SPARK-13147] [SQL] improve readability of generated code
Browse files Browse the repository at this point in the history
1. try to avoid the suffix (unique id)
2. remove the comment if there is no code generated.
3. re-arrange the order of functions
4. trop the new line for inlined blocks.

Author: Davies Liu <davies@databricks.com>

Closes apache#11032 from davies/better_suffix.
  • Loading branch information
Davies Liu authored and davies committed Feb 3, 2016
1 parent 335f10e commit e86f8f6
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
val value = ctx.freshName("value")
val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
if (ve.code != "") {
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
} else {
ve
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,33 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"

private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
* The map from a variable name to it's next ID.
*/
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1

/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""

/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
def freshName(name: String): String = {
if (freshNamePrefix == "") {
s"$name${curId.getAndIncrement}"
def freshName(name: String): String = synchronized {
val fullName = if (freshNamePrefix == "") {
name
} else {
s"${freshNamePrefix}_$name"
}
if (freshNameIds.contains(fullName)) {
val id = freshNameIds(fullName)
freshNameIds(fullName) = id + 1
s"$fullName$id"
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
freshNameIds += fullName -> 1
fullName
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,26 @@ case class GetArrayStructFields(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
val n = ctx.freshName("n")
val values = ctx.freshName("values")
val j = ctx.freshName("j")
val row = ctx.freshName("row")
s"""
final int n = $eval.numElements();
final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
if ($eval.isNullAt(j)) {
values[j] = null;
final int $n = $eval.numElements();
final Object[] $values = new Object[$n];
for (int $j = 0; $j < $n; $j++) {
if ($eval.isNullAt($j)) {
$values[$j] = null;
} else {
final InternalRow row = $eval.getStruct(j, $numFields);
if (row.isNullAt($ordinal)) {
values[j] = null;
final InternalRow $row = $eval.getStruct($j, $numFields);
if ($row.isNullAt($ordinal)) {
$values[$j] = null;
} else {
values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
}
}
}
${ev.value} = new $arrayClass(values);
${ev.value} = new $arrayClass($values);
"""
})
}
Expand Down Expand Up @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
s"""
final int index = (int) $eval2;
if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
}
"""
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| }
""".stripMargin
}
Expand Down Expand Up @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])

private Object[] references;
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
this.references = references;
${ctx.initMutableStates()}
}

${ctx.declareAddedFunctions()}

protected void processNext() throws java.io.IOException {
$code
${code.trim}
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ case class TungstenAggregate(
| $doAgg();
|
| // output the result
| $genResult
| ${genResult.trim}
|
| ${consume(ctx, resultVars)}
| ${consume(ctx, resultVars).trim}
| }
""".stripMargin
}
Expand Down Expand Up @@ -242,9 +242,9 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n")}
| ${aggVals.map(_.code).mkString("\n").trim}
| // update aggregation buffer
| ${updates.mkString("")}
| ${updates.mkString("\n").trim}
""".stripMargin
}

Expand Down Expand Up @@ -523,7 +523,7 @@ case class TungstenAggregate(
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
${keyCode.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
Expand All @@ -547,9 +547,9 @@ case class TungstenAggregate(
$incCounter

// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
${evals.map(_.code).mkString("\n").trim}
// update aggregate buffer
${updates.mkString("\n")}
${updates.mkString("\n").trim}
"""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
val eval = expr.gen(ctx)
val nullCheck = if (expr.nullable) {
s"!${eval.isNull} &&"
} else {
s""
}
s"""
| ${eval.code}
| if (!${eval.isNull} && ${eval.value}) {
| if ($nullCheck ${eval.value}) {
| ${consume(ctx, ctx.currentVars)}
| }
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// These benchmark are skipped in normal build
ignore("benchmark") {
// testWholeStage(200 << 20)
// testStddev(20 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
}
Expand Down

0 comments on commit e86f8f6

Please sign in to comment.