Skip to content

Commit

Permalink
revert init_model.py and add init_model in export_jit
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 30, 2023
1 parent 9e810d1 commit b765c12
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
28 changes: 27 additions & 1 deletion wenet/paraformer/ali_paraformer/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import argparse
import torch
import yaml
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.model import (AliParaformer, SanmDecoer,
SanmEncoder)
from wenet.transformer.cmvn import GlobalCMVN
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
from wenet.utils.cmvn import load_cmvn


def get_args():
Expand All @@ -23,6 +27,28 @@ def get_args():
return args


def init_model(configs):
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder = SanmEncoder(global_cmvn=global_cmvn,
input_size=configs['lfr_conf']['lfr_m'] * input_dim,
**configs['encoder_conf'])
decoder = decoder = SanmDecoer(vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
predictor = Predictor(**configs['cif_predictor_conf'])
model = AliParaformer(
encoder=encoder,
decoder=decoder,
predictor=predictor,
)
return model


def main():

args = get_args()
Expand Down
35 changes: 6 additions & 29 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import torch

from wenet.paraformer.ali_paraformer.model import SanmDecoer, SanmEncoder
from wenet.k2.model import K2Model

from wenet.transducer.joint import TransducerJoint
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor,
RNNPredictor)
Expand All @@ -31,7 +29,6 @@
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.paraformer.paraformer import Paraformer
from wenet.paraformer.ali_paraformer.model import AliParaformer
from wenet.cif.predictor import Predictor
from wenet.utils.cmvn import load_cmvn

Expand Down Expand Up @@ -74,24 +71,13 @@ def init_model(configs):
encoder = EBranchformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
elif encoder_type == 'SanmEncoder':
assert 'lfr_conf' in configs
encoder = SanmEncoder(global_cmvn=global_cmvn,
input_size=configs['lfr_conf']['lfr_m'] *
input_dim,
**configs['encoder_conf'])
else:
encoder = TransformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size, encoder.output_size(),
**configs['decoder_conf'])
elif decoder_type == 'SanmDecoder':
assert isinstance(encoder, SanmEncoder)
decoder = SanmDecoer(vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
else:
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
Expand Down Expand Up @@ -131,21 +117,12 @@ def init_model(configs):
**configs['model_conf'])
elif 'paraformer' in configs:
predictor = Predictor(**configs['cif_predictor_conf'])
if isinstance(encoder, SanmEncoder):
assert isinstance(decoder, SanmDecoer)
# NOTE(Mddct): only support inference for now
model = AliParaformer(
encoder=encoder,
decoder=decoder,
predictor=predictor,
)
else:
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
else:
print(configs)
if configs.get('lfmmi_dir', '') != '':
Expand Down

0 comments on commit b765c12

Please sign in to comment.