Skip to content

Commit

Permalink
Integrate f-divergence to DPO (Follow up) (huggingface#1610)
Browse files Browse the repository at this point in the history
* Step 1: update ppo_trainer and hello_world example

* Step 2: Refine comments and add parameter type

* Step 2: Add missing parameter comments

* Step 1: Organize ptx loss into a function and add ptx_loss to train_stats

* Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message

* Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate

* Step 2: Remove loss from columns_to_log in ppo_ptx example

* Remove data set revision in load imbd dataset

* Run pre-commit and fix format issues

* Initial draft of f-divergence fn

* Update f-divergence to avoid overflow

* fix test errors and comments

* Add Unit tests for dpo loss with alpha and js div f

* Adjust format

* Fix test error

* Reverse this update

* Add test cases

* Reverse un-needed updates

* Update code style

* Try to fix code fmt error

* remove extra end line

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
1485840691 and kashif authored Jun 19, 2024
1 parent ae23d40 commit a57e759
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 12 deletions.
86 changes: 85 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import DPOConfig, DPOTrainer
from trl import DPOConfig, DPOTrainer, FDivergenceType

from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft

Expand Down Expand Up @@ -725,3 +725,87 @@ def test_dpo_lora_force_use_ref(self):

# train the model
trainer.train()

def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([810.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([110.6, 0.5])
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
assert torch.isfinite(losses).cpu().numpy().all()

def test_dpo_loss_js_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.JS_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([95.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([5.5, 0.5])
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
assert torch.isfinite(losses).cpu().numpy().all()
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
"RewardTrainer",
"SFTConfig",
"SFTTrainer",
"FDivergenceConstants",
"FDivergenceType",
],
"commands": [],
"commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"],
Expand Down Expand Up @@ -117,6 +119,8 @@
RewardTrainer,
SFTConfig,
SFTTrainer,
FDivergenceConstants,
FDivergenceType,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"peft_module_casting_to_bf16",
"RichProgressCallback",
],
"dpo_config": ["DPOConfig"],
"dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"],
"dpo_trainer": ["DPOTrainer"],
"cpo_config": ["CPOConfig"],
"cpo_trainer": ["CPOTrainer"],
Expand Down Expand Up @@ -76,7 +76,7 @@
from .base import BaseTrainer
from .ddpo_config import DDPOConfig

from .dpo_config import DPOConfig
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .cpo_config import CPOConfig
Expand Down
18 changes: 18 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Literal, Optional

from transformers import TrainingArguments


class FDivergenceType(Enum):
REVERSE_KL = "reverse_kl"
JS_DIVERGENCE = "js_divergence"
ALPHA_DIVERGENCE = "alpha_divergence"


class FDivergenceConstants:
ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0


@dataclass
class DPOConfig(TrainingArguments):
r"""
Expand Down Expand Up @@ -66,6 +78,10 @@ class DPOConfig(TrainingArguments):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
f_divergence_type (`FDivergenceType`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
The type of f-divergence regularization function to compute divergence between policy and reference model. This argument is optional, defaults to `FDivergenceType.REVERSE_KL`.
f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
The alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss.
sync_ref_model ('bool', defaults to `False`):
The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_mixup_alpha ('float', defaults to 1.0):
Expand Down Expand Up @@ -98,6 +114,8 @@ class DPOConfig(TrainingArguments):
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: Optional[float] = 1.0
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
Expand Down
50 changes: 41 additions & 9 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .dpo_config import DPOConfig
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
cap_exp,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -481,6 +482,9 @@ def make_inputs_require_grad(module, input, output):

self._stored_metrics = defaultdict(lambda: defaultdict(list))

self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}

if dataset_num_proc is not None:
warnings.warn(
"You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
Expand Down Expand Up @@ -998,15 +1002,43 @@ def dpo_loss(
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
not self.reference_free
) * reference_chosen_logps.to(self.accelerator.device)
rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - (
not self.reference_free
) * reference_rejected_logps.to(self.accelerator.device)

if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
# The alpha-divergence formula: (1 - u^-alpha) / alpha
# The divergence difference between the chosen and rejected sample is:
# (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
# = (u[l]^-alpha - u[w]^-alpha) / alpha
# where u[w] and u[l] are the policy/reference probability ratios
# for the chosen and rejected samples, respectively.
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios

if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
# The js-divergence formula: log(2 * u / (1 + u))
# The divergence difference between the chosen and rejected sample is:
# log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
# = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
# where u[w] and u[l] are the policy/reference probability ratios
# for the chosen and rejected samples, respectively.
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
Expand Down
26 changes: 26 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,32 @@ def on_train_end(self, args, state, control, **kwargs):
self.current_step = None


def get_exp_cap(value, decimal=4):
"""
Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow.
The formula is : log(value.dtype.max)
E.g.
For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points.
```
Args:
value (`torch.Tensor`):
The input tensor to obtain the data type
decimal (`int`):
The number of decimal points of the output exponent cap.
eg: direct calling exp(log(torch.float32.max)) will result in inf
so we cap the exponent to 88.7228 to avoid overflow.
"""
vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max
vdtype_log_max = torch.log(vdtype_max).to(value.device)
return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max


def cap_exp(value, cap=-1):
# Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp
cap = get_exp_cap(value) if cap < 0 else cap
return torch.exp(torch.clamp(value, max=cap))


def print_rich_table(df: pd.DataFrame) -> Table:
console = Console()
table = Table(show_lines=True)
Expand Down

0 comments on commit a57e759

Please sign in to comment.