Skip to content

Commit

Permalink
Update cif_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Mar 11, 2023
1 parent 55c9323 commit df0e2ae
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions py/cif_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self,
spike_threshold: float = 0.0,
blank_id: int = 0) -> None:
super().__init__()
self.spike_threshold = math.log(spike_threshold)
self.spike_threshold = math.log(1-spike_threshold)
self.blank = blank_id

def forward(self, alpha: torch.Tensor, ctc_log_probs: torch.Tensor,
Expand All @@ -19,7 +19,7 @@ def forward(self, alpha: torch.Tensor, ctc_log_probs: torch.Tensor,
text_mask = make_non_pad_mask(text_length)
batch_size = alpha.size(0)
ctc_blank_probs = ctc_log_probs[:, :, self.blank]
triggerd = (math.log(1) - ctc_blank_probs) > self.spike_threshold
triggerd = ctc_blank_probs < self.spike_threshold
spikes = triggerd * mask
begin = torch.ones(batch_size,
1,
Expand Down

0 comments on commit df0e2ae

Please sign in to comment.