Skip to content

Commit

Permalink
fix act
Browse files Browse the repository at this point in the history
  • Loading branch information
KimythAnly committed Nov 18, 2020
1 parent 951fe98 commit f657a4c
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions model/again.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ def calc_mean_std(self, x, mask=None):


def forward(self, x, return_mean_std=False):
"""
:param x: has either shape (b, c, x, y) or shape (b, c, x, y, z)
:return:
"""
mean, std = self.calc_mean_std(x)
x = (x - mean) / std
if return_mean_std:
Expand Down Expand Up @@ -349,7 +345,6 @@ class Activation(nn.Module):
}
def __init__(self, act, params=None):
super().__init__()

self.act = Activation.dct[act](**params)

def forward(self, x):
Expand All @@ -372,8 +367,11 @@ def forward(self, x, x_cond=None):

x, x_cond = x[:,None,:,:], x_cond[:,None,:,:]

enc = self.encoder(x) # , mask=x_mask)
cond = self.encoder(x_cond) #, mask=cond_mask)
enc, mns_enc, sds_enc = self.encoder(x) # , mask=x_mask)
cond, mns_cond, sds_cond = self.encoder(x_cond) #, mask=cond_mask)

enc = (self.act(enc), mns_enc, sds_enc)
cond = (self.act(cond), mns_cond, sds_cond)

y = self.decoder(enc, cond)
return y
Expand Down Expand Up @@ -401,8 +399,11 @@ def inference(self, source, target):
x = x[:,None,:,:]
x_cond = x_cond[:,None,:,:]

enc = self.encoder(x)
cond = self.encoder(x_cond)
enc, mns_enc, sds_enc = self.encoder(x) # , mask=x_mask)
cond, mns_cond, sds_cond = self.encoder(x_cond) #, mask=cond_mask)

enc = (self.act(enc), mns_enc, sds_enc)
cond = (self.act(cond), mns_cond, sds_cond)

y = self.decoder(enc, cond)

Expand Down

0 comments on commit f657a4c

Please sign in to comment.