Skip to content

Commit

Permalink
Optimize pipeline schedule (hpcaitech#94)
Browse files Browse the repository at this point in the history
* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (hpcaitech#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (hpcaitech#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (hpcaitech#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
  • Loading branch information
ver217 and FrankLeeeee authored Dec 30, 2021
1 parent e5b9f9a commit 96780e6
Show file tree
Hide file tree
Showing 29 changed files with 423 additions and 290 deletions.
7 changes: 1 addition & 6 deletions colossalai/amp/naive_amp/_fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,7 @@ def _unscale_main_grads_and_check_for_nan(self):
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR))

if is_using_pp():
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PIPELINE))
group=gpc.get_group(ParallelMode.MODEL))

# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
Expand Down
13 changes: 9 additions & 4 deletions colossalai/amp/torch_amp/_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.context import ParallelMode
import torch.distributed as dist
from colossalai.core import global_context as gpc
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors


class _MultiDeviceReplicator(object):
Expand Down Expand Up @@ -247,10 +248,14 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
device),
per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
for tensor in per_device_found_inf._per_device_tensors.values():
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR))
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals)
dist.all_reduce(coalesced,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL))
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
buf.copy_(synced)
return per_device_found_inf._per_device_tensors

def unscale_(self, optimizer):
Expand Down
16 changes: 8 additions & 8 deletions colossalai/builder/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _binary_search(weights, num):
return intervals


def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"

Expand All @@ -134,11 +134,11 @@ def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
return parts


def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)

intervals = _binary_search(weights, num_total)

Expand All @@ -151,7 +151,7 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts


def _count_layer_params(layers):
def count_layer_params(layers):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(layers)
Expand Down Expand Up @@ -201,11 +201,11 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
# Make a partition
if method == 'layer':
num_layers = len(layers)
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter':
param_counts = _count_layer_params(layers)
param_counts = count_layer_params(layers)
# print_rank_0(param_counts)
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else:
raise ValueError("Method should be a pre-set string in [layer, parameter]")

Expand Down Expand Up @@ -250,7 +250,7 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
"""
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*layers[start:end]))
Expand Down
3 changes: 2 additions & 1 deletion colossalai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
'2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence'
'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model'
}

# 1D parallel
Expand Down
3 changes: 3 additions & 0 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ def init_parallel_groups(self):
# LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))

# LSG: init model parallel process group for compatibility with amp and clip grad
pg_init.append(dict(type=INITIALIZER_MAPPING['model']))

if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
Expand Down
6 changes: 4 additions & 2 deletions colossalai/context/parallel_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class ParallelMode(Enum):
# common parallel
DATA = 'data'

# model parallel - containing tensor and pipeline parallel groups
# this is added to facilitate amp and grad clipping in hybrid parallel
MODEL = 'model'

# pipeline parallel
PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'

# containing all ranks in tensor parallel
TENSOR = 'tensor'
Expand Down
3 changes: 2 additions & 1 deletion colossalai/context/process_group_initializer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor
from .initializer_model import Initializer_Model
from .process_group_initializer import ProcessGroupInitializer

__all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
'Initializer_1D', 'ProcessGroupInitializer'
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
]
43 changes: 43 additions & 0 deletions colossalai/context/process_group_initializer/initializer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch.distributed as dist

from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode


@DIST_GROUP_INITIALIZER.register_module
class Initializer_Model(ProcessGroupInitializer):
'''A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel groups).
'''

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
self.num_group = self.world_size // self.model_parallel_size

def init_dist_group(self):
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: tuple
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MODEL

for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks)

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,4 @@ def init_dist_group(self):
process_group, ranks_in_group,
ParallelMode.PIPELINE)))

for k in range(pipe_group_size):
first = pipe_ranks[k]
second = pipe_ranks[(k + 1) % pipe_group_size]
ranks = [first, second]
group = dist.new_group(ranks)
if self.rank == first:
local_rank = 0
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_NEXT)))
elif self.rank == second:
local_rank = 1
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_PREV)))

return dist_settings
7 changes: 2 additions & 5 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
# -*- encoding: utf-8 -*-


import torch
from typing import List
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from colossalai.builder import build_gradient_handler
from colossalai.logging import get_dist_logger
from colossalai.utils import is_using_ddp, is_using_pp
from torch import Tensor


Expand Down Expand Up @@ -84,15 +81,15 @@ def step(self):

def backward(self, loss: Tensor):
"""Start backward propagation given the loss value computed by a loss function
:param loss: loss value computed by a loss function
:type loss: :class:`torch.Tensor`
"""
return self.optimizer.backward(loss)

def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor
:param loss: output tensor
:type loss: :class:`torch.Tensor`
:param grad: gradient passed back to the output
Expand Down
4 changes: 3 additions & 1 deletion colossalai/engine/gradient_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler

__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from collections import defaultdict


@GRADIENT_HANDLER.register_module
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""

def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups.
"""
if gpc.pipeline_parallel_size > 1:
# bucketize and all-reduce
buckets = defaultdict(lambda: defaultdict(list))
# Pack the buckets.
for param in self._model.parameters():
group = getattr(param, 'pipeline_shared_module_pg', None)
if param.requires_grad and param.grad is not None and group is not None:
tp = param.data.type()
buckets[group][tp].append(param)

# For each bucket, all-reduce and copy all-reduced grads.
for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items():
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
Loading

0 comments on commit 96780e6

Please sign in to comment.