From adf17a5a269a0bc59162597f81e3d489a8c144e5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 10:06:58 -0400 Subject: [PATCH] support loss function for Self-play Preference Optimization (#1612) * support loss function for Self-play Preference Optimization * update docs * update value error msg * update typehint * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul * include sppo in tests --------- Co-authored-by: Kashif Rasul --- docs/source/dpo_trainer.mdx | 2 ++ tests/test_dpo_trainer.py | 2 ++ trl/trainer/dpo_config.py | 2 +- trl/trainer/dpo_trainer.py | 7 ++++++- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 9d9294814c..425e312d95 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -111,6 +111,8 @@ The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utilit The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `DPOTrainer` can be switched to this loss via the `loss_type="bco_pair"` argument. +The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. + ## Logging While training and evaluating we record the following reward metrics: diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index c60ee6dcfe..946c7e7043 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -90,6 +90,8 @@ def _init_dummy_dataset(self): ["t5", "kto_pair", False], ["gpt2", "bco_pair", False], ["t5", "bco_pair", True], + ["gpt2", "sppo", False], + ["t5", "sppo", True], ] ) def test_dpo_trainer(self, name, loss_type, pre_compute): diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 84f19ec03a..87c1298aab 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -69,7 +69,7 @@ class DPOConfig(TrainingArguments): beta: float = 0.1 label_smoothing: float = 0 - loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair"] = "sigmoid" + loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo"] = "sigmoid" label_pad_token_id: int = -100 padding_value: int = 0 truncation_mode: str = "keep_end" diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 801c5e0020..d1706d3a35 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1028,9 +1028,14 @@ def dpo_loss( losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( -(self.beta * rejected_logratios - delta) ) + elif self.loss_type == "sppo": + a = self.beta * (policy_chosen_logps - reference_chosen_logps) + b = self.beta * (policy_rejected_logps - reference_rejected_logps) + + losses = (a - 0.5) ** 2 + (b + 0.5) ** 2 else: raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair']" + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo']" ) chosen_rewards = (