Skip to content

Commit

Permalink
Fit sharding optimization for auto parallel llama (PaddlePaddle#8021)
Browse files Browse the repository at this point in the history
* Fit sharding optimization for auto parallel llama

* Add args enable_allreduce_avg_in_gradinent_scale

* Fix CI errors
  • Loading branch information
From00 authored Mar 5, 2024
1 parent 4b1c54b commit e3cb5d2
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ class TrainingArguments:
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
)
data_parallel_config (`str`, *optional*)(
Some additional configs which affect data parallel performance, we provide some option to config it.
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
tensor_parallel_config (`str`, *optional*)(
Some additional configs which affect model parallel performance, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -571,6 +575,16 @@ class TrainingArguments:
)
},
)
data_parallel_config: str = field(
default="",
metadata={
"help": (
"Some additional configs which affect data parallel performance, we provide some option to config it."
"following config is support:\n"
"enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. \n"
)
},
)
tensor_parallel_config: str = field(
default="",
metadata={
Expand Down Expand Up @@ -951,6 +965,7 @@ def __post_init__(self):
# TODO use paddle.distributed.is_initialized() after paddle 2.4rc
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
strategy = fleet.DistributedStrategy()
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
if self.pipeline_parallel_degree > 1:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
for x in pipeline_parallel_config:
Expand Down Expand Up @@ -1165,6 +1180,17 @@ def is_segment_parallel_supported():
warnings.warn("`offload` is not supported NOW!")

strategy = fleet.auto.Strategy()
if self.data_parallel_degree > 1:
data_parallel_config = set(self.data_parallel_config.split(" "))
for x in data_parallel_config:
if len(x) > 0:
if x not in ["enable_allreduce_avg_in_gradinent_scale"]:
raise ValueError(
f"Found unknown data parallel config {x}, accpet config is enable_allreduce_avg_in_gradinent_scale."
)
if "enable_allreduce_avg_in_gradinent_scale" in data_parallel_config:
strategy.gradient_scale_using_allreduce_avg = True

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
Expand Down Expand Up @@ -1254,9 +1280,9 @@ def is_segment_parallel_supported():
for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
# "enable_stage1_tensor_fusion",
# "enable_stage1_overlap",
# "enable_stage2_overlap",
"enable_stage1_tensor_fusion",
"enable_stage1_overlap",
"enable_stage2_overlap",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap."
Expand All @@ -1266,7 +1292,10 @@ def is_segment_parallel_supported():
"enable_stage1_overlap" in sharding_parallel_config
or "enable_stage2_overlap" in sharding_parallel_config
):
sharding.reduce_overlap = True
sharding.enable_overlap = True

if "enable_stage1_tensor_fusion" in sharding_parallel_config:
sharding.grad_bucket_size_numel = 210355872

if self.bf16 or self.fp16:
amp = strategy.amp
Expand Down

0 comments on commit e3cb5d2

Please sign in to comment.