Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158

Merged
merged 29 commits into from
Jan 6, 2025
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4472501
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
6cfa171
fix: formatting
Abhishek-TAMU Oct 2, 2024
a821ce0
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
fb669b6
fix: formatting
Abhishek-TAMU Oct 2, 2024
f4b1955
Merge branch 'huggingface:main' into collator_batch
Abhishek-TAMU Oct 14, 2024
1b7c060
Merge branch 'collator_batch' of github.com:Abhishek-TAMU/trl into co…
Abhishek-TAMU Oct 21, 2024
c3578f8
Merge branch 'main' into collator_batch
Abhishek-TAMU Oct 21, 2024
e83fc8a
fix: max_length_k to int
Abhishek-TAMU Oct 21, 2024
68554b1
fix:Added comments
Abhishek-TAMU Oct 21, 2024
2a7dd47
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Oct 30, 2024
b0a52e2
test cases
Abhishek-TAMU Oct 30, 2024
054a6ef
test cases
Abhishek-TAMU Oct 30, 2024
376ad21
test cases
Abhishek-TAMU Oct 30, 2024
9a08ea3
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 12, 2024
a97045b
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
f31a780
fix: formatting
Abhishek-TAMU Oct 2, 2024
29ba8a3
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
d1441e1
test cases
Abhishek-TAMU Oct 30, 2024
d55a6e2
test cases
Abhishek-TAMU Oct 30, 2024
7dccc2d
test cases
Abhishek-TAMU Oct 30, 2024
5e5224e
unit test changes
Abhishek-TAMU Nov 12, 2024
1b434b0
unit test changes
Abhishek-TAMU Nov 12, 2024
ef1e304
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 18, 2024
77894b1
style
qgallouedec Nov 19, 2024
911f60c
Merge branch 'main' into collator_batch
qgallouedec Nov 19, 2024
979f9f0
Merge branch 'main' into collator_batch
qgallouedec Dec 18, 2024
cebf936
Merge branch 'main' into collator_batch
qgallouedec Jan 6, 2025
ca8e153
add test
qgallouedec Jan 6, 2025
8c27e16
remove test
qgallouedec Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test cases
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
  • Loading branch information
Abhishek-TAMU committed Nov 12, 2024
commit d55a6e280711f5adc05a198ac193bd660f32e91d
2 changes: 2 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def test_data_collator_completion_lm_with_multiple_text(self):
self.assertEqual(result_text, "I have not been masked correctly.")

def test_data_collator_completion_lm_with_padding(self):
os.environ["CUDA_VISIBLE_DEVICES"]="0"
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
torch_dtype = getattr(torch, "bfloat16", None)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
Expand Down Expand Up @@ -696,6 +697,7 @@ def test_data_collator_completion_lm_with_padding(self):
trainer.train()

def test_data_collator_completion_lm_without_padding(self):
os.environ["CUDA_VISIBLE_DEVICES"]="0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the issue only occur with cuda device? In other words can we reproduce on cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to usage of flash_attention_2 it would work only on GPU.

model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
torch_dtype = getattr(torch, "bfloat16", None)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
Expand Down