From 96ad4de0651984dde6031c9610338f6211702fa8 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Fri, 9 Feb 2024 14:51:34 +0800 Subject: [PATCH] Update presets for GPT-SoVITS Update logger formatter Update yaml saving enoding to Update checkpoint loading info in gpt_sovits.py Add default presets to two Fix loaded_path --- config.py | 22 ++++++------ gpt_sovits/gpt_sovits.py | 10 +++--- gpt_sovits/utils.py | 14 ++++---- logger.py | 21 +++++++++-- manager/ModelManager.py | 33 +++++++++-------- manager/model_handler.py | 1 + tts_app/static/js/index.js | 58 +++++++++++++++++++++++++++--- tts_app/templates/pages/index.html | 11 +++++- tts_app/voice_api/views.py | 39 ++++++++++++++------ 9 files changed, 152 insertions(+), 57 deletions(-) diff --git a/config.py b/config.py index 2086a98..483c8e7 100644 --- a/config.py +++ b/config.py @@ -13,6 +13,7 @@ import shutil import string import sys +import traceback from dataclasses import dataclass, field, asdict, fields, is_dataclass from typing import List, Union, Optional, Dict @@ -175,16 +176,15 @@ class GPTSoVitsConfig(AsDictMixin): lang: str = "auto" format: str = "wav" segment_size: int = 50 - presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset()}) + presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset(), + "default2": GPTSoVitsPreset()}) def update_config(self, new_config_dict): for field in fields(self): field_name = field.name field_type = field.type - if field_name in new_config_dict: new_value = new_config_dict[field_name] - if is_dataclass(field_type): if isinstance(new_value, list): # If the field type is a dataclass and the new value is a list @@ -198,6 +198,7 @@ def update_config(self, new_config_dict): else: if field_type == Dict[str, GPTSoVitsPreset]: new_dict = {} + for k, v in new_value.items(): refer_wav_path = v.get("refer_wav_path") prompt_text = v.get("prompt_text") @@ -394,17 +395,17 @@ def get_id(self): class Config(AsDictMixin): abs_path: str = ABS_PATH http_service: HttpService = HttpService() - log_config: LogConfig = LogConfig() + model_config: ModelConfig = ModelConfig() + tts_config: TTSConfig = TTSConfig() + admin: User = User() system: System = System() + log_config: LogConfig = LogConfig() language_identification: LanguageIdentification = LanguageIdentification() vits_config: VitsConfig = VitsConfig() w2v2_vits_config: W2V2VitsConfig = W2V2VitsConfig() hubert_vits_config: HuBertVitsConfig = HuBertVitsConfig() bert_vits2_config: BertVits2Config = BertVits2Config() gpt_sovits_config: GPTSoVitsConfig = GPTSoVitsConfig() - model_config: ModelConfig = ModelConfig() - tts_config: TTSConfig = TTSConfig() - admin: User = User() def asdict(self): data = {} @@ -437,7 +438,7 @@ def load_config(): else: try: logging.info("Loading config...") - with open(config_path, 'r') as f: + with open(config_path, 'r', encoding='utf-8') as f: loaded_config = yaml.safe_load(f) config = Config() @@ -455,12 +456,13 @@ def load_config(): return config except Exception as e: + logging.error(traceback.print_exc()) ValueError(e) @staticmethod def save_config(config): temp_filename = os.path.join(Config.abs_path, "config.yaml.tmp") - with open(temp_filename, 'w') as f: - yaml.safe_dump(config.asdict(), f, default_style=None) + with open(temp_filename, 'w', encoding='utf-8') as f: + yaml.dump(config.asdict(), f, allow_unicode=True, default_style='', sort_keys=False) shutil.move(temp_filename, os.path.join(Config.abs_path, "config.yaml")) logging.info(f"Config is saved.") diff --git a/gpt_sovits/gpt_sovits.py b/gpt_sovits/gpt_sovits.py index 0a660e1..08e90f0 100644 --- a/gpt_sovits/gpt_sovits.py +++ b/gpt_sovits/gpt_sovits.py @@ -1,4 +1,5 @@ import logging +import os.path import re import librosa @@ -61,12 +62,13 @@ def load_weight(self, saved_state_dict, model): def load_sovits(self, sovits_path): # self.n_semantic = 1024 - + logging.info(f"Loaded checkpoint '{sovits_path}'") dict_s2 = torch.load(sovits_path, map_location=self.device) self.hps = dict_s2["config"] self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" - self.speakers = [self.hps.get("name")] + # self.speakers = [self.hps.get("name")] # 从模型配置中获取名字 + self.speakers = [os.path.basename(os.path.dirname(self.sovits_path))] # 用模型文件夹作为名字 self.vq_model = SynthesizerTrn( self.hps.data.filter_length // 2 + 1, @@ -83,6 +85,7 @@ def load_sovits(self, sovits_path): self.load_weight(dict_s2['weight'], self.vq_model) def load_gpt(self, gpt_path): + logging.info(f"Loaded checkpoint '{gpt_path}'") dict_s1 = torch.load(gpt_path, map_location=self.device) self.gpt_config = dict_s1["config"] @@ -98,7 +101,7 @@ def load_gpt(self, gpt_path): self.t2s_model.eval() total = sum([param.nelement() for param in self.t2s_model.parameters()]) - logging.info("Number of parameter: %.2fM" % (total / 1e6)) + logging.info(f"Number of parameter: {total / 1e6:.2f}M") def get_speakers(self): return self.speakers @@ -150,7 +153,6 @@ def get_first(self, text): text = re.split(pattern, text)[0].strip() return text - def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, prompt_lang): # t0 = ttime() diff --git a/gpt_sovits/utils.py b/gpt_sovits/utils.py index b5b283d..6f2a841 100644 --- a/gpt_sovits/utils.py +++ b/gpt_sovits/utils.py @@ -36,13 +36,13 @@ def __delattr__(self, item): raise AttributeError(f"Attribute {item} not found") -def load_audio(file, sr=16000): - try: - y, sr = librosa.load(file, sr=sr, dtype=np.float32) - except Exception as e: - raise RuntimeError(f"Failed to load audio: {e}") - - return y.flatten(), sr +# def load_audio(file, sr=16000): +# try: +# y, sr = librosa.load(file, sr=sr, dtype=np.float32) +# except Exception as e: +# raise RuntimeError(f"Failed to load audio: {e}") +# +# return y.flatten(), sr # import ffmpeg # import numpy as np diff --git a/logger.py b/logger.py index 9480c8c..c9d93dc 100644 --- a/logger.py +++ b/logger.py @@ -17,6 +17,7 @@ def __init__(self, warning_messages): def filter(self, record): return all(msg not in record.getMessage() for msg in self.warning_messages) + # 过滤警告 ignore_warning_messages = ["stft with return_complex=False is deprecated", "1Torch was not compiled with flash attention", @@ -27,20 +28,34 @@ def filter(self, record): for message in ignore_warning_messages: warnings.filterwarnings(action="ignore", message=message) + class WarningFilter(logging.Filter): def filter(self, record): return record.levelno != logging.WARNING + logzero.loglevel(logging.WARNING) logger = logging.getLogger("vits-simple-api") -level = getattr(config, "LOGGING_LEVEL", "DEBUG") +level = config.log_config.logging_level.upper() level_dict = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING, 'ERROR': logging.ERROR, 'CRITICAL': logging.CRITICAL} logging.getLogger().setLevel(level_dict[level]) # formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s') -formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(module)s.%(funcName)s:%(lineno)d] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') +# formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(module)s.%(funcName)s:%(lineno)d] %(message)s', +# datefmt='%Y-%m-%d %H:%M:%S') + +# 根据日志级别选择日志格式 +if level == "DEBUG": + formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]', + datefmt='%Y-%m-%d %H:%M:%S') +elif level == "INFO": + formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') +else: + # 如果日志级别既不是DEBUG也不是INFO,则使用默认的日志格式 + formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]', + datefmt='%Y-%m-%d %H:%M:%S') logs_path = os.path.join(config.abs_path, config.log_config.logs_path) os.makedirs(logs_path, exist_ok=True) diff --git a/manager/ModelManager.py b/manager/ModelManager.py index 4e5a8d3..3ef92b4 100644 --- a/manager/ModelManager.py +++ b/manager/ModelManager.py @@ -366,7 +366,8 @@ def unload_model(self, model_type_value: str, model_id: str): self.notify("model_unloaded", model_manager=self) self.logger.info(f"Unloading success.") except Exception as e: - self.logger.info(f"Unloading failed. {e}") + logging.error(traceback.print_exc()) + logging.error(f"Unloading failed. {e}") state = False return state @@ -591,9 +592,13 @@ def scan_unload_path(self): for model in self.get_models_path(): # 只取已加载的模型路径 if model.get("model_type") == ModelType.GPT_SOVITS: - loaded_paths_2.append((model.get("sovits_path"), model.get("gpt_path"))) + sovits_path, gpt_path = self.absolute_to_relative_path(model.get("sovits_path"), + model.get("gpt_path")) + sovits_path, gpt_path = sovits_path.replace("\\", "/"), gpt_path.replace("\\", "/") + loaded_paths_2.append((sovits_path, gpt_path)) else: - loaded_paths.append(model.get("model_path")) + model_path = self.absolute_to_relative_path(model.get("model_path"))[0].replace("\\", "/") + loaded_paths.append(model_path) for info in all_paths: # 将绝对路径修改为相对路径,并将分隔符格式化为'/' @@ -609,29 +614,23 @@ def scan_unload_path(self): model_path, config_path = self.absolute_to_relative_path(info.get("model_path"), info.get("config_path")) model_path, config_path = model_path.replace("\\", "/"), config_path.replace("\\", "/") - - if not self.is_path_loaded(info.get("model_path"), loaded_paths): + if not self.is_path_loaded(model_path, loaded_paths): info.update({"model_path": model_path, "config_path": config_path}) unload_paths.append(info) return unload_paths def is_path_loaded(self, paths, loaded_paths): - if len(paths) == 1: - path = paths - normalized_path = os.path.normpath(path) - + if len(paths) == 2: + sovits_path, gpt_path = paths for loaded_path in loaded_paths: - normalized_loaded_path = os.path.normpath(loaded_path) - if normalized_path == normalized_loaded_path: + if sovits_path == loaded_path[0] and gpt_path == loaded_path[1]: return True - elif len(paths) == 2: - sovits_path, gpt_path = paths - sovits_path, gpt_path = os.path.normpath(sovits_path), os.path.normpath(gpt_path) + else: + path = paths + for loaded_path in loaded_paths: - normalized_sovits_path = os.path.normpath(self.absolute_to_relative_path(loaded_path[0])[0]) - normalized_gpt_path = os.path.normpath(self.absolute_to_relative_path(loaded_path[1])[0]) - if sovits_path == normalized_sovits_path and gpt_path == normalized_gpt_path: + if path == loaded_path: return True return False diff --git a/manager/model_handler.py b/manager/model_handler.py index c3de163..abfb24a 100644 --- a/manager/model_handler.py +++ b/manager/model_handler.py @@ -301,6 +301,7 @@ def load_ssl(self, max_retries=3): self.ssl_model["model"] = self.ssl_model["model"].half() self.ssl_model["model"] = self.ssl_model["model"].to(self.device) + self.ssl_model["reference_count"] = 1 logging.info(f"Success loading: {model_path}") break except Exception as e: diff --git a/tts_app/static/js/index.js b/tts_app/static/js/index.js index a2f8f3b..8c55b2d 100644 --- a/tts_app/static/js/index.js +++ b/tts_app/static/js/index.js @@ -74,7 +74,9 @@ function getLink() { let text_prompt = ""; let style_text = ""; let style_weight = ""; - let prompt_text = "" + let prompt_text = null; + let prompt_lang = null; + let preset = null; if (currentModelPage == 1 || currentModelPage == 2 || currentModelPage == 3) { length = document.getElementById("input_length" + currentModelPage).value; @@ -98,6 +100,8 @@ function getLink() { url += "/voice/bert-vits2?id=" + id; } else if (currentModelPage == 4) { prompt_text = document.getElementById('input_prompt_text4').value; + prompt_lang = document.getElementById('input_prompt_lang4').value; + preset = document.getElementById('input_preset4').value; url += "/voice/gpt-sovits?id=" + id; } else { @@ -148,6 +152,15 @@ function getLink() { url += "&style_text=" + style_text; if (style_weight !== null && style_weight !== "") url += "&style_weight=" + style_weight; + } else if (currentModelPage == 4) { + if (prompt_lang !== null && prompt_lang !== "") + url += "&prompt_lang=" + prompt_lang; + if (prompt_text !== null && prompt_text !== "") + url += "&prompt_text=" + prompt_text; + if (preset !== null && preset !== "") + url += "&preset=" + preset; + + } if (api_key != "") { @@ -233,6 +246,7 @@ function setAudioSourceByPost() { let style_weight = ""; let prompt_text = null; let prompt_lang = null; + let preset = null; if (currentModelPage == 1 || currentModelPage == 2 || currentModelPage == 3) { length = $("#input_length" + currentModelPage).val(); @@ -264,6 +278,7 @@ function setAudioSourceByPost() { url = baseUrl + "/voice/gpt-sovits"; prompt_text = $("#input_prompt_text4").val() prompt_lang = $("#input_prompt_lang4").val() + preset = $("#input_preset4").val() } @@ -308,6 +323,9 @@ function setAudioSourceByPost() { if (currentModelPage == 4 && prompt_lang) { formData.append('prompt_lang', prompt_lang); } + if (currentModelPage == 4 && preset) { + formData.append('preset', preset); + } let downloadButton = document.getElementById("downloadButton" + currentModelPage); @@ -334,8 +352,10 @@ function setAudioSourceByPost() { downloadButton.disabled = false; }, error: function (error) { - console.error('Error:', error); - alert("无法获取音频数据,请查看日志!"); + // console.error('Error:', error); + let message = "无法获取音频数据,请查看日志!"; + console.log(message) + alert(message); downloadButton.disabled = true; } }); @@ -396,11 +416,38 @@ function showModelContentBasedOnStatus() { } function updatePlaceholders(config, page) { - for (var key in config) { - $("#input_" + key + page).attr("placeholder", config[key]); + for (let key in config) { + if (key == "presets") { + let data = config[key]; + let selectElement = $("#input_preset" + page); + selectElement.empty(); // 清除现有的选项 + for (let name in data) { + let preset_value = data[name]; + let preset = `[${name}] audio: ${preset_value["refer_wav_path"]}`; + // 创建preset + let option = $("