Skip to content

Commit

Permalink
fix: forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent ca4b04b commit 27e90d8
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, Optional, Union

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -43,8 +43,16 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))

def forward(self, pair: Dict[str, Tensor] = None):
ranker_logits = self.hf_model(**pair, return_dict=True).logits
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, token_type_ids: Tensor = None, **kwargs):
model_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
}
if token_type_ids is not None:
model_inputs['token_type_ids'] = token_type_ids

ranker_logits = self.hf_model(**model_inputs, return_dict=True).logits

if self.train_batch_size:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
loss = self.cross_entropy(grouped_logits, self.target_label)
Expand All @@ -60,7 +68,6 @@ def forward(self, pair: Dict[str, Tensor] = None):

def gradient_checkpointing_enable(self, **kwargs):
return False
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)

@classmethod
def build(
Expand Down

0 comments on commit 27e90d8

Please sign in to comment.