-
Notifications
You must be signed in to change notification settings - Fork 36
/
model.py
94 lines (76 loc) · 2.77 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import json
import torch
import numpy as np
import hifigan
from model import PortaSpeech, ScheduledOptim
def get_model(args, configs, device, train=False):
(preprocess_config, model_config, train_config) = configs
model = PortaSpeech(preprocess_config, model_config, train_config).to(device)
if args.restore_step:
ckpt_path = os.path.join(
train_config["path"]["ckpt_path"],
"{}.pth.tar".format(args.restore_step),
)
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model"])
if train:
scheduled_optim = ScheduledOptim(
model, train_config, model_config, args.restore_step
)
if args.restore_step:
scheduled_optim.load_state_dict(ckpt["optimizer"])
model.train()
return model, scheduled_optim
model.eval()
model.requires_grad_ = False
return model
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def get_vocoder(config, device):
name = config["vocoder"]["model"]
speaker = config["vocoder"]["speaker"]
if name == "MelGAN":
if speaker == "LJSpeech":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
)
elif speaker == "universal":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
)
vocoder.mel2wav.eval()
vocoder.mel2wav.to(device)
elif name == "HiFi-GAN":
with open("hifigan/config.json", "r") as f:
config = json.load(f)
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator(config)
if speaker == "LJSpeech":
ckpt = torch.load(
"hifigan/generator_LJSpeech.pth.tar", map_location=device)
elif speaker == "universal":
ckpt = torch.load(
"hifigan/generator_universal.pth.tar", map_location=device)
vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder
def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
name = model_config["vocoder"]["model"]
with torch.no_grad():
if name == "MelGAN":
wavs = vocoder.inverse(mels / np.log(10))
elif name == "HiFi-GAN":
wavs = vocoder(mels).squeeze(1)
wavs = (
wavs.cpu().numpy()
* preprocess_config["preprocessing"]["audio"]["max_wav_value"]
).astype("int16")
wavs = [wav for wav in wavs]
for i in range(len(mels)):
if lengths is not None:
wavs[i] = wavs[i][: lengths[i]]
return wavs