diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c57650939d..a34ad62d44 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -578,6 +578,7 @@ def make_inputs_require_grad(module, input, output): "truncation_mode": self.truncation_mode, "label_pad_token_id": self.label_pad_token_id, "max_prompt_length": self.max_prompt_length, + "max_completion_length": self.max_completion_length, } train_dataset = train_dataset.map( _process_tokens, @@ -618,6 +619,7 @@ def make_inputs_require_grad(module, input, output): "truncation_mode": self.truncation_mode, "label_pad_token_id": self.label_pad_token_id, "max_prompt_length": self.max_prompt_length, + "max_completion_length": self.max_completion_length, } eval_dataset = eval_dataset.map( _process_tokens,