Skip to content

Commit

Permalink
add ptr loss for cif streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored May 31, 2023
1 parent 80fe3a7 commit 7cf9825
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions py/cif_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
import torch

class PTRLoss(torch.nn.Module):
"""PTR for cif
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.eps = 1e-5

def forward(self, alpha: torch.Tensor):
""" alpha: [B,T], padding value is zero
"""
prev_frames = alpha[:, :-1]
post_frames = alpha[:, 1:]

ptr = post_frames * torch.log(post_frames /
(prev_frames + self.eps)) # [B,T-1]
ptr = ptr.sum(-1)
return ptr.sum() / ptr.size(0)

class CtcBoundaryLossV3(torch.nn.Module):
""" https://arxiv.org/pdf/2104.04702.pdf
"""
Expand Down

0 comments on commit 7cf9825

Please sign in to comment.