Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli/paraformer] ali-paraformer inference #2067

Merged
merged 21 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix att mask && batch infer
  • Loading branch information
Mddct committed Oct 30, 2023
commit b1b44dfc48d55178aeef91ac23673b48dd922a5a
3 changes: 3 additions & 0 deletions wenet/paraformer/ali_paraformer/assets/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ cif_predictor_conf:
tail_threshold: 0.45
cnn_groups: 1
residual: false

model_conf:
ctc_weight: 0.0
4 changes: 2 additions & 2 deletions wenet/paraformer/ali_paraformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def forward(
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

if mask is not None:
query = query * mask.transpose(1, 2)
inputs = query

x = inputs.transpose(1, 2)
x = self.pad_fn(x)
# TODO(Mddct): cache here for future streaming
Expand Down
6 changes: 1 addition & 5 deletions wenet/paraformer/ali_paraformer/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import yaml
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.init_model import init_model


Expand All @@ -18,7 +17,7 @@ def get_args():
parser.add_argument('--cmvn',
required=True,
help='cmvn file of paraformer in wenet style')
parser.add_argument('--dict', required=True, help='dict file')
# parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--output_file', default=None, help='output file')
args = parser.parse_args()
return args
Expand All @@ -27,9 +26,6 @@ def get_args():
def main():

args = get_args()

symbol_table = read_symbol_table(args.dict)
char_dict = {v: k for k, v in symbol_table.items()}
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
configs['cmvn_file'] = args.cmvn
Expand Down
3 changes: 1 addition & 2 deletions wenet/paraformer/ali_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, m: int = 7, n: int = 6) -> None:
def forward(self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, _, D = input.size()
n_lfr = torch.ceil(input_lens / self.n)
n_lfr = torch.ceil(input_lens / self.n).to(input_lens.dtype)
# print(n_lfr)
# right_padding_nums >= 0
prepad_nums = input_lens + self.left_padding_nums
Expand Down Expand Up @@ -289,7 +289,6 @@ def forward(
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
masks_pad = make_non_pad_mask(xs_lens).unsqueeze(1) # [B,1,T]
# masks = masks_pad * masks_pad.transpose(1, 2) #[B,T,T]
xs = xs * self.output_size()**0.5
xs = self.embed(xs)
for layer in self.encoders0:
Expand Down
5 changes: 4 additions & 1 deletion wenet/paraformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def paraformer_greedy_search(
topk_prob, topk_index = decoder_out.topk(1, dim=2)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
results = []
topk_index = topk_index.cpu()
decoder_out_lens = decoder_out_lens.cpu()
# TODO(Mddct): scores, times etc
for (i, hyp) in enumerate(topk_index.tolist()):
r = DecodeResult(hyp[:decoder_out_lens.numpy()[i]])
Expand All @@ -31,7 +33,8 @@ def paraformer_beam_search(decoder_out: torch.Tensor,
beam_size=beam_size,
eos=eos)

best_hyps = indices[:, 0, :]
best_hyps = indices[:, 0, :].cpu()
decoder_out_lens = decoder_out_lens.cpu()
results = []
# TODO(Mddct): scores, times etc
for (i, hyp) in enumerate(best_hyps.tolist()):
Expand Down