Skip to content

Commit

Permalink
[SPARK-40175][CORE][SQL][MLLIB][DSTREAM][R] Optimize the performance …
Browse files Browse the repository at this point in the history
…of `keys.zip(values).toMap` code pattern

### What changes were proposed in this pull request?
This pr introduce two new `toMap` method to `o.a.spark.util.collection.Utils`,  use `while loop manually` style to optimize the performance of `keys.zip(values).toMap` code pattern in Spark.

### Why are the changes needed?
Performance improvement

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Actions

Closes apache#37876 from LuciferYang/SPARK-40175.

Authored-by: yangjie01 <yangjie01@baidu.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
LuciferYang authored and cloud-fan committed Sep 16, 2022
1 parent 128479f commit 8b6b3be
Show file tree
Hide file tree
Showing 17 changed files with 67 additions and 21 deletions.
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/SerDe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.io.{DataInputStream, DataOutputStream}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Time, Timestamp}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.util.collection.Utils

/**
* Utility functions to serialize, deserialize objects to / from R
*/
Expand Down Expand Up @@ -236,7 +237,7 @@ private[spark] object SerDe {
val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]]
val values = readList(in, jvmObjectTracker)

keys.zip(values).toMap.asJava
Utils.toJavaMap(keys, values)
} else {
new java.util.HashMap[Object, Object]()
}
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace}
import org.apache.spark.util.collection.{Utils => CUtils}
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/** CallSite represents a place in user code. It can have a short and a long form. */
Expand Down Expand Up @@ -1718,7 +1719,8 @@ private[spark] object Utils extends Logging {
assert(files.length == fileLengths.length)
val startIndex = math.max(start, 0)
val endIndex = math.min(end, fileLengths.sum)
val fileToLength = files.zip(fileLengths).toMap
val fileToLength = CUtils.toMap(files, fileLengths)

logDebug("Log files: \n" + fileToLength.mkString("\n"))

val stringBuffer = new StringBuffer((endIndex - startIndex).toInt)
Expand Down
29 changes: 29 additions & 0 deletions core/src/main/scala/org/apache/spark/util/collection/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.util.collection

import java.util.Collections

import scala.collection.JavaConverters._
import scala.collection.immutable

import com.google.common.collect.{Iterators => GuavaIterators, Ordering => GuavaOrdering}

Expand Down Expand Up @@ -62,4 +65,30 @@ private[spark] object Utils {
*/
def sequenceToOption[T](input: Seq[Option[T]]): Option[Seq[T]] =
if (input.forall(_.isDefined)) Some(input.flatten) else None

/**
* Same function as `keys.zip(values).toMap`, but has perf gain.
*/
def toMap[K, V](keys: Iterable[K], values: Iterable[V]): Map[K, V] = {
val builder = immutable.Map.newBuilder[K, V]
val keyIter = keys.iterator
val valueIter = values.iterator
while (keyIter.hasNext && valueIter.hasNext) {
builder += (keyIter.next(), valueIter.next()).asInstanceOf[(K, V)]
}
builder.result()
}

/**
* Same function as `keys.zip(values).toMap.asJava`, but has perf gain.
*/
def toJavaMap[K, V](keys: Iterable[K], values: Iterable[V]): java.util.Map[K, V] = {
val map = new java.util.HashMap[K, V]()
val keyIter = keys.iterator
val valueIter = values.iterator
while (keyIter.hasNext && valueIter.hasNext) {
map.put(keyIter.next(), valueIter.next())
}
Collections.unmodifiableMap(map)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.Utils

/**
* Evaluator for ranking algorithms.
Expand Down Expand Up @@ -155,7 +156,7 @@ class RankingMetrics[T: ClassTag] @Since("1.2.0") (predictionAndLabels: RDD[_ <:
rdd.map { case (pred, lab, rel) =>
val useBinary = rel.isEmpty
val labSet = lab.toSet
val relMap = lab.zip(rel).toMap
val relMap = Utils.toMap(lab, rel)
if (useBinary && lab.size != rel.size) {
logWarning(
"# of ground truth set and # of relevance value set should be equal, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.Utils

/**
* Functions to convert Scala types to Catalyst types and vice versa.
Expand Down Expand Up @@ -229,7 +230,7 @@ object CatalystTypeConverters {
val convertedValues =
if (isPrimitive(valueType)) values else values.map(valueConverter.toScala)

convertedKeys.zip(convertedValues).toMap
Utils.toMap(convertedKeys, convertedValues)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{Utils => CUtils}

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]], [[EmptyFunctionRegistry]] and
Expand Down Expand Up @@ -3457,7 +3458,7 @@ class Analyzer(override val catalogManager: CatalogManager)
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
cols.size, query.output.size, query)
}
val nameToQueryExpr = cols.zip(query.output).toMap
val nameToQueryExpr = CUtils.toMap(cols, query.output)
// Static partition columns in the table output should not appear in the column list
// they will be handled in another rule ResolveInsertInto
val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.util.collection.Utils

/**
* Decorrelate the inner query by eliminating outer references and create domain joins.
Expand Down Expand Up @@ -346,7 +347,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
val domains = attributes.map(_.newInstance())
// A placeholder to be rewritten into domain join.
val domainJoin = DomainJoin(domains, plan)
val outerReferenceMap = attributes.zip(domains).toMap
val outerReferenceMap = Utils.toMap(attributes, domains)
// Build join conditions between domain attributes and outer references.
// EqualNullSafe is used to make sure null key can be joined together. Note
// outer referenced attributes can be changed during the outer query optimization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.CTE
import org.apache.spark.util.collection.Utils

/**
* Infer predicates and column pruning for [[CTERelationDef]] from its reference points, and push
Expand Down Expand Up @@ -71,7 +72,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {

case PhysicalOperation(projects, predicates, ref: CTERelationRef) =>
val (cteDef, precedence, preds, attrs) = cteMap(ref.cteId)
val attrMapping = ref.output.zip(cteDef.output).map{ case (r, d) => r -> d }.toMap
val attrMapping = Utils.toMap(ref.output, cteDef.output)
val newPredicates = if (isTruePredicate(preds)) {
preds
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPl
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.collection.Utils

/**
* This rule rewrites an aggregate query with distinct aggregations into an expanded double
Expand Down Expand Up @@ -265,7 +266,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap
val distinctAggFilterAttrLookup = Utils.toMap(distinctAggFilters, maxConds.map(_.toAttribute))
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util

import java.util.{Map => JavaMap}

import org.apache.spark.util.collection.Utils

/**
* A simple `MapData` implementation which is backed by 2 arrays.
*
Expand Down Expand Up @@ -129,20 +131,19 @@ object ArrayBasedMapData {
def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
val keys = map.keyArray.asInstanceOf[GenericArrayData].array
val values = map.valueArray.asInstanceOf[GenericArrayData].array
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toScalaMap(keys: scala.collection.Seq[Any],
values: scala.collection.Seq[Any]): Map[Any, Any] = {
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
import scala.collection.JavaConverters._
keys.zip(values).toMap.asJava
Utils.toJavaMap(keys, values)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType.YEAR
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.collection.Utils
/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
* values; instead, they're biased to return "interesting" values (such as maximum / minimum values)
Expand Down Expand Up @@ -340,7 +341,7 @@ object RandomDataGenerator {
count += 1
}
val values = Seq.fill(keys.size)(valueGenerator())
keys.zip(values).toMap
Utils.toMap(keys, values)
}
}
case StructType(fields) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.Utils

object ExternalRDD {

Expand Down Expand Up @@ -106,7 +107,7 @@ case class LogicalRDD(
session :: originStats :: originConstraints :: Nil

override def newInstance(): LogicalRDD.this.type = {
val rewrite = output.zip(output.map(_.newInstance())).toMap
val rewrite = Utils.toMap(output, output.map(_.newInstance()))

val rewrittenPartitioning = outputPartitioning match {
case p: Expression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{Utils => CUtils}

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
Expand Down Expand Up @@ -218,7 +219,7 @@ object AggUtils {
}

// 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.util.PartitioningUtils
import org.apache.spark.util.collection.Utils

/**
* Analyzes a given set of partitions to generate per-partition statistics, which will be used in
Expand Down Expand Up @@ -147,7 +148,7 @@ case class AnalyzePartitionCommand(
r.get(i).toString
}
}
val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap
val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues)
val count = BigInt(r.getLong(partitionColumns.size))
(spec, count)
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.Utils

object PushDownUtils {
/**
Expand Down Expand Up @@ -203,7 +204,7 @@ object PushDownUtils {
def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
val nameToAttr = Utils.toMap(relation.output.map(_.name), relation.output)
val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
cleaned.toAttributes.map {
// we have to keep the attribute id during transformation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.UnaryExecNode
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.Utils

/**
* Holds common logic for window operators
Expand Down Expand Up @@ -69,7 +70,7 @@ trait WindowExecBase extends UnaryExecNode {
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val unboundToRefMap = Utils.toMap(expressions, references)
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.collection.Utils

/**
* A class that tries to schedule receivers with evenly distributed. There are two phases for
Expand Down Expand Up @@ -135,7 +136,7 @@ private[streaming] class ReceiverSchedulingPolicy {
leastScheduledExecutors += executor
}

receivers.map(_.streamId).zip(scheduledLocations.map(_.toSeq)).toMap
Utils.toMap(receivers.map(_.streamId), scheduledLocations.map(_.toSeq))
}

/**
Expand Down

0 comments on commit 8b6b3be

Please sign in to comment.