Skip to content

Commit

Permalink
export jit and load work
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 23, 2023
1 parent 6967ef5 commit c3ca2e8
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 194 deletions.
88 changes: 56 additions & 32 deletions wenet/cif/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,38 @@

import torch
from torch import nn
from torchaudio.compliance.kaldi import Tuple
from wenet.utils.mask import make_pad_mask


class Predictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1,
smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1,
groups=idim)
self.cif_conv1d = nn.Conv1d(
idim,
idim,
l_order + r_order + 1,
groups=idim if cnn_groups == 0 else cnn_groups)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
Expand All @@ -46,7 +61,10 @@ def forward(self,
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
if self.residual:
output = memory + context
else:
output = memory
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
Expand All @@ -55,11 +73,11 @@ def forward(self,
alphas = torch.nn.functional.relu(alphas * self.smooth_factor -
self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
mask = mask.transpose(-1, -2)
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
alphas = alphas.squeeze(2)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
Expand All @@ -72,10 +90,10 @@ def forward(self,
alphas *= (target_length / token_num)[:, None] \
.repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas,
hidden, alphas, token_num = self.tail_process_fn(hidden,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand All @@ -84,26 +102,32 @@ def forward(self,

return acoustic_embeds, token_num, alphas, cif_peak

def tail_process_fn(self, hidden, alphas,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
def tail_process_fn(
self,
hidden: torch.Tensor,
alphas: torch.Tensor,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, _, d = hidden.size()
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32,
zeros_t = torch.zeros((b, 1),
dtype=torch.float32,
device=alphas.device)
mask = mask.to(zeros_t.dtype)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
tail_threshold = mask * self.tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold_tensor = torch.tensor([tail_threshold],
tail_threshold_tensor = torch.tensor([self.tail_threshold],
dtype=alphas.dtype).to(
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor,
(1, 1))
alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
Expand Down Expand Up @@ -132,13 +156,15 @@ def gen_frame_alignments(self,

index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(
alphas_cumsum.device)
index = index[:, :,
None].repeat(1, 1,
maximum_length).to(alphas_cumsum.device)

index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(
int_type)
index_div = torch.floor(torch.true_divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros,
dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0,
encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
Expand Down Expand Up @@ -210,19 +236,17 @@ def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(fire_place, integrate -
torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur

frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
remainds[:, None] * hidden[:, t, :], frame)

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward_qkv(
def forward_fsmn(self,
inputs: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk=None):
mask_shfit_chunk: Optional[torch.Tensor] = None):
b, _, t, _ = inputs.size()
inputs = inputs.transpose(1, 2).view(b, t, -1)
if mask.size(2) > 0: # time2 > 0
Expand Down Expand Up @@ -136,14 +136,14 @@ def forward(
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

inputs = query

x = inputs.transpose(1, 2)
x = self.pad_fn(x)
# TODO(Mddct): cache here for future streaming
cache: torch.Tensor
cache: Optional[torch.Tensor] = None
x = self.fsmn_block(x)
x = x.transpose(1, 2)
if x.size(1) != inputs.size(1):
Expand Down Expand Up @@ -204,11 +204,11 @@ def forward(
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
q, k, v = self.forward_qkv(query, key, key)
q = q * self.d_k**(-0.5)
scores = torch.matmul(q, k.transpose(-2, -1))

# TODO(Mddct): support future streaming paraformer
cache: torch.Tensor
cache: Optional[torch.Tensor] = None
return self.forward_attention(v, scores, mask), cache
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ predictor_conf:
r_order: 1
tail_threshold: 0.45
cnn_groups: 1
residual: false

# smooth_factor2: 0.25
# noise_threshold2: 0.01
# upsample_times: 3
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit c3ca2e8

Please sign in to comment.