Skip to content

Commit

Permalink
support loss function for Self-play Preference Optimization (huggingf…
Browse files Browse the repository at this point in the history
…ace#1612)

* support loss function for Self-play Preference Optimization

* update docs

* update value error msg

* update typehint

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* include sppo in tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
winglian and kashif authored May 2, 2024
1 parent 0d40e18 commit adf17a5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utilit

The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `DPOTrainer` can be switched to this loss via the `loss_type="bco_pair"` argument.

The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _init_dummy_dataset(self):
["t5", "kto_pair", False],
["gpt2", "bco_pair", False],
["t5", "bco_pair", True],
["gpt2", "sppo", False],
["t5", "sppo", True],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo"] = "sigmoid"
label_pad_token_id: int = -100
padding_value: int = 0
truncation_mode: str = "keep_end"
Expand Down
7 changes: 6 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,9 +1028,14 @@ def dpo_loss(
losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
elif self.loss_type == "sppo":
a = self.beta * (policy_chosen_logps - reference_chosen_logps)
b = self.beta * (policy_rejected_logps - reference_rejected_logps)

losses = (a - 0.5) ** 2 + (b + 0.5) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo']"
)

chosen_rewards = (
Expand Down

0 comments on commit adf17a5

Please sign in to comment.