[BUG]: Size Mismatch Issue When Loading Model Checkpoints Trained with Tensor Parallel if vocab_size % tp_size != 0
#6167
Description
Is there an existing issue for this bug?
- I have searched the existing issues
🐛 Describe the bug
Describe the bug
A size mismatch error occurs when loading model checkpoints trained with tensor parallel enabled, if the vocab_size
is not divisible by tp_size
.
To Reproduce
Let's modify the official Llama benchmark to reproduce with minimize work.
benchmark.py
(modify llama model vocab_size):
MODEL_CONFIGS = {
"100m": LlamaConfig(
max_position_embeddings=4096,
num_hidden_layers=4,
num_attention_heads=32,
intermediate_size=2048,
hidden_size=1024,
vocab_size=65535 # Note that vocab_size % tp_size != 0
),
}
benchmark.py
(add to the end of main
function):
# save the checkpoint and load it again
output_dir = './scripts/save'
booster.save_model(model, output_dir, shard=True, size_per_shard=10240)
print('wait 10 secs to ensure ckpts are saved.')
from time import sleep; sleep(10)
model = AutoModelForCausalLM.from_pretrained( # Note that this will fail
output_dir,
trust_remote_code=True,
**init_kwargs,
torch_dtype=torch.bfloat16,
)
entroypoint:
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py \
--plugin 3d --config 100m --xformers \
--batch_size 1 --num_steps 5 \
--grad_checkpoint --zero 1 \
--tp 2 --pp 1 --mbs 1
the script will fail with RuntimeError after execuating model = AutoModelForCausalLM.from_pretrained()
:
[rank0]: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
[rank0]: size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([65535, 1024]).
[rank0]: size mismatch for lm_head.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([65535, 1024]).
Others
No error reported if we set vocab_size=65536
.
No error reported if we set --tp 1 --pp 2
.
Similar error reported if we set --tp 2 --pp 2
.
Environment
colossalai: latest(8b0ed61)
cluster: single node with H20 * 8.
feel free to ask for furher environment information (but i think it probably not crucial to this issue ^_^)