Skip to content

Commit

Permalink
[SPARK-8305] [SPARK-8190] [SQL] improve codegen
Browse files Browse the repository at this point in the history
This PR fix a few small issues about codgen:

1. cast decimal to boolean
2. do not inline literal with null
3. improve SpecificRow.equals()
4. test expressions with optimized express
5. fix compare with BinaryType

cc rxin chenghao-intel

Author: Davies Liu <davies@databricks.com>

Closes apache#6755 from davies/fix_codegen and squashes the following commits:

ef27343 [Davies Liu] address comments
6617ea6 [Davies Liu] fix scala tyle
70b7dda [Davies Liu] improve codegen
  • Loading branch information
Davies Liu authored and rxin committed Jun 11, 2015
1 parent 424b007 commit 1191c3e
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 129 deletions.
21 changes: 21 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ public int fieldIndex(String name) {
throw new UnsupportedOperationException();
}

/**
* A generic version of Row.equals(Row), which is used for tests.
*/
@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
int n = size();
if (n != row.size()) {
return false;
}
for (int i = 0; i < n; i ++) {
if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
return false;
}
}
return true;
}
return false;
}

@Override
public Row copy() {
final int n = size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
buildCast[Decimal](_, _ != 0)
buildCast[Decimal](_, _ != Decimal(0))
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
Expand Down Expand Up @@ -454,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
defineCodeGen(ctx, ev, c => s"!$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,23 @@ class CodeGenContext {
}

/**
* Returns a function to generate equal expression in Java
* Generate code for equal expression in Java
*/
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
case BinaryType => { case (eval1, eval2) =>
s"java.util.Arrays.equals($eval1, $eval2)" }
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
{ case (eval1, eval2) => s"$eval1 == $eval2" }
case other =>
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case other => s"$c1.equals($c2)"
}

/**
* Generate code for compare expression in Java
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// Use signum() to keep any small difference bwteen float/double
case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case other => s"$c1.compare($c2)"
}

/**
Expand All @@ -182,6 +190,16 @@ class CodeGenContext {
* Returns true if the data type has a special accessor and setter in [[Row]].
*/
def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)

/**
* List of data types who's Java type is primitive type
*/
val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)

/**
* Returns true if the Java type is primitive type
*/
def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
}
"""


logDebug(s"code for ${expressions.mkString(",")}:\n$code")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Private
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BinaryType, NumericType}

/**
* Inherits some default implementation for Java from `Ordering[Row]`
Expand Down Expand Up @@ -55,39 +54,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
val evalA = order.child.gen(ctx)
val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
s"""
{
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
j = j + 1;
}
int d = x.length - y.length;
if (d != 0) {
return d;
}
}"""
case _: NumericType =>
s"""
if (${evalA.primitive} != ${evalB.primitive}) {
if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
}
}"""
case _ =>
s"""
int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
}

s"""
i = $a;
${evalA.code}
Expand All @@ -100,7 +66,10 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare
int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}
}
"""
}.mkString("\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")

val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType
|| dataType == IntegerType && e.dataType == DateType
|| dataType == LongType && e.dataType == TimestampType =>
s"case $i: return c$i;"
case _ => ""
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
List(s"case $i: return c$i;")
case _ => Nil
}.mkString("\n ")
if (cases.count(_ != '\n') > 0) {
if (cases.length > 0) {
s"""
@Override
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
Expand All @@ -89,29 +87,30 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
switch (i) {
$cases
}
return ${ctx.defaultValue(dataType)};
throw new IllegalArgumentException("Invalid index: " + i
+ " in ${ctx.accessorForType(dataType)}");
}"""
} else {
""
}
}.mkString("\n")

val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType
|| dataType == IntegerType && e.dataType == DateType
|| dataType == LongType && e.dataType == TimestampType =>
s"case $i: { c$i = value; return; }"
case _ => ""
}.mkString("\n")
if (cases.count(_ != '\n') > 0) {
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
List(s"case $i: { c$i = value; return; }")
case _ => Nil
}.mkString("\n ")
if (cases.length > 0) {
s"""
@Override
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
}
throw new IllegalArgumentException("Invalid index: " + i +
" in ${ctx.mutatorForType(dataType)}");
}"""
} else {
""
Expand Down Expand Up @@ -139,9 +138,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
s"""
if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
return false;
}
if (nullBits[$i] != row.nullBits[$i] ||
(!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) {
return false;
}
"""
}.mkString("\n")

Expand Down Expand Up @@ -174,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}

public int size() { return ${expressions.length};}
private boolean[] nullBits = new boolean[${expressions.length}];
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }

Expand Down Expand Up @@ -207,9 +207,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
if (row.length() != size()) return false;
if (other instanceof SpecificRow) {
SpecificRow row = (SpecificRow) other;
$columnChecks
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
${cond.code}
if (${keyEval.isNull} && ${cond.isNull} ||
!${keyEval.isNull} && !${cond.isNull}
&& ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
// change the isNull and primitive to consts, to inline them
if (value == null) {
ev.isNull = "true"
ev.primitive = ctx.defaultValue(dataType)
""
s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};"
} else {
dataType match {
case BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
left.dataType match {
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case DateType | TimestampType => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case other => defineCodeGen (ctx, ev, {
(c1, c2) => s"$c1.compare($c2) $symbol 0"
})
if (ctx.isPrimitiveType(left.dataType)) {
// faster version
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}

Expand All @@ -280,8 +275,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
}
}

Expand All @@ -307,7 +303,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive)
ev.isNull = "false"
eval1.code + eval2.code + s"""
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,12 @@ object TypeUtils {

def getOrdering(t: DataType): Ordering[Any] =
t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]

def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
x.length - y.length
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.util.TypeUtils


/**
Expand All @@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType {

private[sql] val ordering = new Ordering[InternalType] {
def compare(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
x.length - y.length
TypeUtils.compareBinary(x, y)
}
}

Expand Down
Loading

0 comments on commit 1191c3e

Please sign in to comment.