diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 85d989e44..804acddb5 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -11,18 +11,22 @@ vad_model_revision="v2.0.4", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.4", - # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", - # spk_model_revision="v2.0.2", + spk_model="damo/speech_campplus_sv_zh-cn_16k-common", + spk_model_revision="v2.0.2", ) # example1 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭', + # preset_spk_num=2, # sentence_timestamp=True, # return sentence level information when spk_model is not given ) print(res) + +''' +# tensor or numpy as input # example2 import torchaudio import os @@ -38,4 +42,4 @@ wav_file = os.path.join(model.model_path, "example/asr_example.wav") speech, sample_rate = soundfile.read(wav_file) res = model.generate(input=[speech], batch_size_s=300, is_final=True) - +''' \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index d99fc5613..8007d6e4e 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -121,9 +121,6 @@ def __init__(self, **kwargs): if spk_mode not in ["default", "vad_segment", "punc_segment"]: logging.error("spk_mode should be one of default, vad_segment and punc_segment.") self.spk_mode = spk_mode - self.preset_spk_num = kwargs.get("preset_spk_num", None) - if self.preset_spk_num: - logging.warning("Using preset speaker number: {}".format(self.preset_spk_num)) self.kwargs = kwargs self.model = model @@ -391,7 +388,7 @@ def inference_with_vad(self, input, input_len=None, **cfg): if self.spk_model is not None: all_segments = sorted(all_segments, key=lambda x: x[0]) spk_embedding = result['spk_embedding'] - labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num) + labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs['preset_spk_num']) del result['spk_embedding'] sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) if self.spk_mode == 'vad_segment': # recover sentence_list