diff --git a/model/again.py b/model/again.py index 2621ccc..9df6ef7 100644 --- a/model/again.py +++ b/model/again.py @@ -134,10 +134,6 @@ def calc_mean_std(self, x, mask=None): def forward(self, x, return_mean_std=False): - """ - :param x: has either shape (b, c, x, y) or shape (b, c, x, y, z) - :return: - """ mean, std = self.calc_mean_std(x) x = (x - mean) / std if return_mean_std: @@ -349,7 +345,6 @@ class Activation(nn.Module): } def __init__(self, act, params=None): super().__init__() - self.act = Activation.dct[act](**params) def forward(self, x): @@ -372,8 +367,11 @@ def forward(self, x, x_cond=None): x, x_cond = x[:,None,:,:], x_cond[:,None,:,:] - enc = self.encoder(x) # , mask=x_mask) - cond = self.encoder(x_cond) #, mask=cond_mask) + enc, mns_enc, sds_enc = self.encoder(x) # , mask=x_mask) + cond, mns_cond, sds_cond = self.encoder(x_cond) #, mask=cond_mask) + + enc = (self.act(enc), mns_enc, sds_enc) + cond = (self.act(cond), mns_cond, sds_cond) y = self.decoder(enc, cond) return y @@ -401,8 +399,11 @@ def inference(self, source, target): x = x[:,None,:,:] x_cond = x_cond[:,None,:,:] - enc = self.encoder(x) - cond = self.encoder(x_cond) + enc, mns_enc, sds_enc = self.encoder(x) # , mask=x_mask) + cond, mns_cond, sds_cond = self.encoder(x_cond) #, mask=cond_mask) + + enc = (self.act(enc), mns_enc, sds_enc) + cond = (self.act(cond), mns_cond, sds_cond) y = self.decoder(enc, cond)