Skip to content

Commit

Permalink
[DPO] DPOConfig class (huggingface#1554)
Browse files Browse the repository at this point in the history
* initial DPOConfig

* fix doc string

* use DPOConfig

* fix missing import

* fix DpoScriptArguments

* override args config when given in init

* use DPOConfig

* fix output dir name

* over-ride with depreicated arguments if given

* use DPOConfig in tests

* fix comment

* add custom_message

* use dataset_train_name and dataset_test_name

* beta is also in the training_args

* fix loss_type docs

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use DPOScriptArguments

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
kashif and lewtun authored Apr 23, 2024
1 parent c050ebc commit 24fd8dd
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 178 deletions.
24 changes: 17 additions & 7 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).

```py
training_args = DPOConfig(
beta=0.1,
)
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
Expand Down Expand Up @@ -131,8 +133,7 @@ First install `unsloth` according to the [official documentation](https://github

```python
import torch
from transformers import TrainingArguments
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
from unsloth import FastLanguageModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
Expand All @@ -159,13 +160,15 @@ model = FastLanguageModel.get_peft_model(
random_state = 3407,
)

training_args = TrainingArguments(output_dir="./output")
training_args = DPOConfig(
output_dir="./output",
beta=0.1,
)

dpo_trainer = DPOTrainer(
model,
ref_model=None,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
Expand Down Expand Up @@ -224,14 +227,21 @@ model = PeftModel.from_pretrained(
model.load_adapter("/path/to/peft", adapter_name="reference")

# Initialize the trainer, without a ref_model param.
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
model_adapter_name="train",
ref_adapter_name="reference",
)
```

## DPOTrainer

[[autodoc]] DPOTrainer

## DPOConfig

[[autodoc]] DPOConfig
17 changes: 7 additions & 10 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@
--lora_r=16 \
--lora_alpha=16
"""

import logging
import multiprocessing
import os
from contextlib import nullcontext

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser
from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser

if TRL_USE_RICH:
init_zero_verbose()
Expand All @@ -67,9 +68,10 @@

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
Expand All @@ -84,7 +86,7 @@


if __name__ == "__main__":
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
Expand Down Expand Up @@ -155,8 +157,8 @@ def process(row):
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
# Training
Expand All @@ -166,14 +168,9 @@ def process(row):
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
Expand Down
38 changes: 19 additions & 19 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from trl import DPOTrainer, is_peft_available
from trl import DPOConfig, DPOTrainer, is_peft_available

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand All @@ -71,20 +71,20 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
fp16=True,
logging_strategy="no",
report_to="none",
beta=0.1,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
max_length=self.max_length,
)

# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
max_length=self.max_length,
)

# train the model
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand All @@ -127,22 +127,22 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
report_to="none",
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
beta=0.1,
max_length=self.max_length,
)

# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
peft_config=self.peft_config,
max_length=self.max_length,
)

assert isinstance(trainer.model, PeftModel)
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
Expand All @@ -191,22 +191,22 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
report_to="none",
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
beta=0.1,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
max_length=self.max_length,
)

# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
peft_config=self.peft_config,
max_length=self.max_length,
)

assert isinstance(trainer.model, PeftModel)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_sft_cli():
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
shell=True,
check=True,
)
Expand Down
Loading

0 comments on commit 24fd8dd

Please sign in to comment.