Skip to content

Commit

Permalink
Add CPOTrainer (huggingface#1382)
Browse files Browse the repository at this point in the history
* add CPOTrainer

* add docs

* fix formatting

* removed precompute_ref_log_probs arg

* remove precompute_ref_log_probs

* typos

* finish cpo trainer doc

* remove redundant lines

* typo

* formatting

* compute chosen nll loss also for enc-dec models

* fix gradient error of inplace operation for enc-dec models

* formatting

* use CPOConfig

* formatting

* use model_init_kwargs from CPOConfig

* comments in example

* fix doc string

* fix typo in docstring

* update year

* fixed typo

* use preference dataset

* fix learning rate

* move dataset_num_proc to configs

* Update cpo paper link from HF: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update description for CPO: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove _prepare_deepspeed for cpo

Because CPO does not need init for reference model

* Add explanation to CPO loss

* format

* fix bug when lengths are given

* add CPOTrainer to README

* fix grammer

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
3 people authored Mar 22, 2024
1 parent d10f766 commit d1df79f
Show file tree
Hide file tree
Showing 12 changed files with 1,433 additions and 12 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The library is built on top of the [`transformers`](https://github.com/huggingfa
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), and [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer).
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), and [`CPOTrainer`]((https://huggingface.co/docs/trl/trainer#trl.CPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).

Expand Down Expand Up @@ -82,11 +82,11 @@ Read more about CLI in the [relevant documentation section](https://huggingface.

## How to use

For more flexibility and control over the training you can use the dedicated trainer classes to fine-tune the model in Python.
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.

### `SFTTrainer`

This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.

```python
# imports
Expand All @@ -110,7 +110,7 @@ trainer.train()

### `RewardTrainer`

This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
Expand All @@ -136,7 +136,7 @@ trainer.train()

### `PPOTrainer`

This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```python
# imports
Expand Down Expand Up @@ -175,7 +175,7 @@ train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

### `DPOTrainer`

`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example on how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
title: DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: iterative_sft_trainer
Expand Down
102 changes: 102 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# CPO Trainer

Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.

CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.

## Expected dataset format

The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:

- `prompt`
- `chosen`
- `rejected`

for example:

```py
cpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.


## Expected model format
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `CPOTrainer`
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.

```py
cpo_config = CPOConfig(
beta=0.1,
)

cpo_trainer = CPOTrainer(
model,
args=cpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:

```py
cpo_trainer.train()
```

## Loss functions

Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.

The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).


## Logging

While training and evaluating we record the following reward metrics:

* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses

## CPOTrainer

[[autodoc]] CPOTrainer


## CPOConfig

[[autodoc]] CPOConfig
121 changes: 121 additions & 0 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the CPO training script with the following command with some example arguments.
In general, the optimal configuration for CPO will be similar to that of DPO:
# regular:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-cpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-cpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config


@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
args, cpo_args, model_config = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
ds = load_dataset(args.dataset)
if cpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]

################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)

# train and save the model
trainer.train()
trainer.save_model(cpo_args.output_dir)
3 changes: 2 additions & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,6 @@ def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_d
peft_config=get_peft_config(model_args),
)

# 5. train
# 5. train and save the model
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
Loading

0 comments on commit d1df79f

Please sign in to comment.