Skip to content

Commit

Permalink
fix saved adapter.safetensors
Browse files Browse the repository at this point in the history
this _save follows the original transformers trainer implementation.
The previous implementation will case RuntimeError: Error(s) in loading state_dict for PeftModel when loading saved safetensors.
This update can fix that.
  • Loading branch information
ArvinZhuang authored May 16, 2024
1 parent 668ad0e commit 2143a85
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions src/tevatron/retriever/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import torch

from transformers.trainer import Trainer
from transformers.trainer import Trainer, TRAINING_ARGS_NAME
import torch.distributed as dist
from transformers.deepspeed import is_deepspeed_zero3_enabled
from peft import get_peft_model_state_dict, PeftModel
from modeling import EncoderModel


import logging
Expand All @@ -20,26 +19,26 @@ def __init__(self, *args, **kwargs):
self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1

def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
self.model.save(output_dir)

if is_deepspeed_zero3_enabled():
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'encoder.'
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
if isinstance(self.model.encoder, PeftModel):
lora_state_dict = get_peft_model_state_dict(self.model.encoder, state_dict)
if self.args.process_index <= 0:
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")
else:
if self.args.process_index <= 0:
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
print(f"Save model at {output_dir}")
logger.info(f"Saving model checkpoint to {output_dir}")

supported_classes = (EncoderModel,)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
raise ValueError(f"Unsupported model class {self.model}")
else:
self.model.encoder.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)

if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

def compute_loss(self, model, inputs):
query, passage = inputs
Expand Down

0 comments on commit 2143a85

Please sign in to comment.