-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cli/paraformer] ali-paraformer inference #2067
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
000c7af
[cli/paraformer] ali-paraformer load and infer work
Mddct 6967ef5
fix lint
Mddct c8cccdc
export jit and load work
Mddct 852165c
reuse init_model.py
Mddct e8fa013
mv the intermediate files to the assets directory
Mddct 0a98847
merge main
Mddct e987b00
model.decodde work && recognize.py work
Mddct 3d25e2e
rm positionwise_feed_forward.py/lfr.py
Mddct aedd43b
Merge branch 'main' into Mddct-cli-paraformer
Mddct d264297
refactor search
Mddct 3f45af7
merge main
Mddct 30b5677
merge main
Mddct 20146f4
cli work
Mddct e3ec8e7
fix lint
Mddct b1b44df
fix att mask && batch infer
Mddct cd9c659
search confidence works
Mddct 0b0eea7
merge main
Mddct c11aefe
merge main
Mddct daff617
fix linux dtype
Mddct 9e810d1
fix label type
Mddct b765c12
revert init_model.py and add init_model in export_jit
Mddct File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
|
||
import torch | ||
import torchaudio | ||
import torchaudio.compliance.kaldi as kaldi | ||
|
||
from wenet.paraformer.search import paraformer_greedy_search | ||
from wenet.utils.file_utils import read_symbol_table | ||
|
||
|
||
class Paraformer: | ||
|
||
def __init__(self, model_dir: str) -> None: | ||
|
||
model_path = os.path.join(model_dir, 'final.zip') | ||
units_path = os.path.join(model_dir, 'units.txt') | ||
self.model = torch.jit.load(model_path) | ||
symbol_table = read_symbol_table(units_path) | ||
self.char_dict = {v: k for k, v in symbol_table.items()} | ||
self.eos = 2 | ||
|
||
def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: | ||
waveform, sample_rate = torchaudio.load(audio_file, normalize=False) | ||
waveform = waveform.to(torch.float) | ||
feats = kaldi.fbank(waveform, | ||
num_mel_bins=80, | ||
frame_length=25, | ||
frame_shift=10, | ||
energy_floor=0.0, | ||
sample_frequency=16000) | ||
feats = feats.unsqueeze(0) | ||
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64) | ||
|
||
decoder_out, token_num = self.model.forward_paraformer( | ||
feats, feats_lens) | ||
|
||
res = paraformer_greedy_search(decoder_out, token_num)[0] | ||
|
||
result = {} | ||
result['confidence'] = res.confidence | ||
# # TODO(Mddct): deal with '@@' and 'eos' | ||
result['rec'] = "".join([self.char_dict[x] for x in res.tokens]) | ||
|
||
if tokens_info: | ||
tokens_info = [] | ||
for i, x in enumerate(res.tokens): | ||
tokens_info.append({ | ||
'token': self.char_dict[x], | ||
# TODO(Mddct): support times | ||
# 'start': 0, | ||
# 'end': 0, | ||
'confidence': res.tokens_confidence[i] | ||
}) | ||
result['tokens'] = tokens_info | ||
|
||
# result = ''.join(hyp) | ||
return result | ||
|
||
def align(self, audio_file: str, label: str) -> dict: | ||
raise NotImplementedError | ||
|
||
|
||
def load_model(language: str = None, model_dir: str = None) -> Paraformer: | ||
if model_dir is None: | ||
model_dir = Hub.get_model_by_lang(language) | ||
return Paraformer(model_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: SanmEncoder | ||
encoder_conf: | ||
output_size: 512 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 50 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: 'conv2d' # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
kernel_size: 11 | ||
sanm_shfit: 0 | ||
|
||
input_dim: 80 | ||
output_dim: 8404 | ||
paraformer: true | ||
is_json_cmvn: True | ||
# decoder related | ||
decoder: SanmDecoder | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 16 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.1 | ||
src_attention_dropout_rate: 0.1 | ||
att_layer_num: 16 | ||
kernel_size: 11 | ||
sanm_shfit: 0 | ||
|
||
lfr_conf: | ||
lfr_m: 7 | ||
lfr_n: 6 | ||
|
||
cif_predictor_conf: | ||
idim: 512 | ||
threshold: 1.0 | ||
l_order: 1 | ||
r_order: 1 | ||
tail_threshold: 0.45 | ||
cnn_groups: 1 | ||
residual: false | ||
|
||
model_conf: | ||
ctc_weight: 0.0 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default window in the FunASR frontend is
hamming
. You can find more details here. However, the default window inkaldi.fbank
ispovey
, as specified here. This different window maybe a little mismatch. As mentioned in line 44 of this document:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pr welcome