Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise Noise Contrastive Alignment #1632

Merged
merged 7 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier wh

The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation using loss_type="sppo_hard" approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser.

The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def _init_dummy_dataset(self):
["t5", "bco_pair", True],
["gpt2", "sppo_hard", False],
["t5", "sppo_hard", True],
["gpt2", "nca_pair", False],
["t5", "nca_pair", True],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down Expand Up @@ -140,7 +142,7 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

def test_dpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard", "nca_pair"] = "sigmoid"
label_pad_token_id: int = -100
padding_value: int = 0
truncation_mode: str = "keep_end"
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,9 +1039,17 @@ def dpo_loss(
b = policy_rejected_logps - reference_rejected_logps

losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
elif self.loss_type == "nca_pair":
chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta
rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.beta
losses = (
-F.logsigmoid(chosen_rewards)
- 0.5 * F.logsigmoid(-chosen_rewards)
- 0.5 * F.logsigmoid(-rejected_rewards)
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo_hard']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo_hard', 'nca_pair']"
)

chosen_rewards = (
Expand Down
Loading