Skip to content

Commit

Permalink
[autoparallel] fix parameters sharding bug (hpcaitech#2716)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangLiu0306 authored Feb 15, 2023
1 parent 2045d45 commit 5b24987
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,9 @@ def _shard_param(param, target_sharding_spec):
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
return param

for node in nodes:
if node.op == 'call_module':
Expand All @@ -438,7 +439,7 @@ def _shard_param(param, target_sharding_spec):
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
_shard_param(param, target_sharding_spec)
param = _shard_param(param, target_sharding_spec)

setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param)
Expand Down Expand Up @@ -469,7 +470,7 @@ def _shard_param(param, target_sharding_spec):
target = getattr(target_module, atoms[-1])

target_sharding_spec = node.sharding_spec
_shard_param(target, target_sharding_spec)
target = _shard_param(target, target_sharding_spec)

assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
Expand Down

0 comments on commit 5b24987

Please sign in to comment.