diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 4f0ad68a36..2022743b4e 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -119,6 +119,8 @@ The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`. +The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://arxiv.org/abs/2405.16436) that essentially consists of the SFT loss on the chosen preferences together with a weighted DPO loss. To use this loss set the `rpo_alpha` in the `DPOConfig` to an appropriate value. + ## Logging While training and evaluating we record the following reward metrics: diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index fddc79166e..3b5bfcae2b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -160,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self): eval_strategy="steps", beta=0.1, precompute_ref_log_probs=True, + rpo_alpha=0.5, ) dummy_dataset = self._init_dummy_dataset() diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 7a94cc74a4..b4b259fa46 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -71,6 +71,8 @@ class DPOConfig(TrainingArguments): The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. ref_model_sync_steps ('int', defaults to 2): The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. + rpo_alpha ('float', defaults to `None`): + The alpha parameter from the [RPO](https://arxiv.org/pdf/2404.19733) paper. If None, no weighting is applied and the loss is the same as the DPO loss. """ beta: float = 0.1 @@ -98,3 +100,4 @@ class DPOConfig(TrainingArguments): sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 + rpo_alpha: Optional[float] = None diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index fb79ef0e39..8bf29afa6b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -901,6 +901,7 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: reference_rejected_logps, _, _, + _, ) = self.concatenated_forward(self.model, padded_batch) else: ( @@ -908,6 +909,7 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: reference_rejected_logps, _, _, + _, ) = self.concatenated_forward(self.ref_model, padded_batch) return reference_chosen_logps, reference_rejected_logps @@ -1089,21 +1091,19 @@ def dpo_loss( def get_batch_logps( logits: torch.FloatTensor, labels: torch.LongTensor, - average_log_prob: bool = False, label_pad_token_id: int = -100, is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: """Compute the log probabilities of the given labels under the given logits. Args: logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. label_pad_token_id: The label pad token id. is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor. """ if logits.shape[:-1] != labels.shape: raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") @@ -1118,10 +1118,7 @@ def get_batch_logps( per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] @@ -1154,13 +1151,17 @@ def concatenated_forward( **model_kwargs, ).logits - all_logps = self.get_batch_logps( + all_logps, size_completion = self.get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], - average_log_prob=self.loss_type == "ipo", + # average_log_prob=self.loss_type == "ipo", is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) + chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen] + + if self.loss_type == "ipo": + all_logps = all_logps / size_completion chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] @@ -1168,7 +1169,7 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg) def get_batch_loss_metrics( self, @@ -1184,10 +1185,15 @@ def get_batch_loss_metrics( policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, + policy_chosen_logps_avg, ) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model - if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: + if ( + "reference_chosen_logps" in batch + and "reference_rejected_logps" in batch + and self.args.rpo_alpha is not None + ): reference_chosen_logps = batch["reference_chosen_logps"] reference_rejected_logps = batch["reference_rejected_logps"] else: @@ -1199,6 +1205,7 @@ def get_batch_loss_metrics( reference_rejected_logps, _, _, + _, ) = self.concatenated_forward(self.model, batch) else: ( @@ -1206,6 +1213,7 @@ def get_batch_loss_metrics( reference_rejected_logps, _, _, + _, ) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( @@ -1216,6 +1224,9 @@ def get_batch_loss_metrics( ) reward_accuracies = (chosen_rewards > rejected_rewards).float() + if self.args.rpo_alpha is not None: + losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg + prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()