Skip to content

Commit

Permalink
model.decodde work && recognize.py work
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 23, 2023
1 parent 0a98847 commit 0658057
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion wenet/paraformer/ali_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def decode(self, methods: List[str], speech: torch.Tensor,
results = []
out, out_lens = self.forward(speech, speech_lens)
for (i, value) in enumerate(out.argmax(-1).numpy()):
results.append(DecodeResult(value.numpy()[:out_lens[i]]))
results.append(DecodeResult(value[:out_lens[i]]))

results_dict['paraformer_greedy_search'] = results
return results_dict
9 changes: 6 additions & 3 deletions wenet/paraformer/ali_paraformer/test_infer_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def main():
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)

out, token_nums = model(feats, feats_lens)
print("".join([char_dict[id] for id in out.argmax(-1)[0].numpy()]))
print(token_nums)
decode_results = model.decode(['paraformer_greedy_search'], feats,
feats_lens)
print("".join([
char_dict[id]
for id in decode_results['paraformer_greedy_search'][0].tokens
]))

if args.output_file:
script_model = torch.jit.script(model)
Expand Down

0 comments on commit 0658057

Please sign in to comment.