diff --git a/examples/scripts/dpo_visual.py b/examples/scripts/dpo_visual.py index f731a6d00d..6be22bdd39 100644 --- a/examples/scripts/dpo_visual.py +++ b/examples/scripts/dpo_visual.py @@ -16,12 +16,13 @@ accelerate launch examples/scripts/dpo_visual.py \ --dataset_name HuggingFaceH4/rlaif-v_formatted \ --model_name_or_path HuggingFaceM4/idefics2-8b \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 16 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 32 \ --dataset_num_proc 32 \ --output_dir dpo_idefics_rlaif-v \ --bf16 \ --torch_dtype bfloat16 \ + --gradient_checkpointing \ --use_peft \ --lora_target_modules=all-linear """ @@ -82,21 +83,40 @@ model_kwargs = dict( revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, attn_implementation=model_config.attn_implementation, torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) - model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + model = AutoModelForVision2Seq.from_pretrained( + model_config.model_name_or_path, + trust_remote_code=model_config.trust_remote_code, + **model_kwargs, + ) peft_config = get_peft_config(model_config) if peft_config is None: - model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + model_ref = AutoModelForVision2Seq.from_pretrained( + model_config.model_name_or_path, + trust_remote_code=model_config.trust_remote_code, + **model_kwargs, + ) else: model_ref = None - processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False) + processor = AutoProcessor.from_pretrained( + model_config.model_name_or_path, + trust_remote_code=model_config.trust_remote_code, + do_image_splitting=False, + ) tokenizer = processor.tokenizer + + # Set up the chat template + if model.config.model_type == "idefics2": + pass # the processor already has a valid chat template + elif model.config.model_type == "paligemma": + processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + elif model.config.model_type == "llava": + processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if args.ignore_bias_buffers: @@ -124,27 +144,9 @@ ds[key] = ds[key].select(range(50)) def process(row): - # The prompt can be either a string or a list. In some datasets, the prompt is just a common string - # for both rejected and chosen (already included in chosen and rejected) and is not meant to be used - # separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen, - # and in such cases, it is properly formatted as a list with keys "role" and "content". - # Example 1: - # row = {"prompt": "What does detox mean?", - # "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}], - # "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]} - # Example 2: - # row = {"prompt": [{"content": "What does detox mean?", "role": "user"}], - # "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}], - # "rejected": [{"content": "I don't know.", "role": "user"}]} - if "prompt" in row and isinstance(row["prompt"], list): - row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) - + row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False) row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False) - - if "images" in row: - for img in row["images"]: # Resize each image so the largest side is 640 pixels - img.thumbnail((640, 640)) # Resize the image to at most 640x640 pixels return row with PartialState().local_main_process_first(): @@ -168,6 +170,6 @@ def process(row): ) trainer.train() - trainer.push_to_hub + with save_context: trainer.save_model(training_args.output_dir) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c76122fb0a..640212392e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -723,9 +723,18 @@ def build_tokenized_answer(self, prompt, answer, images=None): if self.is_vision_model: if answer.count("") > 0: raise NotImplementedError("Answer contains token, which is not supported yet.") - full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False) + if "add_special_tokens" in inspect.signature(self.processor).parameters: + processor_kwargs = {"add_special_tokens": False} + else: + processor_kwargs = {} + full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs) full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics - prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0] + if not isinstance(full_tokenized["input_ids"], list): # llava processor returns tensors + full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist() + full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist() + prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0] + if not isinstance(prompt_input_ids, list): # llava processor returns tensors + prompt_input_ids = prompt_input_ids.tolist() else: full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] @@ -762,22 +771,18 @@ def build_tokenized_answer(self, prompt, answer, images=None): answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + return_dict = dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) if "pixel_values" in full_tokenized: - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - prompt_pixel_values=full_tokenized["pixel_values"], - prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"], - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) - else: - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) + return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"] + if "pixel_attention_mask" in full_tokenized: + return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"] + + return return_dict def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: """Tokenize a single row from a DPO specific dataset. @@ -805,8 +810,15 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module if not isinstance(prompt, str): raise ValueError(f"prompt should be an str but got {type(prompt)}") if self.is_vision_model: - prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False) + if "add_special_tokens" in inspect.signature(self.processor).parameters: + processor_kwargs = {"add_special_tokens": False} + else: + processor_kwargs = {} + prompt_tokens = self.processor(prompt, images=images, **processor_kwargs) prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics + if not isinstance(prompt_tokens["input_ids"], list): # llava processor returns tensors + prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist() + prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist() else: prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) @@ -1037,10 +1049,13 @@ def concatenated_inputs( ) if is_vision_model: - concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device) - concatenated_batch["pixel_attention_mask"] = ( - batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device) + concatenated_batch["pixel_values"] = torch.cat( + [batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0 ) + if "prompt_pixel_attention_mask" in batch: + concatenated_batch["pixel_attention_mask"] = torch.cat( + [batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0 + ) return concatenated_batch def dpo_loss( @@ -1262,7 +1277,8 @@ def concatenated_forward( if self.is_vision_model: model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True @@ -1275,6 +1291,11 @@ def concatenated_forward( ) all_logits = outputs.logits + if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]: + # for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens) + seq_len = concatenated_batch["concatenated_labels"].shape[1] + all_logits = all_logits[:, -seq_len:] + all_logps, size_completion = self.get_batch_logps( all_logits, concatenated_batch["concatenated_labels"],