Skip to content

Commit

Permalink
[autoparallel] test compatibility for gemini and auto parallel (hpcai…
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangLiu0306 authored Feb 15, 2023
1 parent d701ef8 commit 7fa6be4
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 4 deletions.
10 changes: 6 additions & 4 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())

setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
Expand Down Expand Up @@ -432,8 +433,9 @@ def hook_fn(grad):
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
target = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())

assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import copy
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port


class MLP(torch.nn.Module):

def __init__(self, in_features):
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)

def forward(self, x):
x = self.linear_1(x)
x = self.linear_2(x)

return x


def check_compatibility_with_ddp(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MLP(4).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
loss_compare = output_compare.sum()
loss_compare.backward()
grad_compare = copy.deepcopy(model.linear_1.weight.grad)

physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
meta_args = {'x': torch.rand(4, 4).to('meta')}
gm, solution = initialize_model(model,
meta_args=meta_args,
device_mesh=device_mesh,
return_solution=True,
solver_preference='tp',
shard_option='shard_last_axis')

msg = '| TP strategy combination chosen by auto-parallel solver |'
msg_length = len(msg)
if rank == 0:
print('=' * msg_length)
print(msg)
print('=' * msg_length)
for strategy in solution:
print(strategy)
print('=' * msg_length)

dp_process_group = None
for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]:
if rank in ranks:
dp_process_group = process_group_handle
assert dp_process_group is not None
gm = DDP(gm, process_group=dp_process_group)
output = gm(input)

assert_close(output, output_compare)
print(f'output on rank{rank} is correct')
loss = output.sum()

loss.backward()

if rank in (0, 2):
assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8))

if rank in (1, 3):
assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8))

print(f'gradient on rank{rank} is correct')


@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_compatibility_with_ddp():
world_size = 4
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_compatibility_with_ddp()
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import copy
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx


class MLP(torch.nn.Module):

def __init__(self, in_features):
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)

def forward(self, x):
x = self.linear_1(x)
x = self.linear_2(x)

return x


def check_auto_parallel_with_gemini(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MLP(4).half().cuda()

input = torch.rand(4, 4).half().cuda()
output_compare = model(input)
loss_compare = output_compare.sum()
loss_compare.backward()
grad_compare = copy.deepcopy(model.linear_1.weight.grad)

physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
meta_args = {'x': torch.rand(4, 4).half().to('meta')}
gm, solution = initialize_model(model,
meta_args=meta_args,
device_mesh=device_mesh,
return_solution=True,
solver_preference='tp',
shard_option='shard_last_axis')

if rank == 0:
msg = '| TP strategy combination chosen by auto-parallel solver |'
msg_length = len(msg)
print('=' * msg_length)
print(msg)
print('=' * msg_length)
for strategy in solution:
print(strategy)
print('=' * msg_length)

dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2)
gemini_config = dict(strict_ddp_mode=False,
device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
search_range_mb=128)

post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
output = gm(input)
assert_close(output, output_compare)
print(f'output on rank{rank} is correct')
loss = output.sum()
optimizer.zero_grad()
optimizer.backward(loss)
optimizer.step()

if rank in (0, 2):
assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten())

if rank in (1, 3):
assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten())

print(f'gradient on rank{rank} is correct')


@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini():
world_size = 4
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_auto_parallel_with_gemini()

0 comments on commit 7fa6be4

Please sign in to comment.