Skip to content

Commit

Permalink
add sharded checkpoint loading for AutoTP path to reduce the peak mem… (
Browse files Browse the repository at this point in the history
microsoft#3102)

* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix gptj sharded checkpoint loading problem

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
sywangyi and tjruwase authored May 4, 2023
1 parent 0a61d5d commit d10b8ca
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 25 deletions.
46 changes: 35 additions & 11 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

if config.checkpoint and not config.replace_with_kernel_inject:
self._load_checkpoint(config.checkpoint)

# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
Expand All @@ -173,10 +170,6 @@ def __init__(self, model, config):
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.moe_experts)

# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None

# We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
if self.injection_dict:
# 1. User specified Tensor Parallelism
Expand Down Expand Up @@ -343,18 +336,38 @@ def load_model_with_checkpoint(self, r_module):
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data,
device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = self.mp_replace.strided_copy(module.weight.data,
state_dict[prefix + 'weight'],
num_splits=3)
else:
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data, device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data, device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
else:
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
module.bias = self.mp_replace.copy(module.bias, data)
Expand Down Expand Up @@ -383,6 +396,15 @@ def load_module_recursive(module, prefix='', level=0):

load_module_recursive(r_module)

embedding_weight = None

for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr(
r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight

def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
Expand Down Expand Up @@ -434,16 +456,18 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(sd_loader)):
for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
Loading

0 comments on commit d10b8ca

Please sign in to comment.