-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize pipeline schedule #94
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
add pipeline shared module wrapper and update load batch
* added model parallel process group for amp and clip grad * update amp and clip with model parallel process group
micro batch offload
optimize pipeline gpu memory usage
ver217
requested review from
FrankLeeeee
and removed request for
FrankLeeeee
December 29, 2021 05:48
ver217
requested review from
FrankLeeeee
and removed request for
FrankLeeeee
December 29, 2021 08:55
ver217
requested review from
FrankLeeeee
and removed request for
FrankLeeeee
December 30, 2021 04:27
Hi @ver217 . In your demo, I do not see
|
ver217
requested review from
FrankLeeeee
and removed request for
FrankLeeeee
December 30, 2021 07:32
FrankLeeeee
approved these changes
Dec 30, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add pipeline shared module wrapper
This feature is especially useful when train GPT/BERT whose word embedding layer is used at the first and the end.
Usage:
PipelineSharedModuleWrapper
must be initialized in all ranks, andPipelineSharedModuleWrapper.register_module()
should be called in the ranks which share themodule. Modules
have to be moved to corresponding device based on your distributed backend before callingregister_module()
.Update
load_batch()
of scheduleWe update the rule of loading batch in schedule to support GPT/BERT training with pipeline parallelism.
Now please make sure the item your dataset returned is a tuple of
(data, label)
, and the type of data and label must betorch.Tensor
ordict
. When you setsync_data
to True in schedule, you must make sure the values of thedict
aretorch.Tensor
. Note that when your dataset returnsdict
, the keys must be the same as arguments in yourmodel.forward()
orloss_function.forward()
.When using pipeline parallelism, the input of first layer is from dataloader. For other layers, the first argument of forward is the output of the previous pipeline stage and other arguments are from dataloader. Note that each layer can only return one tensor in forward().
Optimize GPU memory usage of pipeline schedule
Add a argument (
return_output_label
) inschedule.forward_backward_step
,trainer.fit
andtrainer.evaluate
. The output of model and labels won't be returned, which can further reduce GPU memory usage especially when using pipeline parallelism.Optimize loss accumulation in pipeline schedule. Use
loss.detach()
when accumulating it to avoid unexpected large memory usage.Example:
Reduce communication of pipeline schedule
Add a argument
tensor_shape
forPipelineSchedule
andInterleavedPipelineSchedule
. You can set this argument to aUnion[torch.Size, List[int], Tuple[int]]
if the tensor shapes transmitted along pipeline are the same and fixed during training. By setting this, the communication will be further reduced.Example: