Skip to content

Commit

Permalink
[Distributed] Add dp_gradient_sync_after_accumulate (PaddlePaddle#8045)
Browse files Browse the repository at this point in the history
* add dp_gradient_sync_after_accumulate

* recover run_pretrain_auto_static.sh

* recover

* use gradient_sync_after_accumulate

* change name under DP_OPTIMIZATION

* add doc for data_parallel_config

* update notes

* add note and recover codes

* add note
  • Loading branch information
AndSonder authored Mar 8, 2024
1 parent 95c8b24 commit 67964cf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
10 changes: 10 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,16 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_dp_comm_overlap, fuse data parallel gradient communication.
--data_parallel_config
对于数据并行,一些选项会影响训练性能,这里将一些选项配置集中管理,以str形式传入配置.
支持如下选项:
enable_allreduce_avg_in_gradinent_scale : 在数据并行中, 替换`allreduce_sum + scale`模式为`allreduce_avg`, 以提高性能. 仅支持auto模式.
gradient_sync_after_accumulate : 当梯度累积开启时, 将梯度同步操作从backward阶段移动到optimizer阶段, 以减少同步次数, 提高性能, 但会增加显存占用. 仅支持auto模式.
Some additional configs which affect data parallel performance, we provide some option to config it.
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.
--recompute
是否使用重计算训练。可以节省显存。
Expand Down
6 changes: 5 additions & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class TrainingArguments:
Some additional configs which affect data parallel performance, we provide some option to config it.
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.
tensor_parallel_config (`str`, *optional*)(
Some additional configs which affect model parallel performance, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -582,6 +583,7 @@ class TrainingArguments:
"Some additional configs which affect data parallel performance, we provide some option to config it."
"following config is support:\n"
"enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. \n"
"gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now. \n"
)
},
)
Expand Down Expand Up @@ -1184,12 +1186,14 @@ def is_segment_parallel_supported():
data_parallel_config = set(self.data_parallel_config.split(" "))
for x in data_parallel_config:
if len(x) > 0:
if x not in ["enable_allreduce_avg_in_gradinent_scale"]:
if x not in ["enable_allreduce_avg_in_gradinent_scale", "gradient_sync_after_accumulate"]:
raise ValueError(
f"Found unknown data parallel config {x}, accpet config is enable_allreduce_avg_in_gradinent_scale."
)
if "enable_allreduce_avg_in_gradinent_scale" in data_parallel_config:
strategy.gradient_scale_using_allreduce_avg = True
if "gradient_sync_after_accumulate" in data_parallel_config:
strategy.dp_optimization.gradient_sync_after_accumulate = True

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
Expand Down

0 comments on commit 67964cf

Please sign in to comment.