diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java index 0e9d9aa80db..35ef341ebb0 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java @@ -19,6 +19,7 @@ import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -54,6 +55,7 @@ import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.common.util.ThreadUtils; +import org.apache.celeborn.plugin.flink.fallback.ShuffleFallbackPolicy; import org.apache.celeborn.plugin.flink.fallback.ShuffleFallbackPolicyRunner; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; @@ -63,8 +65,8 @@ public class RemoteShuffleMaster implements ShuffleMaster { private final ShuffleMasterContext shuffleMasterContext; // Flink JobId -> Celeborn register shuffleIds private final Map> jobShuffleIds = JavaUtils.newConcurrentHashMap(); - private final ConcurrentHashMap.KeySetView nettyJobIds = - ConcurrentHashMap.newKeySet(); + private final ConcurrentHashMap jobFallbackPolicies = + JavaUtils.newConcurrentHashMap(); private String celebornAppId; private volatile LifecycleManager lifecycleManager; private final ShuffleTaskInfo shuffleTaskInfo = new ShuffleTaskInfo(); @@ -106,18 +108,21 @@ public void registerJob(JobShuffleContext context) { } try { - if (nettyShuffleServiceFactory != null - && ShuffleFallbackPolicyRunner.applyFallbackPolicies(context, conf, lifecycleManager)) { - LOG.warn("Fallback to vanilla Flink NettyShuffleMaster for job: {}.", jobID); - nettyJobIds.add(jobID); - nettyShuffleMaster().registerJob(context); - } else { - Set previousShuffleIds = jobShuffleIds.putIfAbsent(jobID, new HashSet<>()); - if (previousShuffleIds != null) { - throw new RuntimeException("Duplicated registration job: " + jobID); + if (nettyShuffleServiceFactory != null) { + Optional shuffleFallbackPolicy = + ShuffleFallbackPolicyRunner.getActivatedFallbackPolicy(context, conf, lifecycleManager); + if (shuffleFallbackPolicy.isPresent()) { + LOG.warn("Fallback to vanilla Flink NettyShuffleMaster for job: {}.", jobID); + jobFallbackPolicies.put(jobID, shuffleFallbackPolicy.get().getClass().getName()); + nettyShuffleMaster().registerJob(context); + return; } - shuffleResourceTracker.registerJob(context); } + Set previousShuffleIds = jobShuffleIds.putIfAbsent(jobID, new HashSet<>()); + if (previousShuffleIds != null) { + throw new RuntimeException("Duplicated registration job: " + jobID); + } + shuffleResourceTracker.registerJob(context); } catch (CelebornIOException e) { throw new RuntimeException(e); } @@ -126,7 +131,7 @@ public void registerJob(JobShuffleContext context) { @Override public void unregisterJob(JobID jobID) { LOG.info("Unregister job: {}.", jobID); - if (nettyJobIds.remove(jobID)) { + if (jobFallbackPolicies.remove(jobID) != null) { nettyShuffleMaster().unregisterJob(jobID); return; } @@ -152,8 +157,13 @@ public CompletableFuture registerPartitionWithProducer( JobID jobID, PartitionDescriptor partitionDescriptor, ProducerDescriptor producerDescriptor) { return CompletableFuture.supplyAsync( () -> { - if (nettyJobIds.contains(jobID)) { + lifecycleManager.shuffleCount().increment(); + String jobFallbackPolicy = jobFallbackPolicies.get(jobID); + if (jobFallbackPolicy != null) { try { + lifecycleManager + .shuffleFallbackCounts() + .compute(jobFallbackPolicy, (key, value) -> value == null ? 1L : value + 1L); return nettyShuffleMaster() .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); @@ -270,7 +280,7 @@ public MemorySize computeShuffleMemorySizeForTask( @Override public void close() throws Exception { try { - nettyJobIds.clear(); + jobFallbackPolicies.clear(); jobShuffleIds.clear(); LifecycleManager manager = lifecycleManager; if (null != manager) { @@ -318,7 +328,12 @@ private NettyShuffleMaster nettyShuffleMaster() { } @VisibleForTesting - public ConcurrentHashMap.KeySetView nettyJobIds() { - return nettyJobIds; + public LifecycleManager lifecycleManager() { + return lifecycleManager; + } + + @VisibleForTesting + public ConcurrentHashMap jobFallbackPolicies() { + return jobFallbackPolicies; } } diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/fallback/ShuffleFallbackPolicyRunner.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/fallback/ShuffleFallbackPolicyRunner.java index cfff6aaa2b4..009fddc96f3 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/fallback/ShuffleFallbackPolicyRunner.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/fallback/ShuffleFallbackPolicyRunner.java @@ -32,7 +32,7 @@ public class ShuffleFallbackPolicyRunner { private static final List FALLBACK_POLICIES = ShuffleFallbackPolicyFactory.getShuffleFallbackPolicies(); - public static boolean applyFallbackPolicies( + public static Optional getActivatedFallbackPolicy( JobShuffleContext shuffleContext, CelebornConf celebornConf, LifecycleManager lifecycleManager) @@ -44,11 +44,11 @@ public static boolean applyFallbackPolicies( shuffleFallbackPolicy.needFallback( shuffleContext, celebornConf, lifecycleManager)) .findFirst(); - boolean needFallback = fallbackPolicy.isPresent(); - if (needFallback && FallbackPolicy.NEVER.equals(celebornConf.flinkShuffleFallbackPolicy())) { + if (fallbackPolicy.isPresent() + && FallbackPolicy.NEVER.equals(celebornConf.flinkShuffleFallbackPolicy())) { throw new CelebornIOException( "Fallback to flink built-in shuffle implementation is prohibited."); } - return needFallback; + return fallbackPolicy; } } diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index 2b53f533ccf..19416032698 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -49,6 +49,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -91,9 +92,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -111,6 +112,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -128,6 +130,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -140,6 +143,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +151,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index 2b53f533ccf..19416032698 100644 --- a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -49,6 +49,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -91,9 +92,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -111,6 +112,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -128,6 +130,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -140,6 +143,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +151,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.16/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.16/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index 525c67bc497..9eb36efe932 100644 --- a/client-flink/flink-1.16/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.16/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -42,8 +42,10 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.ShuffleMasterContext; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; import org.junit.After; @@ -56,6 +58,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -118,6 +121,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -135,6 +139,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +152,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -154,6 +160,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.17/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.17/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index cf240eb7395..be70f0f4ad7 100644 --- a/client-flink/flink-1.17/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.17/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -42,8 +42,10 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.ShuffleMasterContext; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; import org.junit.After; @@ -56,6 +58,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -118,6 +121,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -135,6 +139,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +152,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -154,6 +160,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.18/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.18/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index cf240eb7395..be70f0f4ad7 100644 --- a/client-flink/flink-1.18/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.18/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -42,8 +42,10 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.ShuffleMasterContext; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; import org.junit.After; @@ -56,6 +58,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -118,6 +121,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -135,6 +139,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +152,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -154,6 +160,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.19/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.19/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index 01e2bc8eba1..451f6001bec 100644 --- a/client-flink/flink-1.19/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.19/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -42,8 +42,10 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.ShuffleMasterContext; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; import org.junit.After; @@ -56,6 +58,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -98,9 +101,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -118,6 +121,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -135,6 +139,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -147,6 +152,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -154,6 +160,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException { diff --git a/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java index da88d772fd1..4b3f81aa480 100644 --- a/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java +++ b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterSuiteJ.java @@ -44,9 +44,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.JobShuffleContext; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.PartitionWithMetrics; import org.apache.flink.runtime.shuffle.ProducerDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.ShuffleMasterContext; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; import org.junit.After; @@ -59,6 +61,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.FallbackPolicy; import org.apache.celeborn.common.util.Utils$; +import org.apache.celeborn.plugin.flink.fallback.ForceFallbackPolicy; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; public class RemoteShuffleMasterSuiteJ { @@ -101,9 +104,9 @@ public void testRegisterJobWithForceFallbackPolicy() { JobID jobID = JobID.generate(); JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); remoteShuffleMaster.registerJob(jobShuffleContext); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().contains(jobID)); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().containsKey(jobID)); remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId()); - Assert.assertTrue(remoteShuffleMaster.nettyJobIds().isEmpty()); + Assert.assertTrue(remoteShuffleMaster.jobFallbackPolicies().isEmpty()); } @Test @@ -121,6 +124,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource(); ShuffleResourceDescriptor mapPartitionShuffleDescriptor = shuffleResource.getMapPartitionShuffleDescriptor(); @@ -138,6 +142,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(2, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -150,6 +155,7 @@ public void testRegisterPartitionWithProducer() remoteShuffleMaster .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) .get(); + Assert.assertEquals(3, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); @@ -157,6 +163,32 @@ public void testRegisterPartitionWithProducer() Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } + @Test + public void testRegisterPartitionWithProducerForForceFallbackPolicy() + throws UnknownHostException, ExecutionException, InterruptedException { + configuration.setString( + CelebornConf.FLINK_SHUFFLE_FALLBACK_POLICY().key(), FallbackPolicy.ALWAYS.name()); + remoteShuffleMaster = createShuffleMaster(configuration, new NettyShuffleServiceFactory()); + JobID jobID = JobID.generate(); + JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID); + remoteShuffleMaster.registerJob(jobShuffleContext); + + IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); + PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0); + ProducerDescriptor producerDescriptor = createProducerDescriptor(); + ShuffleDescriptor shuffleDescriptor = + remoteShuffleMaster + .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor) + .get(); + Assert.assertTrue(shuffleDescriptor instanceof NettyShuffleDescriptor); + Assert.assertEquals(1, remoteShuffleMaster.lifecycleManager().shuffleCount().sum()); + Map shuffleFallbackCounts = + remoteShuffleMaster.lifecycleManager().shuffleFallbackCounts(); + Assert.assertEquals(1, shuffleFallbackCounts.size()); + Assert.assertEquals( + 1L, shuffleFallbackCounts.get(ForceFallbackPolicy.class.getName()).longValue()); + } + @Test public void testRegisterMultipleJobs() throws UnknownHostException, ExecutionException, InterruptedException {