Skip to content

Commit

Permalink
VITS: lang_list is not defined & optional dynamic_loading (#187)
Browse files Browse the repository at this point in the history
* Update ModelManager.py

dynamic_loading default value

* Update views.py

defined lang_list

* Update TTSManager.py

'dynamic_loading' is optional

* Update views.py
  • Loading branch information
const-volatile authored Jan 5, 2025
1 parent 7b4cb9c commit 66ac048
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion manager/ModelManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _load_model_from_path(self, tts_model):
"device": self.device
}
if model_type == ModelType.VITS:
model_args["dynamic_loading"] = tts_model["dynamic_loading"]
model_args["dynamic_loading"] = tts_model.get("dynamic_loading", False)

model_class = self.model_class_map[model_type]
model = model_class(**model_args)
Expand Down
4 changes: 2 additions & 2 deletions manager/TTSManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def process_ssml_infer_task(self, tasks, format):

def vits_infer(self, state, encode=True):
model = self.get_model(ModelType.VITS, state["id"])
if config.vits_config.dynamic_loading:
if hasattr(config.vits_config, 'dynamic_loading') and config.vits_config.dynamic_loading:
model.load_model()
state["id"] = self.get_real_id(ModelType.VITS, state["id"]) # Change to real id
# 去除所有多余的空白字符
Expand Down Expand Up @@ -299,7 +299,7 @@ def vits_infer(self, state, encode=True):
audios.append(brk)

audio = np.concatenate(audios, axis=0)
if config.vits_config.dynamic_loading:
if hasattr(config.vits_config, 'dynamic_loading') and config.vits_config.dynamic_loading:
model.release_model()
return self.encode(sampling_rate, audio, state["format"]) if encode else audio

Expand Down
4 changes: 3 additions & 1 deletion tts_app/voice_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def voice_vits_api():
if (lang_detect := config.language_identification.language_automatic_detect) and isinstance(lang_detect, list):
speaker_lang = lang_detect

lang_list = get_lang_list(lang, speaker_lang)

if use_streaming and format.upper() != "MP3":
format = "mp3"
logger.warning("Streaming response only supports MP3 format.")
Expand All @@ -186,7 +188,7 @@ def voice_vits_api():
"noise": noise,
"noisew": noisew,
"segment_size": segment_size,
"lang": lang_list,
"lang": lang_list[0],
"speaker_lang": speaker_lang,
}

Expand Down

0 comments on commit 66ac048

Please sign in to comment.