Skip to content

Commit

Permalink
enc dec enlarged
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-yf committed May 23, 2023
1 parent 1e80eed commit 83bd2ad
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 29 deletions.
120 changes: 92 additions & 28 deletions VAE/CVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ def posenc_nerf(x, min_deg=0, max_deg=15):

return torch.concat([x, emb], dim=-1)

def rt_to_quaternion(R:torch.tensor, t:torch.tensor):
"""Converts rotation matrix and translation vector to quaternion."""
# R: [B, 3, 3]
# t: [B, 3]
# q: [B, 4]
trace = torch.trace(R)
trace = torch.clamp(trace, min=-1, max=3)
if trace > 0:
qw = 0.5 * torch.sqrt(1 + trace)
qx = R[2, 1] - R[1, 2] / (4 * qw)
qy = R[0, 2] - R[2, 0] / (4 * qw)
qz = R[1, 0] - R[0, 1] / (4 * qw)
else:
max_diag = torch.argmax(torch.diag(R))
if max_diag == 0:
qx = 0.5 * torch.sqrt(1 + R[0, 0] - R[1, 1] - R[2, 2])
qy = R[0, 1] + R[1, 0] / (4 * qx)
qz = R[0, 2] + R[2, 0] / (4 * qx)
qw = R[2, 1] - R[1, 2] / (4 * qx)
elif max_diag == 1:
qy = 0.5 * torch.sqrt(1 + R[1, 1] - R[0, 0] - R[2, 2])
qx = R[0, 1] + R[1, 0] / (4 * qy)
qz = R[1, 2] + R[2, 1] / (4 * qy)
qw = R[0, 2] - R[2, 0] / (4 * qy)
elif max_diag == 2:
qz = 0.5 * torch.sqrt(1 + R[2, 2] - R[0, 0] - R[1, 1])
qx = R[0, 2] + R[2, 0] / (4 * qz)
qy = R[1, 2] + R[2, 1] / (4 * qz)
qw = R[1, 0] - R[0, 1] / (4 * qz)
else:
qw = 0
qx = 0
qy = 0
qz = 0
q = torch.tensor([qw, qx, qy, qz])
return q

class PoseConditionProcessor(torch.nn.Module):

def __init__(self, emb_ch, H, W,
Expand Down Expand Up @@ -116,27 +153,43 @@ def forward(self, batch, cond_mask):
class EncoderBlock(nn.Module):
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 5, 2, 2)
self.norm = nn.LayerNorm([out_channel, input_h // 2, input_w // 2])
self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2)
self.conv2 = nn.Conv2d(out_channel, out_channel, 5, 2, 2)
self.norm1 = nn.LayerNorm([out_channel, input_h, input_w])
self.norm2 = nn.LayerNorm([out_channel, input_h // 2, input_w // 2])
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
return x

class DecoderBlock(nn.Module):
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None:
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int, end:bool=False) -> None:
super().__init__()
self.conv = nn.ConvTranspose2d(in_channel, out_channel, 5, 2, 2, output_padding=1)
self.layer_norm = nn.LayerNorm([out_channel, input_h * 2, input_w * 2])
self.relu = nn.ReLU()
# self.end = end
self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2)
self.conv2 = nn.ConvTranspose2d(out_channel, out_channel, 5, 2, 2, output_padding=1)
self.norm1 = nn.LayerNorm([out_channel, input_h, input_w])
self.norm2 = nn.LayerNorm([out_channel, input_h * 2, input_w * 2])
self.used_relu = nn.ReLU()
if not end:
self.relu = nn.ReLU()
else:
self.relu = None

def forward(self, x):
x = self.conv(x)
x = self.layer_norm(x)
x = self.relu(x)
x = self.conv1(x)
x = self.norm1(x)
x = self.used_relu(x)
x = self.conv2(x)
x = self.norm2(x)
if self.relu:
x = self.relu(x)
return x

class ConditionalVAE(nn.Module):
Expand Down Expand Up @@ -167,7 +220,7 @@ def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: i
self.dc1 = DecoderBlock(256, 128, H // 16, W // 16)
self.dc2 = DecoderBlock(128 + emb_ch, 64, H // 8, W // 8)
self.dc3 = DecoderBlock(64 + emb_ch, 32, H // 4, W // 4)
self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2)
self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2, True)

def bottle_neck(self, x):
assert len(x.shape) == 2
Expand All @@ -181,11 +234,11 @@ def bottle_neck(self, x):

def encode(self, x, pose_embeds):
out1 = self.ec1(x)
# input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1)
# input2 = torch.concat([out1, pose_embeds[0][:,0]], dim=1)
out2 = self.ec2(out1)
# input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1)
# input3 = torch.concat([out2, pose_embeds[1][:,0]], dim=1)
out3 = self.ec3(out2)
# input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1)
# input4 = torch.concat([out3, pose_embeds[2][:,0]], dim=1)
out4 = self.ec4(out3)
z_out = self.fc1(self.flatten(out4))
return z_out[:,:self.z_dim], z_out[:,self.z_dim:]
Expand Down Expand Up @@ -232,8 +285,9 @@ def loss(self, z_mu, z_logvar, img_gen, img_gt, img_recon, img_input):
kld = torch.mean(
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
)
img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean()
img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean()
# img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean()
# img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean()
img_loss = F.mse_loss(img_gen, img_gt) + F.mse_loss(img_recon, img_input)
return self.beta * kld , img_loss

def eval_img(self, batch, cond_mask=None):
Expand All @@ -243,17 +297,27 @@ def eval_img(self, batch, cond_mask=None):
pred_img, recon_img = self.decode(z_mu, pose_embeds)
return pred_img, recon_img

# class LabelMapping(nn.Module):
# def __init__(self, channel:int=8, category:int=0, f_num:int=1) -> None:
# super().__init__()
# self.channel = channel
# self.category = category
# self.f_num = f_num
# self.fc = nn.Linear(channel, f_num * category)
# self.softmax = nn.Softmax(dim=1)
class PoseMapping(nn.Module):
'''Map the pose(quaternion) to two vectors'''
def __init__(self, embed:int = 64) -> None:
super().__init__()
self.fc_x = nn.Linear(4, embed)
self.fc_y = nn.Linear(4, embed)

def forward(self, x):
assert len(x.shape) == 2
assert x.shape[1] == 4
return self.fc_x(x), self.fc_y(x)

class CDVAEEncBlock(nn.Module):
def __init__(self) -> None:
super().__init__()

# class ConditionalDeformableVAE(nn.Module):
# def __init__(self, *args, **kwargs) -> None:
# super().__init__(*args, **kwargs)

class ConditionalDeformableVAE(nn.Module):
def __init__(self, pose_embed:int=64) -> None:
super().__init__()
self.pose_embed = pose_embed

self.pose_mapping_enc = PoseMapping(pose_embed)
self.pose_mapping_dec = PoseMapping(pose_embed)
15 changes: 14 additions & 1 deletion VAE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def main(args):
optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))

if args.transfer == "":
now = './results/shapenet_SRN_car/' + str(int(time.time()))
if args.exp_name is None:
now = './results/shapenet_SRN_car/' + str(int(time.time()))
else:
now = './results/shapenet_SRN_car/' + args.exp_name
# now = './results/shapenet_SRN_car/' + "no_nonlinear"
writer = SummaryWriter(now)
step = 0
else:
Expand All @@ -46,8 +50,10 @@ def main(args):
optimizer.load_state_dict(ckpt['optim'])

now = args.transfer
# now = args.transfer + "_val"
writer = SummaryWriter(now)
step = ckpt['step']
# validation(model, loader, writer, step, args.batch_size)
train(model, optimizer, loader, loader_val, writer, now, step, args)


Expand Down Expand Up @@ -115,6 +121,12 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
pred_img, recon_img = model.module.eval_img(batch, None)
pred_img = pred_img.detach().cpu().numpy()
recon_img = recon_img.detach().cpu().numpy()

writer.add_scalar("val/recon_min", recon_img.min(), global_step=step)
writer.add_scalar("val/recon_max", recon_img.max(), global_step=step)
writer.add_scalar("val/gen_min", pred_img.min(), global_step=step)
writer.add_scalar("val/gen_max", pred_img.max(), global_step=step)

pred_img = ((pred_img.clip(-1, 1)+1)*127.5).astype(np.uint8)
recon_img = ((recon_img.clip(-1, 1)+1)*127.5).astype(np.uint8)

Expand All @@ -132,6 +144,7 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default="../data/SRN/cars_train")
parser.add_argument('--pickle_path', type=str, default="../data/cars.pickle")
parser.add_argument('--exp_name', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--image_size', type=int, default=128)
Expand Down

0 comments on commit 83bd2ad

Please sign in to comment.