A Unified Checkpoint System Design #3339
Replies: 9 comments 11 replies
-
Great! We received a lot of issues about it : ( |
Beta Was this translation helpful? Give feedback.
-
@YuliangLiu0306 perhaps you want to suggest some methods to manage the |
Beta Was this translation helpful? Give feedback.
-
One question left to think is that how to save optimizer states for sharded tensors, e.g. in auto parallel and zero? |
Beta Was this translation helpful? Give feedback.
-
Related issue #3250 should gather the optimizer weights before save funciton. |
Beta Was this translation helpful? Give feedback.
-
Assuming we only have 4 linear layers and they all don't have bias. The linear layer in yellow is a DTensor. The checkpoint will look like (index.json is not shown): |
Beta Was this translation helpful? Give feedback.
-
This design implementation will be tracked in the Kanban https://github.com/orgs/hpcaitech/projects/19. |
Beta Was this translation helpful? Give feedback.
-
A possible sharded optimizer checkpointA state dict of optimizer may be like this: There are three types of file:
Index fileFile name may be like File content may be like: {
"param_groups": "pytorch_optim_group.bin",
"weight_map": {
"0": "pytorch_optim-00001.bin",
"1": "pytorch_optim-00002.bin"
}
} Group fileGenerally speaking, File name may be like File content may be like: [{'lr': 0.001,
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'params': [0]},
{'lr': 0.001,
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'params': [1]}] It saves State filesOptimizer states may be large and we need to shard them. File name may be like File content may be like: {1: {'step': tensor(1.),
'exp_avg': tensor([0.0750, 0.0381, 0.0591, 0.0473, 0.0298, 0.0659, 0.0052, 0.0653, 0.0714,
0.0618, 0.0388, 0.0288, 0.0140, 0.0349, 0.0391, 0.0459, 0.0867, 0.0453,
0.0629, 0.0130]),
'exp_avg_sq': tensor([5.6205e-04, 1.4529e-04, 3.4900e-04, 2.2345e-04, 8.8624e-05, 4.3449e-04,
2.6612e-06, 4.2593e-04, 5.0988e-04, 3.8246e-04, 1.5073e-04, 8.2851e-05,
1.9700e-05, 1.2154e-04, 1.5302e-04, 2.1050e-04, 7.5246e-04, 2.0536e-04,
3.9539e-04, 1.6980e-05])}} It saves part of ShortcomingsAs the key of each param is number, we don't know the param is belong to which model. If using pipeline parallelism, it's hard to recover the sharded state dict. |
Beta Was this translation helpful? Give feedback.
-
Hello @FrankLeeeee I'm trying to enable the hugging face remote class in ColossalAI, by now everything is ok, but I just found the 3D version CheckpointIO is not implemented in example/llama: booster/pulgin/three_dim_parallel.py, I want to implement, so which example or tutorial I can follow? |
Beta Was this translation helpful? Give feedback.
-
What's the mechanism of checkpoint save& load during multiple node training? Rank1 collect all the data and save it? or all the nodes save the sharded model, rank1 only save the index json. |
Beta Was this translation helpful? Give feedback.
-
Overview
As we are developing new features for the Colossal-AI system, we find it difficult to save/load model/optimizer checkpoints. It is because that different features need to handle the save/load logic on its own without a common protocol. This is not a good idea as we don't want it to be that the model trained with one feature can only be loaded with the same feature. This limits the usage of the checkpoint and hinders the integration with the community.
Therefore, it is important to design a unified checkpoint system as a protocol. Some important factors should be considered.
Background
First of all, we should understand which use cases this unified system will cater to. Let's assume we have a model like this:
When we train this model, this model can have different ways of placement over GPUs. Let's assume we are training with 2 GPUs.
Therefore, the unified checkpoint system should support at least the features mentioned above.
Currently, there are mainly two ways to save/load the model checkpoints:
state_dict()
. For example, PyTorch saves the whole model into a single file calledmodel.pth
index.json
file is used to specify which parameter goes into which file. An example can be https://huggingface.co/facebook/opt-66b/blob/main/pytorch_model.bin.index.json.Design
According to the information mentioned above, what we need to support can be expressed as a matrix:
Therefore, we will have the following APIs to cover these usages:
We can focus on the
load_model
andsave_model
methods.save_model
has two arguments to define its checkpoint format:index.json
file if True. Otherwise, the model weight will be saved in a single file.To better explain the outcome of different cases, I will use file structure to illustrate:
1. shard = False, gather_dtensor = True
2. shard = True, gather_dtensor = True
3. shard = False, gather_dtensor = False. The dtensors will be stored in an individual folder and each tensor sharded is numbered. The saved tensor format can be:
The file structure will look like:
The weight map key-value pair for dtensor in the
index.json
file will look likelinear_weight: linear_weight.bin.*
4. shard = True, gather_dtensor = False
When loading models, we only support cases of
shard = True, gather_detnsor = False
andshard = False, gather_dtensor = False
. Therefore, merging dtensors to a global tensor can be done offline via our CLI.Beta Was this translation helpful? Give feedback.
All reactions