Skip to content

Commit

Permalink
[zero] adapt zero hooks for unsharded module (hpcaitech#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Apr 8, 2022
1 parent 896ade1 commit ee112fe
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 58 deletions.
59 changes: 39 additions & 20 deletions colossalai/engine/ophooks/zero_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
self._stateful_tensor_mgr = stateful_tensor_mgr

def pre_fwd_exec(self, module: torch.nn.Module, *args):

for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)

Expand All @@ -45,12 +46,15 @@ def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)

# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()

Expand All @@ -59,18 +63,25 @@ def pre_fwd_exec(self, module: torch.nn.Module, *args):
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"

def post_fwd_exec(self, module: torch.nn.Module, *args):

# change tensor state to HOLD_AFTER_FWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)

# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()

def pre_bwd_exec(self, module: torch.nn.Module, input, output):

for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)

Expand All @@ -80,12 +91,15 @@ def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)

# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()

Expand All @@ -94,15 +108,20 @@ def pre_bwd_exec(self, module: torch.nn.Module, input, output):
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"

def post_bwd_exec(self, module: torch.nn.Module, input):

# change tensor state to HOLD_AFTER_BWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)

# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()

Expand Down
24 changes: 9 additions & 15 deletions colossalai/zero/init_ctx/init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def __init__(self,

super().__init__()
self.shard_strategy = shard_strategy
self.sharded_param_list = []
self.unshard_param_list = []
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
Expand Down Expand Up @@ -210,19 +209,15 @@ def _pre_context_exec(self):
def _post_context_exec(self):
"""The callback function when exiting context.
"""
for param in self.sharded_param_list:
assert hasattr(param, 'colo_attr')
param.colo_attr.remove_torch_payload()

del self.sharded_param_list

# broadcast replicated no-shard parameters
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.unshard_param_list:
for param in self.param_list:
assert hasattr(param, 'colo_attr')
if param.is_replicated:
if not param.colo_attr.param_is_sharded and param.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()

del self.unshard_param_list
del self.param_list

nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
torch.set_rng_state(self.cpu_rng_state)
Expand Down Expand Up @@ -264,10 +259,9 @@ def half_fn(t: torch.Tensor):

if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload
self.sharded_param_list.append(param)
else:
self.unshard_param_list.append(param)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload

self.param_list.append(param)

# We must cast buffers
# If we use BN, buffers may be on CPU and Float
Expand Down
6 changes: 2 additions & 4 deletions colossalai/zero/sharded_model/sharded_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self,
self._ophook_list = [
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
]
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)

Expand Down Expand Up @@ -366,14 +366,12 @@ def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor)

def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
prev_params = {}
for p in self.sharded_params:
prev_params[p] = p.data
p.data = p.colo_attr.sharded_data_tensor.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.data = prev_params[p]
p.colo_attr.remove_torch_payload()
return gathered_state_dict

def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
Expand Down
8 changes: 3 additions & 5 deletions colossalai/zero/sharded_optim/sharded_optim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,7 @@ def _zero_grad(self, recover_data: bool = False):
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))

if not p.colo_attr.param_is_sharded:
# FIXME(hhc): add hook for unsharded parameters
p.data = p.colo_attr.sharded_data_tensor.payload
p.colo_attr.remove_torch_payload()

def sync_grad(self):
pass
Expand Down Expand Up @@ -351,10 +348,11 @@ def _copy_master_param_to_param_fp16(self):
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()

if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload

self.master_params[p].trans_state(TensorState.HOLD)
p.colo_attr.saved_grad.set_null()
7 changes: 6 additions & 1 deletion colossalai/zero/sharded_param/sharded_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from .tensorful_state import StatefulTensor, TensorState
from typing import List

# use this tensor as empty data point for parameters
# we do not want users use param.data when its torch payload is removed
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')


class ShardedParamV2(object):

Expand All @@ -29,7 +34,7 @@ def get_payload_tensors(self) -> List[StatefulTensor]:
return [self._sharded_data_tensor]

def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)

@property
def sharded_data_tensor(self):
Expand Down
1 change: 0 additions & 1 deletion tests/test_moe/test_moe_zero_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded
assert param.colo_attr.sharded_data_tensor.data_ptr() == param.data.data_ptr()
else:
assert param.colo_attr.sharded_data_tensor.is_sharded

Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe/test_moe_zero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
assert_equal_in_group(p.data)
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)

model = MoeModel().half()
col_model_deepcopy(zero_model, model)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_moe/test_moe_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
assert_equal_in_group(p.data.to(get_current_device()))
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))

model = MoeModel().half()
col_model_deepcopy(zero_model, model)
Expand All @@ -99,7 +99,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.data)
p.data.copy_(zp.colo_attr.sharded_data_tensor.payload)

for i, (data, label) in enumerate(train_dataloader):
if i > 5:
Expand Down
7 changes: 3 additions & 4 deletions tests/test_zero_data_parallel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,15 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload

assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
2 changes: 1 addition & 1 deletion tests/test_zero_data_parallel/test_shard_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_zero_data_parallel/test_shard_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def _run_shard_param_v2(rank, world_size, port):
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"

sparam.remove_torch_payload()
assert (param.data.numel() == 1)
assert (param.data.numel() == 0)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cpu_mem_use == 2 * 3 * 4 * 2

sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0

# append a grad to torch param
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero_data_parallel/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def test_zero_state_dict(world_size):


if __name__ == '__main__':
test_zero_state_dict(2, TensorShardStrategy)
test_zero_state_dict(2)

0 comments on commit ee112fe

Please sign in to comment.