forked from mozilla/TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_glow_tts.py
131 lines (113 loc) · 5.06 KB
/
test_glow_tts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import copy
import os
import unittest
import torch
from tests import get_tests_input_path
from torch import optim
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTts
from TTS.utils.io import load_config
from TTS.utils.audio import AudioProcessor
#pylint: disable=unused-variable
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
ap = AudioProcessor(**c.audio)
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class GlowTTSTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
criterion = criterion = GlowTTSLoss()
# model to train
model = GlowTts(num_chars=32,
hidden_channels=128,
filter_channels=32,
filter_channels_dp=32,
out_channels=80,
kernel_size=3,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dec=0.,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_sqz=1,
sigmoid_scale=False,
rel_attn_window_size=None,
input_length=None,
mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer").to(device)
# reference model to compare model weights
model_ref = GlowTts(num_chars=32,
hidden_channels=128,
filter_channels=32,
filter_channels_dp=32,
out_channels=80,
kernel_size=3,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dec=0.,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_sqz=1,
sigmoid_scale=False,
rel_attn_window_size=None,
input_length=None,
mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer").to(device)
model.train()
print(" > Num parameters for GlowTTS model:%s" %
(count_parameters(model)))
# pass the state to ref model
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for _ in range(5):
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, None)
optimizer.zero_grad()
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, input_lengths)
loss = loss_dict['loss']
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1