Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli/paraformer] ali-paraformer inference #2067

Merged
merged 21 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
export jit and load work
  • Loading branch information
Mddct committed Oct 23, 2023
commit c8cccdc8fa1fe458fa98b9d924864c8506814354
86 changes: 55 additions & 31 deletions wenet/cif/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,38 @@

import torch
from torch import nn
from torchaudio.compliance.kaldi import Tuple
from wenet.utils.mask import make_pad_mask


class Predictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1,
smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1,
groups=idim)
self.cif_conv1d = nn.Conv1d(
idim,
idim,
l_order + r_order + 1,
groups=idim if cnn_groups == 0 else cnn_groups)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
Expand All @@ -46,7 +61,10 @@ def forward(self,
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
if self.residual:
output = memory + context
else:
output = memory
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
Expand All @@ -55,7 +73,7 @@ def forward(self,
alphas = torch.nn.functional.relu(alphas * self.smooth_factor -
self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
mask = mask.transpose(-1, -2)
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
Expand All @@ -72,10 +90,10 @@ def forward(self,
alphas *= (target_length / token_num)[:, None] \
.repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas,
hidden, alphas, token_num = self.tail_process_fn(hidden,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand All @@ -84,26 +102,32 @@ def forward(self,

return acoustic_embeds, token_num, alphas, cif_peak

def tail_process_fn(self, hidden, alphas,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
def tail_process_fn(
self,
hidden: torch.Tensor,
alphas: torch.Tensor,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, _, d = hidden.size()
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32,
zeros_t = torch.zeros((b, 1),
dtype=torch.float32,
device=alphas.device)
mask = mask.to(zeros_t.dtype)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
tail_threshold = mask * self.tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold_tensor = torch.tensor([tail_threshold],
tail_threshold_tensor = torch.tensor([self.tail_threshold],
dtype=alphas.dtype).to(
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor,
(1, 1))
alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
Expand Down Expand Up @@ -132,13 +156,15 @@ def gen_frame_alignments(self,

index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(
alphas_cumsum.device)
index = index[:, :,
None].repeat(1, 1,
maximum_length).to(alphas_cumsum.device)

index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(
int_type)
index_div = torch.floor(torch.true_divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros,
dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0,
encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
Expand Down Expand Up @@ -210,19 +236,17 @@ def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(fire_place, integrate -
torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur

frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
remainds[:, None] * hidden[:, t, :], frame)

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward_qkv(
def forward_fsmn(self,
inputs: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk=None):
mask_shfit_chunk: Optional[torch.Tensor] = None):
b, _, t, _ = inputs.size()
inputs = inputs.transpose(1, 2).view(b, t, -1)
if mask.size(2) > 0: # time2 > 0
Expand Down Expand Up @@ -136,14 +136,14 @@ def forward(
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

inputs = query

x = inputs.transpose(1, 2)
x = self.pad_fn(x)
# TODO(Mddct): cache here for future streaming
cache: torch.Tensor
cache: Optional[torch.Tensor] = None
x = self.fsmn_block(x)
x = x.transpose(1, 2)
if x.size(1) != inputs.size(1):
Expand Down Expand Up @@ -204,11 +204,11 @@ def forward(
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
q, k, v = self.forward_qkv(query, key, key)
q = q * self.d_k**(-0.5)
scores = torch.matmul(q, k.transpose(-2, -1))

# TODO(Mddct): support future streaming paraformer
cache: torch.Tensor
cache: Optional[torch.Tensor] = None
return self.forward_attention(v, scores, mask), cache
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ predictor_conf:
r_order: 1
tail_threshold: 0.45
cnn_groups: 1
residual: false

# smooth_factor2: 0.25
# noise_threshold2: 0.01
# upsample_times: 3
Expand Down
91 changes: 91 additions & 0 deletions wenet/paraformer/ali_paraformer/export_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
""" NOTE(Mddct): This file is experimental and is used to export paraformer
"""

import argparse
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.model import (
AliParaformer,
SanmDecoer,
SanmEncoder,
)
from wenet.transformer.cmvn import GlobalCMVN
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.cmvn import load_cmvn
from wenet.utils.file_utils import read_symbol_table


def get_args():
parser = argparse.ArgumentParser(description='load ali-paraformer')
parser.add_argument('--ali_paraformer',
required=True,
help='ali released Paraformer model path')
parser.add_argument('--config', required=True, help='config of paraformer')
parser.add_argument('--cmvn',
required=True,
help='cmvn file of paraformer in wenet style')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--wav', required=True, help='wav file')
parser.add_argument('--output_file', default=None, help='output file')
args = parser.parse_args()
return args


def main():

args = get_args()

symbol_table = read_symbol_table(args.dict)
char_dict = {v: k for k, v in symbol_table.items()}
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)

mean, istd = load_cmvn(args.cmvn, is_json=True)
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
configs['encoder_conf']['input_size'] = 80 * 7
encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf'])
configs['decoder_conf']['vocab_size'] = len(char_dict)
configs['decoder_conf']['encoder_output_size'] = encoder.output_size()
decoder = SanmDecoer(**configs['decoder_conf'])

# predictor = PredictorV3(**configs['predictor_conf'])
predictor = Predictor(**configs['predictor_conf'])
model = AliParaformer(encoder, decoder, predictor)
load_checkpoint(model, args.ali_paraformer)
model.eval()

waveform, sample_rate = torchaudio.load(args.wav)
assert sample_rate == 16000
waveform = waveform * (1 << 15)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=sample_rate)
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)

out, token_nums = model(feats, feats_lens)
print("".join([char_dict[id] for id in out.argmax(-1)[0].numpy()]))
print(token_nums)

if args.output_file:
script_model = torch.jit.script(model)
script_model.save(args.output_file)

model = torch.jit.load(args.output_file)
out, token_nums = model.forward(feats, feats_lens)
print("".join([char_dict[id] for id in out.argmax(-1)[0].numpy()]))
print(token_nums)


if __name__ == "__main__":

main()
Loading
Loading