Skip to content

Commit

Permalink
🚚 Finished up Discriminator implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rishikksh20 committed Jun 30, 2022
1 parent b09716e commit f20ac70
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
34 changes: 18 additions & 16 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def forward(self, x, x2, x1, x_hat, x2_hat, x1_hat):
y_hat.append(p1_hat_)
fmap_hat.append(p1_fmap_hat_)

x2_ = self.pqmf_2(x) # Select first band
x1_ = self.pqmf_4(x) # Select first band after figuring out shape of ouput PQMF
x2_ = self.pqmf_2(x)[:, :1, :] # Select first band
x1_ = self.pqmf_4(x)[:, :1, :] # Select first band

x2_hat_ = self.pqmf_2(x_hat)
x1_hat_ = self.pqmf_4(x_hat)
Expand All @@ -331,23 +331,25 @@ def forward(self, x, x2, x1, x_hat, x2_hat, x1_hat):
class MultiSubBandDiscriminator(torch.nn.Module):

def __init__(self, tkernels, fkernel, tchannels, fchannels, tstrides, fstride, tdilations, fdilations, tsubband,
n, m):
n, m, freq_init_ch):

super(MultiSubBandDiscriminator, self).__init__()

self.fsbd = SubBandDiscriminator(channels=fchannels, kernel=fkernel, strides=fstride, dilations=fdilations)
self.fsbd = SubBandDiscriminator(init_channel=freq_init_ch, channels=fchannels, kernel=fkernel,
strides=fstride, dilations=fdilations)

self.tsbd1 = SubBandDiscriminator(channels=tchannels, kernel=tkernels[0], strides=tstrides[0],
dilations=tdilations[0])
self.tsubband1 = tsubband[0]
self.tsbd1 = SubBandDiscriminator(init_channel=self.tsubband1, channels=tchannels, kernel=tkernels[0],
strides=tstrides[0], dilations=tdilations[0])

self.tsbd2 = SubBandDiscriminator(channels=tchannels, kernel=tkernels[1], strides=tstrides[1],
dilations=tdilations[1])
self.tsubband2 = tsubband[1]
self.tsbd2 = SubBandDiscriminator(init_channel=self.tsubband2, channels=tchannels, kernel=tkernels[1],
strides=tstrides[1], dilations=tdilations[1])

self.tsbd3 = SubBandDiscriminator(channels=tchannels, kernel=tkernels[2], strides=tstrides[2],
dilations=tdilations[2])
self.tsubband3 = tsubband[2]
self.tsbd3 = SubBandDiscriminator(init_channel=self.tsubband3, channels=tchannels, kernel=tkernels[2],
strides=tstrides[2], dilations=tdilations[2])


self.pqmf_n = PQMF(N=n, taps=256, cutoff=0.03, beta=10.0)
self.pqmf_m = PQMF(N=m, taps=256, cutoff=0.1, beta=9.0)
Expand All @@ -362,22 +364,22 @@ def forward(self, x, x_hat):
xn = self.pqmf_n(x)
xn_hat = self.pqmf_n(x_hat)

q3, feat_q3 = self.tsbd3(xn)
q3_hat, feat_q3_hat = self.tsbd3(xn_hat)
q3, feat_q3 = self.tsbd3(xn[:, :self.tsubband3, :])
q3_hat, feat_q3_hat = self.tsbd3(xn_hat[:, :self.tsubband3, :])
y.append(q3)
y_hat.append(q3_hat)
fmap.append(feat_q3)
fmap_hat.append(feat_q3_hat)

q2, feat_q2 = self.tsbd2(xn)
q2_hat, feat_q2_hat = self.tsbd2(xn_hat)
q2, feat_q2 = self.tsbd2(xn[:, :self.tsubband2, :])
q2_hat, feat_q2_hat = self.tsbd2(xn_hat[:, :self.tsubband2, :])
y.append(q2)
y_hat.append(q2_hat)
fmap.append(feat_q2)
fmap_hat.append(feat_q2_hat)

q1, feat_q1 = self.tsbd1(xn)
q1_hat, feat_q1_hat = self.tsbd1(xn_hat)
q1, feat_q1 = self.tsbd1(xn[:, :self.tsubband1, :])
q1_hat, feat_q1_hat = self.tsbd1(xn_hat[:, :self.tsubband1, :])
y.append(q1)
y_hat.append(q1_hat)
fmap.append(feat_q1)
Expand Down
4 changes: 2 additions & 2 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def forward(self, x):

class SubBandDiscriminator(torch.nn.Module):

def __init__(self, channels, kernel, strides, dilations, use_spectral_norm=False):
def __init__(self, init_channel, channels, kernel, strides, dilations, use_spectral_norm=False):
super(SubBandDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm

self.mdcs = torch.nn.ModuleList()
init_channel = 1

for c, s, d in zip(channels, strides, dilations):
self.mdcs.append(MDC(init_channel, c, kernel, s, d))
init_channel = c
Expand Down

0 comments on commit f20ac70

Please sign in to comment.