Skip to content

Commit

Permalink
SFTTrainer: Fix backward Compatibility issue with TrainingArguments (
Browse files Browse the repository at this point in the history
…huggingface#1707)

* fix BC

* fixup
  • Loading branch information
younesbelkada authored Jun 6, 2024
1 parent 0bdc638 commit 39a7d1c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
TrainingArguments,
)

from trl import SFTConfig, SFTTrainer
Expand Down Expand Up @@ -213,6 +214,31 @@ def test_constant_length_dataset(self):
decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("Question" in decoded_text) and ("Answer" in decoded_text)

def test_sft_trainer_backward_compatibility(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(
output_dir = "tmp_trainer"
warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.")
args = SFTConfig(output_dir=output_dir)
elif args is not None and args.__class__.__name__ == "TrainingArguments":
args = SFTConfig(**args.to_dict())

if model_init_kwargs is not None:
warnings.warn(
Expand Down

0 comments on commit 39a7d1c

Please sign in to comment.