Design of Optimizer Checkpointing for Gemini Plugin #4140
Fridge003
started this conversation in
Development | Core
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Background
ZeRO: paper
Chunk+Gemini:paper
Checkpoint System Design: github discussion
Problems
Gemini is developed as a parallel training strategy based on ZeRo algorithm. In order to improve communication efficiency, Gemini allocates a chunk to each model parameter, where each chunk is filled with a fixed number of tensor elements. Data are transmitted as chunks, so that bandwidth can be utilized more properly and data locality can also be enhanced.
However, while developing checkpointing features (save & load) for Gemini's optimizer, we found that the mechanism of 'chunks' might induce the following issues:
state_dict
of optimizer needs to map each model parameter into an integer ID. When initializing an optimizer, Pytorch will configure itsparam_groups
member variable according to the list of parameters passed in by user, whereparam_groups
decides the mapping relationship between parameters and integer IDs (the relationship is decided by the order of appearance in passed in parameter list), as well as the hyperparameters adopted by each parameter. However, Gemini algorithm will modifyparam_groups
configured by pytorch before training, leading to greater complexity for parameter management.step()
, but only computing the shard of states on local device instead.Solutions
Management of
param_groups
We assume that the configuration of passed in model parameters and hyperparameters doesn't change during saving & loading. This loose assumption is the foundation of our design. In fact, the checkpointing system of Pytorch is also based on this assumption.
Gemini calls
self.__init__optimizer()
method during initialization ofZeroOptimizer
class (the class of optimizer wrapper it uses). This method will modifyself.param_groups
set by Pytorch as following:In this code segment, the value corresponding to key 'params' in each param_group is replaced with a
fake_param_list
, where eachfake_param
is a dummy tensor. Meanwhile, parameters not stored on local device are wipped out fromself.param_groups
. So the original implementation of Gemini will undermine information of param_groups, leading to the first issue mentioned above.To address this issue, we maintain several member variables in
ZeroOptimizer
class:In the for loop of modified
self.__init__optimizer()
method, each parameter is traversed in the order of original param_groups. We can naturally obtain their IDs and mapping from ID to parameter object (this is recorded atself.id_to_real_params
). Meanwhile, the information of param_groups will be backuped atself.param_groups_backup
. If the current parameter will be added tofake_param_list
of current process, a mapping from ID to its fake_param object will be added toself.id_to_fake_params
accordingly.In this way, we can conveniently check whether a parameter is managed by current process through command
To get the fake_param object or real parameter object corresponding to an integer ID, just use command
In this way, any necessary information of parameter can be obtained with given parameter ID.
Method of Collecting Optimizer States
As is mentioned in the second issue above, the optimizer states of the same parameter can be distributed among different devices. To obtain the integral optimizer states before saving it to checkpoint, we designate the device with rank 0 to be the manager that gathers the states shards and write them to disk. To implement this idea, we design a method for collecting shards of optimizer states:
In this method, first use variable
is_collector
to check whether the current rank needs to collect complete states (by default only master rank needs). Then pack the state shards on local device into a compacted tensor, and communicate with other ranks usingtorch.distributed.all_gather_objects
API. After all the ranks receive the complete information of optimizer states, the ranks whoseis_collector
is True updates collected states and return.Sharding
The sharding feature demands that the checkpoint of optimzier states should be distributed in different files (usually with limited size) under the same folder. As is mentioned in the third issue above, the checkpoint should be agnostic to the condition of devices, so we shouldn't assign respective checkpointing folders for each device.
Since Gemini has implemented the sharding feature for model checkpointing, we can imitate its design:
Beta Was this translation helpful? Give feedback.
All reactions