Skip to content

Commit

Permalink
Update: recognition_model_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Jan 2, 2025
1 parent 01c1a5e commit 96e4787
Showing 1 changed file with 48 additions and 9 deletions.
57 changes: 48 additions & 9 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import secrets
import string
import sys
from json import loads

import torch
import yaml
from typing import List, Union, Optional, Dict, Type
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, Field, ValidationError, field_validator

from contants import ModelType

Expand Down Expand Up @@ -134,40 +135,35 @@ class ResourcePathsConfig(BaseModel):


class BaseModelConfig(BaseModel):
model_type: str
model_type: Optional[str]

class Config:
protected_namespaces = ()


class VITSModelConfig(BaseModelConfig):
model_type: str = ModelType.VITS
vits_path: str = None
config_path: str = None


class W2V2VITSModelConfig(BaseModelConfig):
model_type: str = ModelType.W2V2_VITS
vits_path: str = None
config_path: str = None


class HuBertVITSModelConfig(BaseModelConfig):
model_type: str = ModelType.HUBERT_VITS
vits_path: str = None
config_path: str = None


class BertVITS2ModelConfig(BaseModelConfig):
model_type: str = ModelType.BERT_VITS2
vits_path: str = None
config_path: str = None


class GPTSoVITSModelConfig(BaseModelConfig):
model_type: str = ModelType.GPT_SOVITS
gpt_path: str = None
sovits_path: str = None
vits_path: str = None
t2s_path: str = None


MODEL_TYPE_MAP: Dict[str, Type[BaseModelConfig]] = {
Expand All @@ -190,6 +186,41 @@ class TTSModelConfig(BaseModel):
GPTSoVITSModelConfig,
]] = Field(default_factory=list)

@classmethod
def recognition_model_type_by_config(self, config: dict) -> str:
symbols = config.get("symbols", None)
emotion_embedding = config["data"].get("emotion_embedding", False)

if "use_spk_conditioned_encoder" in config["model"]:
model_type = ModelType.BERT_VITS2
return model_type

if symbols != None:
if not emotion_embedding:
mode_type = ModelType.VITS
else:
mode_type = ModelType.W2V2_VITS
else:
mode_type = ModelType.HUBERT_VITS

return mode_type

@field_validator('tts_models', mode="before")
def infer_model_type(cls, v):
result = []
for model in v:
if 'model_type' not in model:
if 'vits_path' in model and 'config_path' in model:
with open(model["config_path"], 'r', encoding='utf-8') as f:
data = loads(f.read())
model['model_type'] = cls.recognition_model_type_by_config(data)
elif 'vits_path' in model and 't2s_path' in model:
model['model_type'] = ModelType.GPT_SOVITS

model_class = MODEL_TYPE_MAP[model['model_type']]
result.append(model_class(**model))
return result

def add_model(self, model_config: BaseModelConfig):
if not isinstance(model_config, BaseModelConfig):
raise TypeError("model_config must be an instance of BaseModelConfig")
Expand All @@ -205,8 +236,16 @@ def update_tts_models(self, tts_models: list):
for item in tts_models:
tts_model = item["tts_model"]
model_type = tts_model.get("model_type")

if model_type:
model_type = model_type.upper().replace("_", "-")
else:
if tts_model.get("t2s_path"):
model_type = ModelType.GPT_SOVITS
else:
with open(tts_model["config_path"], 'r', encoding='utf-8') as f:
data = f.read()
model_type = self.recognition_model_type(loads(data))
model_class = MODEL_TYPE_MAP.get(ModelType(model_type))
if model_class is not None:
try:
Expand Down

0 comments on commit 96e4787

Please sign in to comment.