Skip to content

Commit

Permalink
Update: VITS dynamic_loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Jan 2, 2025
1 parent 96e4787 commit 00431b2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class Config:
class VITSModelConfig(BaseModelConfig):
vits_path: str = None
config_path: str = None
dynamic_loading: Optional[bool] = False


class W2V2VITSModelConfig(BaseModelConfig):
Expand Down
4 changes: 3 additions & 1 deletion manager/ModelManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def _load_model_from_path(self, tts_model):
"config": hps,
"device": self.device
}
if model_type == ModelType.VITS:
model_args["dynamic_loading"] = tts_model["dynamic_loading"]

model_class = self.model_class_map[model_type]
model = model_class(**model_args)
Expand All @@ -255,7 +257,7 @@ def _load_model_from_path(self, tts_model):
if bert_embedding and self.tts_front is None:
self.load_VITS_PinYin_model(
os.path.join(BASE_DIR, config.system.data_path, config.resource_paths_config.vits_chinese_bert))
if not config.vits_config.dynamic_loading:
if not model.dynamic_loading:
model.load_model()
self.available_tts_model.add(ModelType.VITS)

Expand Down
6 changes: 3 additions & 3 deletions vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class VITS:
def __init__(self, vits_path, config, device="cpu", **kwargs):
def __init__(self, vits_path, config, device="cpu", dynamic_loading=False, **kwargs):
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
Expand All @@ -23,6 +23,7 @@ def __init__(self, vits_path, config, device="cpu", **kwargs):
self.sampling_rate = self.hps_ms.data.sampling_rate
self.device = torch.device(device)
self.vits_path = vits_path
self.dynamic_loading = dynamic_loading

# load checkpoint
# self.load_model()
Expand All @@ -39,10 +40,9 @@ def load_model(self):
_ = self.net_g_ms.eval()
utils.load_checkpoint(self.vits_path, self.net_g_ms)
self.net_g_ms.to(self.device)

def release_model(self):
del self.net_g_ms


def get_cleaned_text(self, text, hps, cleaned=False):
if cleaned:
Expand Down

0 comments on commit 00431b2

Please sign in to comment.