Skip to content

Commit

Permalink
[Unified Checkpoint] Fix tie_weights save and load (PaddlePaddle#8137)
Browse files Browse the repository at this point in the history
* fix unified checkpoint tie_weight

* unify static2struct_name_mappings

* fix chatglm model

* fix trainer

* fix attr

* simplify code

* add single card config save
  • Loading branch information
DesmonDay authored Mar 19, 2024
1 parent 99802de commit b6dcb4e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
14 changes: 14 additions & 0 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ def save_single_card_checkpoint(args, model_to_save, output_dir):
sharded_index_json["type"] = "lora"
elif isinstance(model_to_save, PrefixModelForCausalLM):
sharded_index_json["type"] = "ptuning"

os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, index_filename)
with open(path, "w") as f:
json.dump(sharded_index_json, f, indent=4)
Expand All @@ -911,6 +913,13 @@ def save_single_card_checkpoint(args, model_to_save, output_dir):

if isinstance(model_to_save, PrefixModelForCausalLM):
save_prefix_past_key_value(model_to_save, output_dir)
model_to_save.prefix_config.save_pretrained(output_dir)
if isinstance(model_to_save, LoRAModel):
model_to_save.lora_config.save_pretrained(output_dir)

config_to_save = save_config(model_to_save)
config_to_save.architectures = [model_to_save.__class__.__name__]
config_to_save.save_pretrained(output_dir)


def save_single_card_optimizer(args, model, optimizer, output_dir):
Expand Down Expand Up @@ -1003,6 +1012,11 @@ def get_expected_state_dict(model_to_save):
state_dict = model_to_save.get_trainable_state_dict()
elif isinstance(model_to_save, PrefixModelForCausalLM):
state_dict = model_to_save.prefix_encoder.state_dict()

if hasattr(model_to_save, "_tied_weights_keys") and model_to_save._tied_weights_keys is not None:
for key in model_to_save._tied_weights_keys:
if key in state_dict:
state_dict.pop(key)
return state_dict


Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# Load in optimizer and scheduler states
self.optimizer.set_state_dict(opt_state_dict)
else:
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def forward(self, hidden_states):

class ChatGLMForCausalLM(ChatGLMPretrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder_weight"]
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.decoder_weight"]

def __init__(self, config: ChatGLMConfig):
super(ChatGLMForCausalLM, self).__init__(config)
Expand Down

0 comments on commit b6dcb4e

Please sign in to comment.