Skip to content

Commit

Permalink
[DPO] add 'bco_pair' loss_type (huggingface#1524)
Browse files Browse the repository at this point in the history
* add 'bco_pair' loss_type

* add BCO description to DPO doc

---------

Co-authored-by: sean.jung <sean.jung@seanjungui-MacBookPro.local>
  • Loading branch information
seanexp and sean.jung authored Apr 22, 2024
1 parent abc0584 commit c050ebc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we

The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.

The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `DPOTrainer` can be switched to this loss via the `loss_type="bco_pair"` argument.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _init_dummy_dataset(self):
["t5", "ipo", True],
["gpt2", "kto_pair", True],
["t5", "kto_pair", False],
["gpt2", "bco_pair", False],
["t5", "bco_pair", True],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down Expand Up @@ -453,6 +455,10 @@ def test_dpo_lora_bf16_autocast_llama(self):
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
["gpt2", "bco_pair", False, False],
["gpt2", "bco_pair", False, True],
["gpt2", "bco_pair", True, False],
["gpt2", "bco_pair", True, True],
]
)
@require_bitsandbytes
Expand Down
26 changes: 22 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -77,7 +78,8 @@ class DPOTrainer(Trainer):
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper,
`"kto_pair"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf), or `"bco_pair"` from [BCO](https://arxiv.org/abs/2404.04656) paper.
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -147,7 +149,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair"] = "sigmoid",
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -359,7 +361,7 @@ def make_inputs_require_grad(module, input, output):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False

if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
if loss_type in ["hinge", "ipo", "kto_pair", "bco_pair"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -421,6 +423,9 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -899,9 +904,22 @@ def dpo_loss(
),
0,
)
elif self.loss_type == "bco_pair":
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards)
delta = self.running.mean

losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair']"
)

chosen_rewards = (
Expand Down

0 comments on commit c050ebc

Please sign in to comment.