forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add CPOTrainer * add docs * fix formatting * removed precompute_ref_log_probs arg * remove precompute_ref_log_probs * typos * finish cpo trainer doc * remove redundant lines * typo * formatting * compute chosen nll loss also for enc-dec models * fix gradient error of inplace operation for enc-dec models * formatting * use CPOConfig * formatting * use model_init_kwargs from CPOConfig * comments in example * fix doc string * fix typo in docstring * update year * fixed typo * use preference dataset * fix learning rate * move dataset_num_proc to configs * Update cpo paper link from HF: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update description for CPO: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * remove _prepare_deepspeed for cpo Because CPO does not need init for reference model * Add explanation to CPO loss * format * fix bug when lengths are given * add CPOTrainer to README * fix grammer --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
- Loading branch information
1 parent
d10f766
commit d1df79f
Showing
12 changed files
with
1,433 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# CPO Trainer | ||
|
||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to | ||
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat. | ||
|
||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. | ||
|
||
## Expected dataset format | ||
|
||
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows: | ||
|
||
- `prompt` | ||
- `chosen` | ||
- `rejected` | ||
|
||
for example: | ||
|
||
```py | ||
cpo_dataset_dict = { | ||
"prompt": [ | ||
"hello", | ||
"how are you", | ||
"What is your name?", | ||
"What is your name?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
], | ||
"chosen": [ | ||
"hi nice to meet you", | ||
"I am fine", | ||
"My name is Mary", | ||
"My name is Mary", | ||
"Python", | ||
"Python", | ||
"Java", | ||
], | ||
"rejected": [ | ||
"leave me alone", | ||
"I am not fine", | ||
"Whats it to you?", | ||
"I dont have a name", | ||
"Javascript", | ||
"C++", | ||
"C++", | ||
], | ||
} | ||
``` | ||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. | ||
|
||
|
||
## Expected model format | ||
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. | ||
|
||
## Using the `CPOTrainer` | ||
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. | ||
|
||
```py | ||
cpo_config = CPOConfig( | ||
beta=0.1, | ||
) | ||
|
||
cpo_trainer = CPOTrainer( | ||
model, | ||
args=cpo_config, | ||
train_dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
) | ||
``` | ||
After this one can then call: | ||
|
||
```py | ||
cpo_trainer.train() | ||
``` | ||
|
||
## Loss functions | ||
|
||
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | ||
|
||
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin. | ||
|
||
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only). | ||
|
||
|
||
## Logging | ||
|
||
While training and evaluating we record the following reward metrics: | ||
|
||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta | ||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta | ||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards | ||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards | ||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses | ||
|
||
## CPOTrainer | ||
|
||
[[autodoc]] CPOTrainer | ||
|
||
|
||
## CPOConfig | ||
|
||
[[autodoc]] CPOConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Run the CPO training script with the following command with some example arguments. | ||
In general, the optimal configuration for CPO will be similar to that of DPO: | ||
# regular: | ||
python examples/scripts/cpo.py \ | ||
--model_name_or_path=gpt2 \ | ||
--per_device_train_batch_size 4 \ | ||
--max_steps 1000 \ | ||
--learning_rate 8e-6 \ | ||
--gradient_accumulation_steps 1 \ | ||
--logging_steps 10 \ | ||
--eval_steps 500 \ | ||
--output_dir="gpt2-aligned-cpo" \ | ||
--warmup_steps 150 \ | ||
--report_to wandb \ | ||
--bf16 \ | ||
--logging_first_step \ | ||
--no_remove_unused_columns | ||
# peft: | ||
python examples/scripts/cpo.py \ | ||
--model_name_or_path=gpt2 \ | ||
--per_device_train_batch_size 4 \ | ||
--max_steps 1000 \ | ||
--learning_rate 8e-5 \ | ||
--gradient_accumulation_steps 1 \ | ||
--logging_steps 10 \ | ||
--eval_steps 500 \ | ||
--output_dir="gpt2-lora-aligned-cpo" \ | ||
--optim rmsprop \ | ||
--warmup_steps 150 \ | ||
--report_to wandb \ | ||
--bf16 \ | ||
--logging_first_step \ | ||
--no_remove_unused_columns \ | ||
--use_peft \ | ||
--lora_r=16 \ | ||
--lora_alpha=16 | ||
""" | ||
|
||
import multiprocessing | ||
from dataclasses import dataclass, field | ||
|
||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | ||
|
||
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
dataset: str = field( | ||
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."} | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) | ||
args, cpo_args, model_config = parser.parse_args_into_dataclasses() | ||
|
||
################ | ||
# Model & Tokenizer | ||
################ | ||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) | ||
peft_config = get_peft_config(model_config) | ||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
################ | ||
# Dataset | ||
################ | ||
ds = load_dataset(args.dataset) | ||
if cpo_args.debug: | ||
for key in ds: | ||
ds[key] = ds[key].select(range(50)) | ||
if tokenizer.chat_template is None: | ||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" | ||
|
||
def process(row): | ||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) | ||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) | ||
return row | ||
|
||
ds = ds.map( | ||
process, | ||
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(), | ||
load_from_cache_file=False, | ||
) | ||
train_dataset = ds["train"] | ||
eval_dataset = ds["test"] | ||
|
||
################ | ||
# Training | ||
################ | ||
trainer = CPOTrainer( | ||
model, | ||
args=cpo_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
peft_config=get_peft_config(model_config), | ||
) | ||
|
||
# train and save the model | ||
trainer.train() | ||
trainer.save_model(cpo_args.output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.