Skip to content

Commit

Permalink
[zero] polish low level optimizer (#2473)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Jan 13, 2023
1 parent 8b7495d commit a5dc425
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 124 deletions.
30 changes: 10 additions & 20 deletions colossalai/zero/sharded_optim/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def split_half_float_double(tensor_list):
return buckets


def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[ProcessGroup] = None):
def reduce_tensor_dp_group(tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None):
"""
Reduce the tensor in the data parallel process group
Expand All @@ -128,36 +132,22 @@ def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[Proce
else:
tensor_to_reduce = tensor

if isinstance(pg, ProcessGroup):
group = pg.dp_process_group()
world_size = pg.dp_world_size()
else:
world_size = gpc.get_world_size(ParallelMode.DATA)
group = gpc.get_group(ParallelMode.DATA)

world_size = dist.get_world_size(group=group)
tensor_to_reduce.div_(world_size)

# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
use_all_reduce = dst_local_rank is None

if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
else:
if pg is not None:
ranks_in_group = pg.dp_rank_list()
else:
ranks_in_group = gpc.get_ranks_in_group(ParallelMode.DATA)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)

# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
if pg is not None:
local_rank = pg.dp_local_rank()
else:
local_rank = gpc.get_local_rank(ParallelMode.DATA)
if use_all_reduce or dst_rank == local_rank:
local_rank = dist.get_rank(group=group)
if use_all_reduce or dst_local_rank == local_rank:
tensor.copy_(tensor_to_reduce)

return tensor
Expand Down
17 changes: 5 additions & 12 deletions colossalai/zero/sharded_optim/bookkeeping/base_store.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
from typing import Optional

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
import torch.distributed as dist
from torch.distributed import ProcessGroup


class BaseStore:

def __init__(self, pg: Optional[ProcessGroup] = None):
if isinstance(pg, ProcessGroup):
self._world_size = pg.dp_world_size()
self._local_rank = pg.dp_local_rank()
else:
self._world_size = gpc.get_world_size(ParallelMode.DATA)
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)

@property
def world_size(self):
Expand Down
8 changes: 3 additions & 5 deletions colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Optional

from colossalai.tensor import ProcessGroup
from torch.distributed import ProcessGroup

from .base_store import BaseStore


class BucketStore(BaseStore):

def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(pg)
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
Expand Down
9 changes: 4 additions & 5 deletions colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import List, Optional
from typing import List

from torch import Tensor

from colossalai.tensor import ProcessGroup
from torch.distributed import ProcessGroup

from .base_store import BaseStore


class ParameterStore(BaseStore):

def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(pg)
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
Expand Down
135 changes: 69 additions & 66 deletions colossalai/zero/sharded_optim/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device

from ._utils import (
Expand All @@ -34,32 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__(
self,
optimizer: Optimizer,
pg: Optional[ProcessGroup] = None,
# grad scaler config
initial_scale=2**16,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=2000,
hysteresis=2,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,

# grad clipping
clip_grad_norm=0.0,
verbose=False,

# communication
reduce_bucket_size=1024 * 1024,
communication_dtype=None,
overlap_communication=False,

# stage 2
partition_grad=False,
# cpu offload
cpu_offload=False,

# forced dtype
forced_dtype=None):
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):

# TODO: add support for
# 1. fp16 master weights
Expand All @@ -76,31 +65,30 @@ def __init__(

self._cpu_offload = cpu_offload

self._pg = pg
if isinstance(pg, ProcessGroup):
self._local_rank = pg.dp_local_rank()
self._world_size = pg.dp_world_size()
self._dp_group = pg.dp_process_group()
if pg.tp_world_size() > 1:
self._mp_group = pg.tp_process_group()
else:
self._mp_group = None
elif pg is None:
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
self._world_size = colo_pg.dp_world_size()
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
self._dp_torch_group = colo_pg.dp_process_group()
self._mp_torch_group = None
if colo_pg.tp_world_size() > 1:
self._mp_torch_group = colo_pg.tp_process_group()
elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL

self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)

self._dp_group = gpc.get_group(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode)
else:
self._mp_group = None
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
raise TypeError(f"pg should be None or a ProcesGroup")
raise NotImplementedError
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
Expand Down Expand Up @@ -136,14 +124,9 @@ def __init__(

# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
if self._pg is not None:
self._param_store = ParameterStore(self._pg)
self._grad_store = GradientStore(self._pg)
self._bucket_store = BucketStore(self._pg)
else:
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group)
self._bucket_store = BucketStore(self._dp_torch_group)

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
Expand Down Expand Up @@ -224,6 +207,30 @@ def loss_scale(self):
def num_param_groups(self):
return len(self._fp16_param_groups)

def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg

def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
Expand All @@ -241,14 +248,6 @@ def _partition_param_list(self, param_list):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank

def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

###########################################################
# Backward Reduction Hook
###########################################################
Expand Down Expand Up @@ -384,10 +383,14 @@ def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):

with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_rank=reduce_rank,
pg=self._pg)
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)

# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
Expand Down Expand Up @@ -456,8 +459,8 @@ def step(self, closure=None):
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_group,
mp_group=self._mp_group)
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)

# create flat gradient for the flat fp32 params
Expand Down Expand Up @@ -497,7 +500,7 @@ def step(self, closure=None):
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)

for handle in handles:
Expand All @@ -519,11 +522,11 @@ def _check_overflow(self):
break

# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)

# all-reduce over model parallel group
if self._mp_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)

if self._found_overflow.item() > 0:
return True
Expand Down
12 changes: 3 additions & 9 deletions tests/test_zero/low_level_zero/test_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@ def exam_zero_1_2_grad_acc():
# create model
zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
Expand Down Expand Up @@ -86,7 +83,7 @@ def fwd_bwd_func(number, cur_data):
assert torch.equal(z1p.data, z2p.data)


def exam_zero_1_grad_acc(use_pg=True):
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
Expand All @@ -105,9 +102,7 @@ def exam_zero_1_grad_acc(use_pg=True):
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
pg = ProcessGroup() if use_pg else None #ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
Expand Down Expand Up @@ -158,9 +153,8 @@ def fwd_bwd_func(number, cur_data, check_flag):
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

exam_zero_1_grad_acc(True)
exam_zero_1_grad_acc(False)
# exam_zero_1_2_grad_acc()
exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc()


@pytest.mark.dist
Expand Down
Loading

0 comments on commit a5dc425

Please sign in to comment.