Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add "_prepare_fsdp" for DPOTrainer #2539

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

faaany
Copy link
Contributor

@faaany faaany commented Jan 3, 2025

What does this PR do?

While training with DPOTrainer using FSDP and accelerate, I got the same error as mentioned in #1147. Similar to "_prepare_deepspeed", I fixed the issue by adding a new method called "_prepare_fsdp".

@faaany faaany marked this pull request as ready for review January 3, 2025 02:50
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix @faaany - overall it looks great!

Would you mind confirming that the following demo command works with your PR (once activation checkpointing is removed):

accelerate launch --config_file=examples/accelerate_configs/fsdp_qlora.yaml --num_processes=NUM_GPUS trl/scripts/dpo.py trl/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 50 \
    --output_dir Qwen2-0.5B-DPO \
    --no_remove_unused_columns

If it runs without error, can you please rename fsdp_qlora.yaml to fsdp.yaml so it runs for both modes?

A question for @qgallouedec: should this helper function live in a utils module somewhere so we don't have to copy it around to all other trainers?

trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
"device_id": self.accelerator.device,
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, we don't need activation checkpointing since we never do a backward pass with the reference model. I think this block can thus be removed

@gp1702
Copy link

gp1702 commented Jan 7, 2025

I tried running the demo command without qlora, and got the following error:
Traceback (most recent call last): File "/home/gandharvp_google_com/dpo/example.py", line 159, in <module> main(script_args, training_args, model_args) File "/home/gandharvp_google_com//dpo/example.py", line 134, in main trainer.train() File "/opt/conda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2164, in train return inner_training_loop( File "/opt/conda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2524, in _inner_training_loop tr_loss_step = self.training_step(model, inputs, num_items_in_batch) File "/opt/conda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 3687, in training_step self.accelerator.backward(loss, **kwargs) File "/opt/conda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 2248, in backward loss.backward(**kwargs) File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1071, in unpack_hook frame.recompute_fn(*args) File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1194, in recompute_fn fn(*args, **kwargs) File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/trl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/gandharvp_google_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/trl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 544, in forward attn_output = torch.nn.functional.scaled_dot_product_attention( RuntimeError: The expanded size of the tensor (1460) must match the existing size (730) at non-singleton dimension 3. Target sizes: [4, 14, 730, 1460]. Tensor sizes: [4, 1, 730, 730]

@faaany, I am wondering if you were able to replicate or fix this.

I am attaching the trainer code for reference.
fsdp_dpo_trainer.txt

faaany and others added 2 commits January 8, 2025 10:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants