Skip to content

Commit

Permalink
update GPT-SoVITS presets
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Feb 8, 2024
1 parent c5370bb commit 9d4916e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
60 changes: 55 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import string
import sys
from dataclasses import dataclass, field, asdict, fields, is_dataclass
from typing import List, Union, Optional
from typing import List, Union, Optional, Dict

import torch
import yaml
Expand Down Expand Up @@ -58,6 +58,10 @@ def asdict(self):
data[attr] = []
for item in value:
data[attr].append(item.asdict())
elif isinstance(value, dict):
data[attr] = {}
for k, v in value.items():
data[attr].update({k: v.asdict()})
else:
data[attr] = value
return data
Expand All @@ -73,7 +77,6 @@ def update_config(self, new_config_dict):

if field_name in new_config_dict:
new_value = new_config_dict[field_name]

if is_dataclass(field_type):
if isinstance(new_value, list):
# If the field type is a dataclass and the new value is a list
Expand All @@ -84,6 +87,7 @@ def update_config(self, new_config_dict):
# If the field type is a dataclass but not a list, recursively update the dataclass
nested_config = getattr(self, field_name)
nested_config.update_config(new_value)
setattr(self, field_name, nested_config)
else:
if field_type == bool:
new_value = str(new_value).lower() == "true"
Expand Down Expand Up @@ -156,6 +160,13 @@ class BertVits2Config(AsDictMixin):
torch_data_type: str = ""


@dataclass
class GPTSoVitsPreset(AsDictMixin):
refer_wav_path: str = None
prompt_text: str = None
prompt_lang: str = "auto"


@dataclass
class GPTSoVitsConfig(AsDictMixin):
hz: int = 50
Expand All @@ -164,9 +175,48 @@ class GPTSoVitsConfig(AsDictMixin):
lang: str = "auto"
format: str = "wav"
segment_size: int = 50
refer_wav_path: str = ""
prompt_text: str = ""
prompt_lang: str = "auto"
presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset()})

def update_config(self, new_config_dict):
for field in fields(self):
field_name = field.name
field_type = field.type

if field_name in new_config_dict:
new_value = new_config_dict[field_name]

if is_dataclass(field_type):
if isinstance(new_value, list):
# If the field type is a dataclass and the new value is a list
# Convert each element of the list to the corresponding class object
new_value = [field_type(**item) for item in new_value]
setattr(self, field_name, new_value)
else:
# If the field type is a dataclass but not a list, recursively update the dataclass
nested_config = getattr(self, field_name)
nested_config.update_config(new_value)
else:
if field_type == Dict[str, GPTSoVitsPreset]:
new_dict = {}
for k, v in new_value.items():
refer_wav_path = v.get("refer_wav_path")
prompt_text = v.get("prompt_text")
prompt_lang = v.get("prompt_lang")
new_dict.update({k: GPTSoVitsPreset(refer_wav_path, prompt_text, prompt_lang)})
new_value = new_dict

elif field_type == bool:
new_value = str(new_value).lower() == "true"
elif field_type == int:
new_value = int(new_value)
elif field_type == float:
new_value = float(new_value)
elif field_type == str:
new_value = str(new_value)
elif field_type == torch.device:
new_value = torch.device(new_value)

setattr(self, field_name, new_value)


@dataclass
Expand Down
38 changes: 25 additions & 13 deletions tts_app/voice_api/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import time
import uuid
Expand Down Expand Up @@ -536,9 +537,13 @@ def voice_gpt_sovits_api():
format = get_param(request_data, "format", config.gpt_sovits_config.format, str)
segment_size = get_param(request_data, "segment_size", config.gpt_sovits_config.segment_size, int)
reference_audio = request.files.get("reference_audio", None)
refer_wav_path = get_param(request_data, "refer_wav_path", config.gpt_sovits_config.refer_wav_path, str)
prompt_text = get_param(request_data, "prompt_text", config.gpt_sovits_config.prompt_text, str)
prompt_lang = get_param(request_data, "prompt_lang", config.gpt_sovits_config.prompt_lang, str)
preset = get_param(request_data, "preset", "default", str)
refer_wav_path = get_param(request_data, "refer_wav_path",
config.gpt_sovits_config.presets.get("default").refer_wav_path, str)
prompt_text = get_param(request_data, "prompt_text",
config.gpt_sovits_config.presets.get("default").prompt_text, str)
prompt_lang = get_param(request_data, "prompt_lang",
config.gpt_sovits_config.presets.get("default").prompt_lang, str)
# use_streaming = get_param(request_data, 'streaming', config.gpt_sovits_config.use_streaming, bool)
except Exception as e:
logger.error(f"[{ModelType.GPT_SOVITS.value}] {e}")
Expand Down Expand Up @@ -572,20 +577,27 @@ def voice_gpt_sovits_api():

# 检查参考音频
if check_is_none(reference_audio):
if not check_is_none(refer_wav_path):
reference_audio = load_audio(config.gpt_sovits_config.refer_wav_path)
prompt_text, prompt_lang = (
config.gpt_sovits_config.prompt_text,
config.gpt_sovits_config.prompt_lang,
)
else:
reference_audio, reference_audio_sr = load_audio(config.gpt_sovits_config.refer_wav_path)
if preset != "default":
refer_preset = config.gpt_sovits_config.presets.get(preset)

if check_is_none(refer_wav_path):
refer_wav_path = refer_preset.refer_wav_path

prompt_text, prompt_lang = refer_preset.prompt_text, refer_preset.prompt_lang

try:
reference_audio, reference_audio_sr = load_audio(refer_wav_path)
except Exception as e:
logging.error(e)
return make_response(jsonify({"status": "error", "message": "Loading refer_wav_path error."}), 400)


reference_audio, reference_audio_sr = librosa.load(reference_audio, sr=None, dtype=np.float32)
reference_audio = reference_audio.flatten()

if check_is_none(reference_audio, prompt_text, prompt_lang):

if check_is_none(reference_audio, prompt_text):
# 未指定参考音频且配置文件无预设
logging.error("No reference audio specified, and no default setting in the config.")
return make_response(jsonify(
{"status": "error", "message": "No reference audio specified, and no default setting in the config."}),
400)
Expand Down

0 comments on commit 9d4916e

Please sign in to comment.