Open
Description
Is there an existing issue for this bug?
- I have searched the existing issues
🐛 Describe the bug
When fine-tuning a base model (Qwen2.5 3B Base) that uses tie_word_embeddings=True
with GeminiPlugin and saving the checkpoint, I noticed an additional set of weights being saved:
"lm_head.weight": "pytorch_model-00004.bin"
This leads to the following error when reloading the saved checkpoint:
[rank2]: RuntimeError: Error(s) in loading state_dict for GeminiDDP:
[rank2]: Unexpected key(s) in state_dict: "lm_head.weight".
Could you provide some advice on how to solve this issue and avoid saving the not needed weights?
Environment
Python: 3.11
colossalai: 0.4.6
Activity