Skip to content

Commit

Permalink
📚 Update Generator to suit Avocodo
Browse files Browse the repository at this point in the history
  • Loading branch information
rishikksh20 committed Jun 29, 2022
1 parent c546b8b commit f3e9f55
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def inference(a):
wav = wav / MAX_WAV_VALUE
wav = torch.FloatTensor(wav).to(device)
x = get_mel(wav.unsqueeze(0))
y_g_hat = generator(x)
y_g_hat, _, _ = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
Expand Down
2 changes: 1 addition & 1 deletion inference_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def inference(a):
for i, filname in enumerate(filelist):
x = np.load(os.path.join(a.input_mels_dir, filname))
x = torch.FloatTensor(x).to(device)
y_g_hat = generator(x)
y_g_hat, _, _ = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
Expand Down
43 changes: 28 additions & 15 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,42 +78,56 @@ def __init__(self, h):
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2

self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2)))

self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel//(2**(i+1))
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))

self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
print(self.conv_post)
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)

self.out_proj_x1 = weight_norm(Conv1d(h.upsample_initial_channel // 4, 1, 7, 1, padding=3))
self.out_proj_x2 = weight_norm(Conv1d(h.upsample_initial_channel // 8, 1, 7, 1, padding=3))

def forward(self, x):

x1 = None
x2 = None
x = self.conv_pre(x)

for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels

if i == 1:
x1 = self.out_proj_x1(x)
elif i == 2:
x2 = self.out_proj_x2(x)

x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)

return x
return x, x2, x1

def remove_weight_norm(self):
print('Removing weight norm...')
Expand Down Expand Up @@ -144,7 +158,7 @@ def forward(self, x):

# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
Expand Down Expand Up @@ -236,8 +250,8 @@ def forward(self, y, y_hat):
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
Expand All @@ -254,16 +268,16 @@ def feature_loss(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))

return loss*2
return loss * 2


def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
Expand All @@ -275,9 +289,8 @@ def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l

return loss, gen_losses

4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def train(rank, a, h):
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)

y_g_hat = generator(x)
y_g_hat, x2, x1 = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)

Expand Down Expand Up @@ -191,7 +191,7 @@ def train(rank, a, h):
with torch.no_grad():
for j, batch in enumerate(validation_loader):
x, y, _, y_mel = batch
y_g_hat = generator(x.to(device))
y_g_hat, _, _ = generator(x.to(device))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
Expand Down

0 comments on commit f3e9f55

Please sign in to comment.