Skip to content

Commit

Permalink
change encoder and decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-yf committed May 21, 2023
1 parent 44febe0 commit 6e593c1
Showing 1 changed file with 63 additions and 39 deletions.
102 changes: 63 additions & 39 deletions VAE/CVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ def forward(self, x):
return x.view(x.size(0), -1)

class UnFlatten(nn.Module):
def __init__(self, channel:int = 256) -> None:
super().__init__()
self.channel = channel

def forward(self, x):
B, D = x.shape
H = W = int(np.sqrt(D / 256))
return x.view(B, 256, H, W)
H = W = int(np.sqrt(D / self.channel))
return x.view(B, self.channel, H, W)

def posenc_nerf(x, min_deg=0, max_deg=15):
"""Concatenate x and its positional encodings, following NeRF."""
Expand All @@ -28,22 +32,13 @@ def posenc_nerf(x, min_deg=0, max_deg=15):
emb = torch.sin(torch.concat([xb, xb + torch.pi / 2.], dim=-1))

return torch.concat([x, emb], dim=-1)
# class PoseCondition(nn.Module):
# def __init__(self) -> None:
# # TODO: Decide to imitate the method done in XUnet.
# # By using the camera model in visu3d, we can get the ray info for all the pixels in the image.
# # The ray origin is 3-dim vector and direction is also 3-dim vector.
# # If we just concate the ray info for all the pixel in the image, we can get a tensor in shape (H, W, 6), 6 can be seen as channel?
# # NeRF PE is applied on ray origin and ray direction.
# # CNN is also used to change the spatial size to be the same as the downsampled image during VAE processing.
# super().__init__()

class PoseConditionProcessor(torch.nn.Module):

def __init__(self, emb_ch, H, W,
num_resolutions,
use_pos_emb=False,
use_ref_pose_emb=False):
use_pos_emb=True,
use_ref_pose_emb=True):

super().__init__()

Expand All @@ -59,12 +54,12 @@ def __init__(self, emb_ch, H, W,
self.pos_emb = torch.nn.Parameter(torch.zeros(D, H, W), requires_grad=True)
torch.nn.init.normal_(self.pos_emb, std=(1 / np.sqrt(D)))

# if use_ref_pose_emb:
# self.first_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True)
# torch.nn.init.normal_(self.first_emb, std=(1 / np.sqrt(D)))
if use_ref_pose_emb:
self.first_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True)
torch.nn.init.normal_(self.first_emb, std=(1 / np.sqrt(D)))

# self.other_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True)
# torch.nn.init.normal_(self.other_emb, std=(1 / np.sqrt(D)))
self.other_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True)
torch.nn.init.normal_(self.other_emb, std=(1 / np.sqrt(D)))

convs = []
for i_level in range(self.num_resolutions):
Expand Down Expand Up @@ -157,9 +152,12 @@ def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: i
self.condition_processor = PoseConditionProcessor(emb_ch, H, W, n_resolution)
# TODO: Now hardcode for layers, change to list
self.ec1 = EncoderBlock(3, 32, H, W)
self.ec2 = EncoderBlock(32 + emb_ch, 64, H // 2, W // 2)
self.ec3 = EncoderBlock(64 + emb_ch, 128, H // 4, W // 4)
self.ec4 = EncoderBlock(128 + emb_ch, 256, H // 8, W // 8)
# self.ec2 = EncoderBlock(32 + emb_ch, 64, H // 2, W // 2)
# self.ec3 = EncoderBlock(64 + emb_ch, 128, H // 4, W // 4)
# self.ec4 = EncoderBlock(128 + emb_ch, 256, H // 8, W // 8)
self.ec2 = EncoderBlock(32, 64, H // 2, W // 2)
self.ec3 = EncoderBlock(64, 128, H // 4, W // 4)
self.ec4 = EncoderBlock(128, 256, H // 8, W // 8)

self.flatten = Flatten()
self.fc1 = nn.Linear(256 * (H // 16) * (W // 16), 2*z_dim) # for mu, logvar
Expand All @@ -183,12 +181,12 @@ def bottle_neck(self, x):

def encode(self, x, pose_embeds):
out1 = self.ec1(x)
input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1)
out2 = self.ec2(input2)
input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1)
out3 = self.ec3(input3)
input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1)
out4 = self.ec4(input4)
# input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1)
out2 = self.ec2(out1)
# input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1)
out3 = self.ec3(out2)
# input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1)
out4 = self.ec4(out3)
z_out = self.fc1(self.flatten(out4))
return z_out[:,:self.z_dim], z_out[:,self.z_dim:]

Expand All @@ -200,36 +198,62 @@ def reparaterize(self, mu, logvar):
def decode(self, z, pose_embeds):
input1 = self.fc2(z)
out1 = self.dc1(self.unflatten(input1))
input2 = torch.concat([out1, pose_embeds[2][:,1,:]], dim=1)
out2 = self.dc2(input2)
input3 = torch.concat([out2, pose_embeds[1][:,1,:]], dim=1)
out3 = self.dc3(input3)
input4 = torch.concat([out3, pose_embeds[0][:,1,:]], dim=1)
out4 = self.dc4(input4)
return out4

# generate new image
input2_z = torch.concat([out1, pose_embeds[2][:,1]], dim=1)
out2_z = self.dc2(input2_z)
input3_z = torch.concat([out2_z, pose_embeds[1][:,1]], dim=1)
out3_z = self.dc3(input3_z)
input4_z = torch.concat([out3_z, pose_embeds[0][:,1]], dim=1)
out4_z = self.dc4(input4_z)

# reconstruct input image
input2_x = torch.concat([out1, pose_embeds[2][:,0]], dim=1)
out2_x = self.dc2(input2_x)
input3_x = torch.concat([out2_x, pose_embeds[1][:,0]], dim=1)
out3_x = self.dc3(input3_x)
input4_x = torch.concat([out3_x, pose_embeds[0][:,0]], dim=1)
out4_x = self.dc4(input4_x)

return out4_z, out4_x


def forward(self, batch, cond_mask=None):
pose_embeds = self.condition_processor(batch, cond_mask)
# print([pose_embeds[i].shape for i in range(3)])
x = batch['x']
gt = batch['z']
z_mu, z_logvar = self.encode(x, pose_embeds)
z = self.reparaterize(z_mu, z_logvar)
img_recon = self.decode(z, pose_embeds)
return self.loss(z_mu, z_logvar, img_recon, x)
img_gen, img_recon = self.decode(z, pose_embeds)
return self.loss(z_mu, z_logvar, img_gen, gt, img_recon, x)

def loss(self, z_mu, z_logvar, img_recon, img_gt):
def loss(self, z_mu, z_logvar, img_gen, img_gt, img_recon, img_input):
kld = torch.mean(
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
)
img_loss = ((img_gt - img_recon)**2).mean()
img_loss = ((img_gen - img_gt)**2).mean()
img_loss += ((img_recon - img_input)**2).mean()
return self.beta * kld , img_loss

def eval_img(self, batch, cond_mask=None):
pose_embeds = self.condition_processor(batch, cond_mask)
x = batch['x']
z_mu, z_logvar = self.encode(x, pose_embeds)
img_recon = self.decode(z_mu, pose_embeds)
img_recon, _ = self.decode(z_mu, pose_embeds)
return img_recon

# class LabelMapping(nn.Module):
# def __init__(self, channel:int=8, category:int=0, f_num:int=1) -> None:
# super().__init__()
# self.channel = channel
# self.category = category
# self.f_num = f_num
# self.fc = nn.Linear(channel, f_num * category)
# self.softmax = nn.Softmax(dim=1)

# class ConditionalDeformableVAE(nn.Module):
# def __init__(self, *args, **kwargs) -> None:
# super().__init__(*args, **kwargs)


0 comments on commit 6e593c1

Please sign in to comment.