Skip to content

Commit

Permalink
intial RPO loss (#1686)
Browse files Browse the repository at this point in the history
* intial RPO loss

* fix sign

* clean up
  • Loading branch information
kashif authored Jun 3, 2024
1 parent 151a452 commit f18253b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the

The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.

The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://arxiv.org/abs/2405.16436) that essentially consists of the SFT loss on the chosen preferences together with a weighted DPO loss. To use this loss set the `rpo_alpha` in the `DPOConfig` to an appropriate value.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
1 change: 1 addition & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self):
eval_strategy="steps",
beta=0.1,
precompute_ref_log_probs=True,
rpo_alpha=0.5,
)

dummy_dataset = self._init_dummy_dataset()
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DPOConfig(TrainingArguments):
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_sync_steps ('int', defaults to 2):
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
rpo_alpha ('float', defaults to `None`):
The alpha parameter from the [RPO](https://arxiv.org/pdf/2404.19733) paper. If None, no weighting is applied and the loss is the same as the DPO loss.
"""

beta: float = 0.1
Expand Down Expand Up @@ -98,3 +100,4 @@ class DPOConfig(TrainingArguments):
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None
35 changes: 23 additions & 12 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,13 +901,15 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.model, padded_batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.ref_model, padded_batch)

return reference_chosen_logps, reference_rejected_logps
Expand Down Expand Up @@ -1089,21 +1091,19 @@ def dpo_loss(
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
Expand All @@ -1118,10 +1118,7 @@ def get_batch_logps(

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
Expand Down Expand Up @@ -1154,21 +1151,25 @@ def concatenated_forward(
**model_kwargs,
).logits

all_logps = self.get_batch_logps(
all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

if self.loss_type == "ipo":
all_logps = all_logps / size_completion

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)

def get_batch_loss_metrics(
self,
Expand All @@ -1184,10 +1185,15 @@ def get_batch_loss_metrics(
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
if (
"reference_chosen_logps" in batch
and "reference_rejected_logps" in batch
and self.args.rpo_alpha is not None
):
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
Expand All @@ -1199,13 +1205,15 @@ def get_batch_loss_metrics(
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards = self.dpo_loss(
Expand All @@ -1216,6 +1224,9 @@ def get_batch_loss_metrics(
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

if self.args.rpo_alpha is not None:
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
Expand Down

0 comments on commit f18253b

Please sign in to comment.