Skip to content

Commit

Permalink
Use the standard dataset for DPO CLI (huggingface#1456)
Browse files Browse the repository at this point in the history
* Use the standard dataset

* update docs

* update dpo examples

* fix cli error

* fix CI

* use trl-internal-testing/hh-rlhf-trl-style
  • Loading branch information
vwxyzjn authored Mar 20, 2024
1 parent 988d4c4 commit 423991c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir
**DPO:**

```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf
```

**Chat:**
Expand Down
2 changes: 1 addition & 1 deletion commands/run_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed"
DATASET_NAME="trl-internal-testing/hh-rlhf-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
Expand Down
30 changes: 18 additions & 12 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,30 +63,36 @@ The SFT CLI is based on the `examples/scripts/sft.py` script.

### Direct Policy Optimization (DPO)

First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows:
To use the DPO CLI, you need to have a dataset in the TRL format such as

1- Make sure to pre-tokenize the dataset using chat templates:
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style

```bash
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset
```
These datasets always have at least three columns `prompt, chosen, rejected`:

* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)

You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template

2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
To do a quick start, you can run the following command:

```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
```

Once your dataset being pushed, run the dpo CLI as follows:

The DPO CLI is based on the `examples/scripts/dpo.py` script.


#### Custom preference dataset

Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):

```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```

The SFT CLI is based on the `examples/scripts/dpo.py` script.

## Chat interface

The chat CLI lets you quickly load the model and talk to it. Simply run the following:
Expand Down
26 changes: 22 additions & 4 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""
# regular:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
Expand All @@ -31,9 +31,9 @@
# peft:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
Expand All @@ -50,6 +50,7 @@
--lora_alpha=16
"""
import logging
import multiprocessing
import os
from contextlib import nullcontext

Expand Down Expand Up @@ -118,6 +119,8 @@
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
Expand All @@ -137,8 +140,23 @@
################
# Dataset
################
train_dataset = load_dataset(args.dataset_name, split="train")
eval_dataset = load_dataset(args.dataset_name, split="test")
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]

################
# Training
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_sft_cli():
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine",
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
shell=True,
check=True,
)
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class DpoScriptArguments:
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
Expand Down

0 comments on commit 423991c

Please sign in to comment.