Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize pipeline schedule #94

Merged
merged 14 commits into from
Dec 30, 2021
Merged

Optimize pipeline schedule #94

merged 14 commits into from
Dec 30, 2021

Conversation

ver217
Copy link
Member

@ver217 ver217 commented Dec 29, 2021

Add pipeline shared module wrapper

This feature is especially useful when train GPT/BERT whose word embedding layer is used at the first and the end.

Usage:

pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = _partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
    kwargs['num_layers'] = end - start
    kwargs['first'] = start == 0
    kwargs['last'] = end == num_layers
    print(f'==> Rank{rank} build layer {start}-{end}, total {num_layers}')
    chunk = PipelineGPT1D(**kwargs).to(device)
    if start == 0:
        wrapper.register_module(chunk.embedding.word_embeddings)
    elif end == num_layers:
        wrapper.register_module(chunk.head)
    models.append(chunk)
if len(models) == 1:
    model = models[0]
else:
    model = nn.ModuleList(models)

PipelineSharedModuleWrapper must be initialized in all ranks, and PipelineSharedModuleWrapper.register_module() should be called in the ranks which share the module. Modules have to be moved to corresponding device based on your distributed backend before calling register_module().

Update load_batch() of schedule

We update the rule of loading batch in schedule to support GPT/BERT training with pipeline parallelism.

Now please make sure the item your dataset returned is a tuple of (data, label), and the type of data and label must be torch.Tensor or dict. When you set sync_data to True in schedule, you must make sure the values of the dict are torch.Tensor. Note that when your dataset returns dict, the keys must be the same as arguments in your model.forward() or loss_function.forward().

When using pipeline parallelism, the input of first layer is from dataloader. For other layers, the first argument of forward is the output of the previous pipeline stage and other arguments are from dataloader. Note that each layer can only return one tensor in forward().

Optimize GPU memory usage of pipeline schedule

Add a argument (return_output_label) in schedule.forward_backward_step, trainer.fit and trainer.evaluate. The output of model and labels won't be returned, which can further reduce GPU memory usage especially when using pipeline parallelism.

Optimize loss accumulation in pipeline schedule. Use loss.detach() when accumulating it to avoid unexpected large memory usage.

Example:

trainer.fit(
    train_dataloader=train_dataloader,
    epochs=num_epochs,
    test_interval=1,
    hooks=hook_list,
    display_progress=True,
    return_output_label=False
)

Reduce communication of pipeline schedule

Add a argument tensor_shape for PipelineSchedule and InterleavedPipelineSchedule. You can set this argument to a Union[torch.Size, List[int], Tuple[int]] if the tensor shapes transmitted along pipeline are the same and fixed during training. By setting this, the communication will be further reduced.

Example:

schedule = InterleavedPipelineSchedule(num_micro_batches,
                                               num_model_chunks, tensor_shape=tensor_shape)

@ver217 ver217 requested a review from FrankLeeeee December 29, 2021 05:32
@ver217 ver217 requested review from FrankLeeeee and removed request for FrankLeeeee December 29, 2021 05:48
@ver217 ver217 requested review from FrankLeeeee and removed request for FrankLeeeee December 29, 2021 08:55
@ver217 ver217 requested review from FrankLeeeee and removed request for FrankLeeeee December 30, 2021 04:27
@FrankLeeeee
Copy link
Contributor

Hi @ver217 . In your demo, I do not see PipelineGPT1D in the changed file. Also, the function _partition_uniform should be public but it starts with an underscore, if it will be called by the user. There is one question from me as well.

colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py

  1. In line 39, is division by world size required before all reduce?

@ver217 ver217 requested review from FrankLeeeee and removed request for FrankLeeeee December 30, 2021 07:32
@FrankLeeeee FrankLeeeee merged commit 96780e6 into main Dec 30, 2021
@ver217 ver217 deleted the feature/pipeline branch January 4, 2022 12:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants