Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158

Merged
merged 29 commits into from
Jan 6, 2025

Conversation

Abhishek-TAMU
Copy link
Contributor

@Abhishek-TAMU Abhishek-TAMU commented Oct 3, 2024

What does this PR do?

This PR adds cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q to the batch in DataCollatorForCompletionOnlyLM. This, together with a PR in transformers (link to be added), removes graph breaks in padding-free tuning, allowing for maximum performance to be obtained.
Specifically, these parameters should be generated here (this PR change), outside of the transformers loop, as they incur a cpu-gpu sync that is unavoidable. Otherwise, this cpu-gpu sync happens here, inside the attention call which causes graph breaks and hence the transformers PR removes this call to remove all graph breaks when torch_compile flag is turned on in Training arguments to use in SFTTrainer.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
@Abhishek-TAMU Abhishek-TAMU changed the title Add Sequence Lengths to Batch in DataCollatorForCompletionOnlyLM Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM Oct 3, 2024
@Abhishek-TAMU Abhishek-TAMU marked this pull request as ready for review October 3, 2024 15:42
@Abhishek-TAMU
Copy link
Contributor Author

CC: @kashif @qgallouedec

@kashif kashif added ✨ enhancement New feature or request 🏋 SFT Related to SFT labels Oct 6, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
@qgallouedec
Copy link
Member

qgallouedec commented Oct 10, 2024

Hi, thanks for the PR.
Can you provide the link of the PR in transformers? Is it huggingface/transformers#33932?

@qgallouedec
Copy link
Member

Could you provide a simple test to:

  1. Confirm that it is a case of non-functioning.
  2. Verify that this addition resolves it.

It might also be helpful to add a few comments, as these lines are unclear without context.

@qgallouedec qgallouedec added 🐛 bug Something isn't working and removed ✨ enhancement New feature or request labels Oct 10, 2024
Abhishek-TAMU and others added 9 commits October 14, 2024 13:00
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
@Abhishek-TAMU
Copy link
Contributor Author

Thank you @qgallouedec for the review. This is the related transformers PR which is approved and merged.

I added 2 test cases. One where Tuning fails with Padding and another where doesn't fail without padding.

@Abhishek-TAMU
Copy link
Contributor Author

@kashif @qgallouedec Could you possibly review this PR ? Thank you!

@Abhishek-TAMU
Copy link
Contributor Author

Abhishek-TAMU commented Nov 12, 2024

Hi @kashif @qgallouedec, could you please take another look at this PR when you get the chance? The changes in this PR are urgent for making torch_compile flag in SFTTrainer work for Llama models (LlamaForCausalLM). This is important for users who need to compile the Llama model using SFTTrainer (in padding_free mode) without any graph breaks. Thank you!

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Comment on lines 664 to 666
formatted_dataset = lambda example: {
"output": f"### prompt:\n{example['prompt'].strip()}\n\n### completion:\n{example['completion'].strip()}{tokenizer.eos_token}"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is dataset formatting required here, or can we drop it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset formatting is required because the SFTTrainer and DataCollatorForCompletionOnlyLM expect the dataset to have a specific format—a single text field that combines both the prompt and the completion in a way the model can understand. This function includes both the prompt and completion, ensuring the data collator can correctly identify where the completion starts using the response_template.

@@ -654,6 +654,50 @@ def test_data_collator_completion_lm_with_multiple_text(self):
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")

def test_data_collator_completion_lm_without_padding(self):
os.environ["CUDA_VISIBLE_DEVICES"]="0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the issue only occur with cuda device? In other words can we reproduce on cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to usage of flash_attention_2 it would work only on GPU.

@qgallouedec
Copy link
Member

Hey @Abhishek-TAMU, to keep you posted with the current status of the PR, I am struggling reproducing the initial error. Do you have a MRE by any chance? The code from the unittest gives

...
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 134, in __init__
    assert isinstance(
AssertionError: expected FunctionType found _lru_cache_wrapper <functools._lru_cache_wrapper object at 0x7f000e67fb60>

from user code:
   File "/fsx/qgallouedec/transformers/src/transformers/models/llama/modeling_llama.py", line 1224, in torch_dynamo_resume_in_forward_at_1199
    loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

  0%|          | 0/2 [00:07<?, ?it/s]  

and it doesn't seem related

@Abhishek-TAMU
Copy link
Contributor Author

Thank you @qgallouedec for looking into this. Sharing you the code which would produce graph break.
Using latest release version of transformers (which doesn't have huggingface/transformers#33932 changes) and latest trl including changes from this PR.

If this change huggingface/transformers#33932 is used in transformers then Graph break could be avoided.

import os, tempfile, torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer)
from trl import SFTConfig, SFTTrainer
from trl.trainer import DataCollatorForCompletionOnlyLM
from datasets import load_dataset

standard_prompt_completion_dataset = load_dataset(
    "trl-internal-testing/zen", "standard_prompt_completion"
)

os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["CUDA_HOME"]="/home/tuning/.local/cuda-12.1"
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
torch_dtype = getattr(torch, "bfloat16", None)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

formatted_dataset = lambda example: {
"output": f"### prompt:\n{example['prompt'].strip()}\n\n### completion:\n{example['completion'].strip()}{tokenizer.eos_token}"
}

train_dataset = standard_prompt_completion_dataset["train"].map(formatted_dataset)

response_template = "### completion:\n"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True)

with tempfile.TemporaryDirectory() as tmp_dir:
    training_args = SFTConfig(
        output_dir=tmp_dir,
        dataloader_drop_last=True,
        max_steps=2,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        save_steps=2,
        learning_rate=1e-5,
        dataset_text_field="output",
        torch_compile=True,
        torch_compile_backend="inductor",
        torch_compile_mode="default"
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        data_collator=data_collator,
        args=training_args,
    )

    # with assertRaises(Exception):
    trainer.train()
    del os.environ["CUDA_VISIBLE_DEVICES"]

@Abhishek-TAMU
Copy link
Contributor Author

Hi @qgallouedec, were you able to reproduce the initial error with this MRE ?

@ArthurZucker
Copy link

The loss error was fixed in transformers

@ArthurZucker
Copy link

LGTM!

@@ -114,7 +114,7 @@ def test_padding_free(self):
inst1 = "### System: You are a helpful assistant.\n\n### User: How much is 2+2?\n\n### Assistant: 2+2 equals 4"
inst2 = "### System: You are a honest and helpful assistant.\n\n### User: What is the answer of 22x22?\n\n### Assistant: 22x22 equals 484"

response_template = "\n### Assistant:"
response_template = "\n\n### Assistant:"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise the template isn't found (\n\n is jointly tokenized)

@qgallouedec
Copy link
Member

Thanks and sorry for the delay

@qgallouedec qgallouedec merged commit d9ee2fd into huggingface:main Jan 6, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants