Skip to content

Commit

Permalink
[CELEBORN-1271] Fix unregisterShuffle with celeborn.client.spark.fetc…
Browse files Browse the repository at this point in the history
…h.throwsFetchFailure disabled

### What changes were proposed in this pull request?
per https://issues.apache.org/jira/browse/CELEBORN-1271
fix the bug with SparkShuffleManager.unregisterShuffle when celeborn.client.spark.fetch.throwsFetchFailure=false

### Why are the changes needed?
the bug causes shuffle data can't be cleaned with unregisterShuffle

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
manual tested

Closes apache#2305 from ErikFang/CELEBORN-1271-fix-unregisterShuffle.

Authored-by: Erik.fang <fmerik@gmail.com>
Signed-off-by: waitinfuture <zky.zhoukeyong@alibaba-inc.com>
  • Loading branch information
ErikFang authored and waitinfuture committed Feb 29, 2024
1 parent d5a1bcd commit 6d9fbf5
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ public boolean unregisterShuffle(int appShuffleId) {
}
// For Spark driver side trigger unregister shuffle.
if (lifecycleManager != null) {
lifecycleManager.unregisterAppShuffle(appShuffleId);
lifecycleManager.unregisterAppShuffle(
appShuffleId, celebornConf.clientFetchThrowsFetchFailure());
}
// For Spark executor side cleanup shuffle related info.
if (shuffleClient != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ public boolean unregisterShuffle(int appShuffleId) {
}
// For Spark driver side trigger unregister shuffle.
if (lifecycleManager != null) {
lifecycleManager.unregisterAppShuffle(appShuffleId);
lifecycleManager.unregisterAppShuffle(
appShuffleId, celebornConf.clientFetchThrowsFetchFailure());
}
// For Spark executor side cleanup shuffle related info.
if (shuffleClient != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)

@VisibleForTesting
def getUnregisterShuffleTime(): ConcurrentHashMap[Int, Long] =
unregisterShuffleTime

val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
new util.function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]]() {
override def apply(s: Int): ConcurrentHashMap[Int, PartitionLocation] = {
Expand Down Expand Up @@ -969,16 +973,20 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
logInfo(s"Unregister for $shuffleId success.")
}

def unregisterAppShuffle(appShuffleId: Int): Unit = {
def unregisterAppShuffle(appShuffleId: Int, hasMapping: Boolean): Unit = {
logInfo(s"Unregister appShuffleId $appShuffleId starts...")
appShuffleDeterminateMap.remove(appShuffleId)
val shuffleIds = shuffleIdMapping.remove(appShuffleId)
if (shuffleIds != null) {
shuffleIds.synchronized(
shuffleIds.values.map {
case (shuffleId, _) =>
unregisterShuffle(shuffleId)
})
if (hasMapping) {
val shuffleIds = shuffleIdMapping.remove(appShuffleId)
if (shuffleIds != null) {
shuffleIds.synchronized(
shuffleIds.values.map {
case (shuffleId, _) =>
unregisterShuffle(shuffleId)
})
}
} else {
unregisterShuffle(appShuffleId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.celeborn.tests.spark
import java.io.File
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.{SparkConf, SparkContextHelper, TaskContext}
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
Expand Down Expand Up @@ -123,6 +123,46 @@ class CelebornFetchFailureSuite extends AnyFunSuite
assert(elem._2.mkString(",").equals(value))
}

val shuffleMgr = SparkContextHelper.env
.shuffleManager
.asInstanceOf[TestCelebornShuffleManager]
val lifecycleManager = shuffleMgr.getLifecycleManager

shuffleMgr.unregisterShuffle(0)
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(1))

sparkSession.stop()
}

test("celeborn spark integration test - unregister shuffle with throwsFetchFailure disabled") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.shuffle.enabled", "true")
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false")
.getOrCreate()

val value = Range(1, 10000).mkString(",")
val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
.map { i => (i, value) }.groupByKey(16).collect()

// verify result
assert(tuples.length == 10000)
for (elem <- tuples) {
assert(elem._2.mkString(",").equals(value))
}

val shuffleMgr = SparkContextHelper.env
.shuffleManager
.asInstanceOf[SparkShuffleManager]
val lifecycleManager = shuffleMgr.getLifecycleManager

shuffleMgr.unregisterShuffle(0)
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))

sparkSession.stop()
}

Expand Down

0 comments on commit 6d9fbf5

Please sign in to comment.