Skip to content

Commit

Permalink
Update xbert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Junnan Li authored Mar 1, 2022
1 parent f224b67 commit b726f8e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/xbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def forward(
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)

if soft_labels is not None:
loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=1)*soft_labels,dim=-1)
loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=-1)*soft_labels,dim=-1)
loss_distill = (loss_distill * (labels!=-100)).sum(1)
lm_loss = (1-alpha)*lm_loss + alpha*loss_distill

Expand Down Expand Up @@ -1426,7 +1426,7 @@ def forward(
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if soft_labels is not None:
loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1)*soft_labels,dim=-1)
loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=-1)*soft_labels,dim=-1)
loss_distill = loss_distill[labels!=-100].mean()
masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill

Expand Down

0 comments on commit b726f8e

Please sign in to comment.