Skip to content

Commit

Permalink
[SPARK-45071][SQL] Optimize the processing speed of `BinaryArithmetic…
Browse files Browse the repository at this point in the history
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](apache#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
  • Loading branch information
zzzzming95 authored and wangyum committed Sep 6, 2023
1 parent ba35140 commit 16e813c
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ abstract class BinaryArithmetic extends BinaryOperator

protected val evalMode: EvalMode.Value

private lazy val internalDataType: DataType = (left.dataType, right.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
resultDecimalType(p1, s1, p2, s2)
case _ => left.dataType
}

protected def failOnError: Boolean = evalMode match {
// The TRY mode executes as if it would fail on errors, except that it would capture the errors
// and return null results.
Expand All @@ -234,11 +240,7 @@ abstract class BinaryArithmetic extends BinaryOperator
case _ => super.checkInputDataTypes()
}

override def dataType: DataType = (left.dataType, right.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
resultDecimalType(p1, s1, p2, s2)
case _ => left.dataType
}
override def dataType: DataType = internalDataType

// When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
// needed are out of the range of available values, the scale is reduced up to 6, in order to
Expand Down

0 comments on commit 16e813c

Please sign in to comment.