Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish license #300

Merged
merged 3 commits into from
Mar 3, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more unitest for shard param
feifeibear committed Mar 3, 2022
commit 4ab08286bec52b2213a12a2f11c18fa9788b1f2a
21 changes: 11 additions & 10 deletions colossalai/zero/shard_param/shard_param.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from colossalai.core import global_context as gpc
import torch.distributed as dist
from typing import Union, Tuple, Optional
import numpy


class ShardParam(object):
@@ -13,12 +14,12 @@ class ShardParam(object):
"""

def __init__(self,
param: Union[torch.nn.Parameter, Tuple[int, ...]],
other: Union[torch.nn.Parameter, Tuple[int, ...]],
process_group: Optional[dist.ProcessGroup] = None,
is_sharded: bool = False,
device: Optional[torch.device] = None) -> None:
r"""
param: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
process_group: the process group storing the shared data.
is_sharded: is shared the param during __init__.
device: the device to place param data payload on
@@ -29,23 +30,23 @@ def __init__(self,
self.is_sharded = False

# Hijack the data payload of param
if isinstance(param, torch.nn.Parameter):
self._param_payload = param.data.to(device)
self._origin_shape = param.shape
self._origin_numel = param.numel()
if isinstance(other, torch.nn.Parameter):
self._param_payload = other.data.to(device)
self._origin_shape = other.shape
self._origin_numel = other.numel()
if is_sharded:
self.shard()
elif isinstance(param, tuple):
self._origin_shape = param.shape
self._origin_numel = param.numel()
elif isinstance(other, tuple):
self._origin_shape = other
self._origin_numel = numpy.prod(other)

# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
self._param_payload = torch.empty(self._origin_shape, device=device)
if is_sharded:
self.shard()
else:
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(param)}")
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")

self._payload_numel = None

11 changes: 11 additions & 0 deletions tests/test_zero_data_parallel/test_shard_param.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,17 @@ def run_init_shard_param(rank, world_size, port):
sparam = ShardParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam

param_shape = (2, 3)
sparam = ShardParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])

param_shape = (2, 3)
sparam = ShardParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3])


def run_shard_param_check(rank, world_size, port):