Skip to content

Commit

Permalink
update trainer._save
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinZhuang authored May 22, 2024
1 parent 53944b0 commit 005bbed
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/tevatron/retriever/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
if not isinstance(self.model, supported_classes):
raise ValueError(f"Unsupported model class {self.model}")
else:
prefix = 'encoder.'
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
self.model.encoder.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
Expand Down

0 comments on commit 005bbed

Please sign in to comment.