Skip to content

Commit

Permalink
Update presets for GPT-SoVITS
Browse files Browse the repository at this point in the history
Update logger formatter

Update yaml saving enoding to

Update checkpoint loading info in gpt_sovits.py

Add default presets to two

Fix loaded_path
  • Loading branch information
Artrajz committed Feb 9, 2024
1 parent 9d4916e commit 96ad4de
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 57 deletions.
22 changes: 12 additions & 10 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import shutil
import string
import sys
import traceback
from dataclasses import dataclass, field, asdict, fields, is_dataclass
from typing import List, Union, Optional, Dict

Expand Down Expand Up @@ -175,16 +176,15 @@ class GPTSoVitsConfig(AsDictMixin):
lang: str = "auto"
format: str = "wav"
segment_size: int = 50
presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset()})
presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset(),
"default2": GPTSoVitsPreset()})

def update_config(self, new_config_dict):
for field in fields(self):
field_name = field.name
field_type = field.type

if field_name in new_config_dict:
new_value = new_config_dict[field_name]

if is_dataclass(field_type):
if isinstance(new_value, list):
# If the field type is a dataclass and the new value is a list
Expand All @@ -198,6 +198,7 @@ def update_config(self, new_config_dict):
else:
if field_type == Dict[str, GPTSoVitsPreset]:
new_dict = {}

for k, v in new_value.items():
refer_wav_path = v.get("refer_wav_path")
prompt_text = v.get("prompt_text")
Expand Down Expand Up @@ -394,17 +395,17 @@ def get_id(self):
class Config(AsDictMixin):
abs_path: str = ABS_PATH
http_service: HttpService = HttpService()
log_config: LogConfig = LogConfig()
model_config: ModelConfig = ModelConfig()
tts_config: TTSConfig = TTSConfig()
admin: User = User()
system: System = System()
log_config: LogConfig = LogConfig()
language_identification: LanguageIdentification = LanguageIdentification()
vits_config: VitsConfig = VitsConfig()
w2v2_vits_config: W2V2VitsConfig = W2V2VitsConfig()
hubert_vits_config: HuBertVitsConfig = HuBertVitsConfig()
bert_vits2_config: BertVits2Config = BertVits2Config()
gpt_sovits_config: GPTSoVitsConfig = GPTSoVitsConfig()
model_config: ModelConfig = ModelConfig()
tts_config: TTSConfig = TTSConfig()
admin: User = User()

def asdict(self):
data = {}
Expand Down Expand Up @@ -437,7 +438,7 @@ def load_config():
else:
try:
logging.info("Loading config...")
with open(config_path, 'r') as f:
with open(config_path, 'r', encoding='utf-8') as f:
loaded_config = yaml.safe_load(f)
config = Config()

Expand All @@ -455,12 +456,13 @@ def load_config():

return config
except Exception as e:
logging.error(traceback.print_exc())
ValueError(e)

@staticmethod
def save_config(config):
temp_filename = os.path.join(Config.abs_path, "config.yaml.tmp")
with open(temp_filename, 'w') as f:
yaml.safe_dump(config.asdict(), f, default_style=None)
with open(temp_filename, 'w', encoding='utf-8') as f:
yaml.dump(config.asdict(), f, allow_unicode=True, default_style='', sort_keys=False)
shutil.move(temp_filename, os.path.join(Config.abs_path, "config.yaml"))
logging.info(f"Config is saved.")
10 changes: 6 additions & 4 deletions gpt_sovits/gpt_sovits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os.path
import re

import librosa
Expand Down Expand Up @@ -61,12 +62,13 @@ def load_weight(self, saved_state_dict, model):

def load_sovits(self, sovits_path):
# self.n_semantic = 1024

logging.info(f"Loaded checkpoint '{sovits_path}'")
dict_s2 = torch.load(sovits_path, map_location=self.device)
self.hps = dict_s2["config"]
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.speakers = [self.hps.get("name")]
# self.speakers = [self.hps.get("name")] # 从模型配置中获取名字
self.speakers = [os.path.basename(os.path.dirname(self.sovits_path))] # 用模型文件夹作为名字

self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
Expand All @@ -83,6 +85,7 @@ def load_sovits(self, sovits_path):
self.load_weight(dict_s2['weight'], self.vq_model)

def load_gpt(self, gpt_path):
logging.info(f"Loaded checkpoint '{gpt_path}'")
dict_s1 = torch.load(gpt_path, map_location=self.device)

self.gpt_config = dict_s1["config"]
Expand All @@ -98,7 +101,7 @@ def load_gpt(self, gpt_path):
self.t2s_model.eval()

total = sum([param.nelement() for param in self.t2s_model.parameters()])
logging.info("Number of parameter: %.2fM" % (total / 1e6))
logging.info(f"Number of parameter: {total / 1e6:.2f}M")

def get_speakers(self):
return self.speakers
Expand Down Expand Up @@ -150,7 +153,6 @@ def get_first(self, text):
text = re.split(pattern, text)[0].strip()
return text


def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, prompt_lang):
# t0 = ttime()

Expand Down
14 changes: 7 additions & 7 deletions gpt_sovits/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def __delattr__(self, item):
raise AttributeError(f"Attribute {item} not found")


def load_audio(file, sr=16000):
try:
y, sr = librosa.load(file, sr=sr, dtype=np.float32)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")

return y.flatten(), sr
# def load_audio(file, sr=16000):
# try:
# y, sr = librosa.load(file, sr=sr, dtype=np.float32)
# except Exception as e:
# raise RuntimeError(f"Failed to load audio: {e}")
#
# return y.flatten(), sr

# import ffmpeg
# import numpy as np
Expand Down
21 changes: 18 additions & 3 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, warning_messages):
def filter(self, record):
return all(msg not in record.getMessage() for msg in self.warning_messages)


# 过滤警告
ignore_warning_messages = ["stft with return_complex=False is deprecated",
"1Torch was not compiled with flash attention",
Expand All @@ -27,20 +28,34 @@ def filter(self, record):
for message in ignore_warning_messages:
warnings.filterwarnings(action="ignore", message=message)


class WarningFilter(logging.Filter):
def filter(self, record):
return record.levelno != logging.WARNING


logzero.loglevel(logging.WARNING)
logger = logging.getLogger("vits-simple-api")
level = getattr(config, "LOGGING_LEVEL", "DEBUG")
level = config.log_config.logging_level.upper()
level_dict = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING, 'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL}
logging.getLogger().setLevel(level_dict[level])

# formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(module)s.%(funcName)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
# formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(module)s.%(funcName)s:%(lineno)d] %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S')

# 根据日志级别选择日志格式
if level == "DEBUG":
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]',
datefmt='%Y-%m-%d %H:%M:%S')
elif level == "INFO":
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
else:
# 如果日志级别既不是DEBUG也不是INFO,则使用默认的日志格式
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]',
datefmt='%Y-%m-%d %H:%M:%S')

logs_path = os.path.join(config.abs_path, config.log_config.logs_path)
os.makedirs(logs_path, exist_ok=True)
Expand Down
33 changes: 16 additions & 17 deletions manager/ModelManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def unload_model(self, model_type_value: str, model_id: str):
self.notify("model_unloaded", model_manager=self)
self.logger.info(f"Unloading success.")
except Exception as e:
self.logger.info(f"Unloading failed. {e}")
logging.error(traceback.print_exc())
logging.error(f"Unloading failed. {e}")
state = False

return state
Expand Down Expand Up @@ -591,9 +592,13 @@ def scan_unload_path(self):
for model in self.get_models_path():
# 只取已加载的模型路径
if model.get("model_type") == ModelType.GPT_SOVITS:
loaded_paths_2.append((model.get("sovits_path"), model.get("gpt_path")))
sovits_path, gpt_path = self.absolute_to_relative_path(model.get("sovits_path"),
model.get("gpt_path"))
sovits_path, gpt_path = sovits_path.replace("\\", "/"), gpt_path.replace("\\", "/")
loaded_paths_2.append((sovits_path, gpt_path))
else:
loaded_paths.append(model.get("model_path"))
model_path = self.absolute_to_relative_path(model.get("model_path"))[0].replace("\\", "/")
loaded_paths.append(model_path)

for info in all_paths:
# 将绝对路径修改为相对路径,并将分隔符格式化为'/'
Expand All @@ -609,29 +614,23 @@ def scan_unload_path(self):
model_path, config_path = self.absolute_to_relative_path(info.get("model_path"),
info.get("config_path"))
model_path, config_path = model_path.replace("\\", "/"), config_path.replace("\\", "/")

if not self.is_path_loaded(info.get("model_path"), loaded_paths):
if not self.is_path_loaded(model_path, loaded_paths):
info.update({"model_path": model_path, "config_path": config_path})
unload_paths.append(info)

return unload_paths

def is_path_loaded(self, paths, loaded_paths):
if len(paths) == 1:
path = paths
normalized_path = os.path.normpath(path)

if len(paths) == 2:
sovits_path, gpt_path = paths
for loaded_path in loaded_paths:
normalized_loaded_path = os.path.normpath(loaded_path)
if normalized_path == normalized_loaded_path:
if sovits_path == loaded_path[0] and gpt_path == loaded_path[1]:
return True
elif len(paths) == 2:
sovits_path, gpt_path = paths
sovits_path, gpt_path = os.path.normpath(sovits_path), os.path.normpath(gpt_path)
else:
path = paths

for loaded_path in loaded_paths:
normalized_sovits_path = os.path.normpath(self.absolute_to_relative_path(loaded_path[0])[0])
normalized_gpt_path = os.path.normpath(self.absolute_to_relative_path(loaded_path[1])[0])
if sovits_path == normalized_sovits_path and gpt_path == normalized_gpt_path:
if path == loaded_path:
return True

return False
1 change: 1 addition & 0 deletions manager/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def load_ssl(self, max_retries=3):
self.ssl_model["model"] = self.ssl_model["model"].half()

self.ssl_model["model"] = self.ssl_model["model"].to(self.device)
self.ssl_model["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
Expand Down
58 changes: 53 additions & 5 deletions tts_app/static/js/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ function getLink() {
let text_prompt = "";
let style_text = "";
let style_weight = "";
let prompt_text = ""
let prompt_text = null;
let prompt_lang = null;
let preset = null;

if (currentModelPage == 1 || currentModelPage == 2 || currentModelPage == 3) {
length = document.getElementById("input_length" + currentModelPage).value;
Expand All @@ -98,6 +100,8 @@ function getLink() {
url += "/voice/bert-vits2?id=" + id;
} else if (currentModelPage == 4) {
prompt_text = document.getElementById('input_prompt_text4').value;
prompt_lang = document.getElementById('input_prompt_lang4').value;
preset = document.getElementById('input_preset4').value;
url += "/voice/gpt-sovits?id=" + id;

} else {
Expand Down Expand Up @@ -148,6 +152,15 @@ function getLink() {
url += "&style_text=" + style_text;
if (style_weight !== null && style_weight !== "")
url += "&style_weight=" + style_weight;
} else if (currentModelPage == 4) {
if (prompt_lang !== null && prompt_lang !== "")
url += "&prompt_lang=" + prompt_lang;
if (prompt_text !== null && prompt_text !== "")
url += "&prompt_text=" + prompt_text;
if (preset !== null && preset !== "")
url += "&preset=" + preset;


}

if (api_key != "") {
Expand Down Expand Up @@ -233,6 +246,7 @@ function setAudioSourceByPost() {
let style_weight = "";
let prompt_text = null;
let prompt_lang = null;
let preset = null;

if (currentModelPage == 1 || currentModelPage == 2 || currentModelPage == 3) {
length = $("#input_length" + currentModelPage).val();
Expand Down Expand Up @@ -264,6 +278,7 @@ function setAudioSourceByPost() {
url = baseUrl + "/voice/gpt-sovits";
prompt_text = $("#input_prompt_text4").val()
prompt_lang = $("#input_prompt_lang4").val()
preset = $("#input_preset4").val()
}


Expand Down Expand Up @@ -308,6 +323,9 @@ function setAudioSourceByPost() {
if (currentModelPage == 4 && prompt_lang) {
formData.append('prompt_lang', prompt_lang);
}
if (currentModelPage == 4 && preset) {
formData.append('preset', preset);
}

let downloadButton = document.getElementById("downloadButton" + currentModelPage);

Expand All @@ -334,8 +352,10 @@ function setAudioSourceByPost() {
downloadButton.disabled = false;
},
error: function (error) {
console.error('Error:', error);
alert("无法获取音频数据,请查看日志!");
// console.error('Error:', error);
let message = "无法获取音频数据,请查看日志!";
console.log(message)
alert(message);
downloadButton.disabled = true;
}
});
Expand Down Expand Up @@ -396,11 +416,38 @@ function showModelContentBasedOnStatus() {
}

function updatePlaceholders(config, page) {
for (var key in config) {
$("#input_" + key + page).attr("placeholder", config[key]);
for (let key in config) {
if (key == "presets") {
let data = config[key];
let selectElement = $("#input_preset" + page);
selectElement.empty(); // 清除现有的选项
for (let name in data) {
let preset_value = data[name];
let preset = `[${name}] audio: ${preset_value["refer_wav_path"]}`;
// 创建preset
let option = $("<option>", {
value: name,
text: preset,
'data-prompt-lang': preset_value["prompt_lang"],
'data-prompt-text': preset_value["prompt_text"]
});
selectElement.append(option);
}
// 当选择改变时更新输入预设的值
selectElement.change(function () {
let selectedOption = $(this).find(":selected");
let promptLang = selectedOption.data("prompt-lang");
let promptText = selectedOption.data("prompt-text");
$("#input_prompt_lang" + page).val(promptLang);
$("#input_prompt_text" + page).val(promptText);
});
} else {
$("#input_" + key + page).attr("placeholder", config[key]);
}
}
}


function setDefaultParameter() {
$.ajax({
url: "/voice/default_parameter",
Expand All @@ -411,6 +458,7 @@ function setDefaultParameter() {
updatePlaceholders(default_parameter.vits_config, 1);
updatePlaceholders(default_parameter.w2v2_vits_config, 2);
updatePlaceholders(default_parameter.bert_vits2_config, 3);
updatePlaceholders(default_parameter.gpt_sovits_config, 4);
},
error: function (error) {
}
Expand Down
Loading

0 comments on commit 96ad4de

Please sign in to comment.