Skip to content

Commit

Permalink
🐛 add config and update code for training
Browse files Browse the repository at this point in the history
  • Loading branch information
rishikksh20 committed Jun 30, 2022
1 parent ae17ac6 commit 3f155b1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
18 changes: 18 additions & 0 deletions config_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

"combd_channels": [16, 64, 256, 1024, 1024, 1024],
"combd_kernels" : [[7, 11, 11, 11, 11, 5], [11, 21, 21, 21, 21, 5], [15, 41, 41, 41, 41, 5]],
"combd_groups" : [1, 4, 16, 64, 256, 1],
"combd_strides": [1, 1, 4, 4, 4, 1],

"tkernels" : [7, 5, 3],
"fkernel" : 5,
"tchannels" : [64, 128, 256, 256, 256],
"fchannels" : [32, 64, 128, 128, 128],
"tstrides" : [[1, 1, 3, 3, 1], [1, 1, 3, 3, 1], [1, 1, 3, 3, 1]],
"fstride" : [1, 1, 3, 3, 1],
"tdilations" : [[[5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11], [5, 7, 11]], [[3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7], [3, 5, 7]], [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]],
"fdilations" : [[1, 2, 3], [1, 2, 3], [1, 2, 3], [2, 3, 5], [2, 3, 5]],
"pqmf_n" : 16,
"pqmf_m" : 64,
"freq_init_ch" : 128,
"tsubband" : [6, 11, 16],

"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
Expand Down
18 changes: 11 additions & 7 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(self, kernels, channels, groups, strides):
self.pqmf_2 = PQMF(N=2, taps=256, cutoff=0.25, beta=10.0)
self.pqmf_4 = PQMF(N=4, taps=192, cutoff=0.13, beta=10.0)

def forward(self, x, x2, x1, x_hat, x2_hat, x1_hat):
def forward(self, x, x_hat, x2_hat, x1_hat):
y = []
y_hat = []
fmap = []
Expand All @@ -288,27 +288,31 @@ def forward(self, x, x2, x1, x_hat, x2_hat, x1_hat):
y_hat.append(p3_hat)
fmap_hat.append(p3_fmap_hat)

p2_, p2_fmap_ = self.combd_2(x2)
x2_ = self.pqmf_2(x)[:, :1, :] # Select first band
x1_ = self.pqmf_4(x)[:, :1, :] # Select first band

x2_hat_ = self.pqmf_2(x_hat)[:, :1, :]
x1_hat_ = self.pqmf_4(x_hat)[:, :1, :]

p2_, p2_fmap_ = self.combd_2(x2_)
y.append(p2_)
fmap.append(p2_fmap_)

p2_hat_, p2_fmap_hat_ = self.combd_2(x2_hat)
y_hat.append(p2_hat_)
fmap_hat.append(p2_fmap_hat_)

p1_, p1_fmap_ = self.combd_1(x1)
p1_, p1_fmap_ = self.combd_1(x1_)
y.append(p1_)
fmap.append(p1_fmap_)

p1_hat_, p1_fmap_hat_ = self.combd_1(x1_hat)
y_hat.append(p1_hat_)
fmap_hat.append(p1_fmap_hat_)

x2_ = self.pqmf_2(x)[:, :1, :] # Select first band
x1_ = self.pqmf_4(x)[:, :1, :] # Select first band

x2_hat_ = self.pqmf_2(x_hat)[:, :1, :]
x1_hat_ = self.pqmf_4(x_hat)[:, :1, :]



p2, p2_fmap = self.combd_2(x2_)
y.append(p2)
Expand Down
37 changes: 19 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
from models import Generator, MultiCoMBDiscriminator, MultiSubBandDiscriminator, feature_loss, generator_loss,\
discriminator_loss
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint

Expand All @@ -30,8 +30,9 @@ def train(rank, a, h):
device = torch.device('cuda:{:d}'.format(rank))

generator = Generator(h).to(device)
mpd = MultiPeriodDiscriminator().to(device)
msd = MultiScaleDiscriminator().to(device)
mcmbd = MultiCoMBDiscriminator(h.combd_kernels, h.combd_channels, h.combd_groups, h.combd_strides).to(device)
msbd = MultiSubBandDiscriminator(h.tkernels, h.fkernel, h.tchannels, h.fchannels, h.tstrides, h.fstride,
h.tdilations, h.fdilations, h.tsubband, h.pqmf_n, h.pqmf_m, h.freq_init_ch).to(device)

if rank == 0:
print(generator)
Expand All @@ -50,18 +51,18 @@ def train(rank, a, h):
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
mpd.load_state_dict(state_dict_do['mpd'])
msd.load_state_dict(state_dict_do['msd'])
mcmbd.load_state_dict(state_dict_do['mcmbd'])
msbd.load_state_dict(state_dict_do['msbd'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']

if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
mcmbd = DistributedDataParallel(mcmbd, device_ids=[rank]).to(device)
msbd = DistributedDataParallel(msbd, device_ids=[rank]).to(device)

optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
optim_d = torch.optim.AdamW(itertools.chain(msbd.parameters(), mcmbd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])

if state_dict_do is not None:
Expand Down Expand Up @@ -100,8 +101,8 @@ def train(rank, a, h):
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))

generator.train()
mpd.train()
msd.train()
mcmbd.train()
msbd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
Expand All @@ -126,11 +127,11 @@ def train(rank, a, h):
optim_d.zero_grad()

# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
y_df_hat_r, y_df_hat_g, _, _ = mcmbd(y, y_g_hat.detach(), x2.detach(), x1.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)

# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
y_ds_hat_r, y_ds_hat_g, _, _ = msbd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)

loss_disc_all = loss_disc_s + loss_disc_f
Expand All @@ -144,8 +145,8 @@ def train(rank, a, h):
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mcmbd(y, y_g_hat, x2, x1)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msbd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
Expand All @@ -171,10 +172,10 @@ def train(rank, a, h):
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
{'mcmbd': (mcmbd.module if h.num_gpus > 1
else mcmbd).state_dict(),
'msbd': (msbd.module if h.num_gpus > 1
else msbd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})

Expand Down

0 comments on commit 3f155b1

Please sign in to comment.