diff --git a/VAE/CVAE.py b/VAE/CVAE.py index d6a5256..da97aba 100644 --- a/VAE/CVAE.py +++ b/VAE/CVAE.py @@ -33,6 +33,43 @@ def posenc_nerf(x, min_deg=0, max_deg=15): return torch.concat([x, emb], dim=-1) +def rt_to_quaternion(R:torch.tensor, t:torch.tensor): + """Converts rotation matrix and translation vector to quaternion.""" + # R: [B, 3, 3] + # t: [B, 3] + # q: [B, 4] + trace = torch.trace(R) + trace = torch.clamp(trace, min=-1, max=3) + if trace > 0: + qw = 0.5 * torch.sqrt(1 + trace) + qx = R[2, 1] - R[1, 2] / (4 * qw) + qy = R[0, 2] - R[2, 0] / (4 * qw) + qz = R[1, 0] - R[0, 1] / (4 * qw) + else: + max_diag = torch.argmax(torch.diag(R)) + if max_diag == 0: + qx = 0.5 * torch.sqrt(1 + R[0, 0] - R[1, 1] - R[2, 2]) + qy = R[0, 1] + R[1, 0] / (4 * qx) + qz = R[0, 2] + R[2, 0] / (4 * qx) + qw = R[2, 1] - R[1, 2] / (4 * qx) + elif max_diag == 1: + qy = 0.5 * torch.sqrt(1 + R[1, 1] - R[0, 0] - R[2, 2]) + qx = R[0, 1] + R[1, 0] / (4 * qy) + qz = R[1, 2] + R[2, 1] / (4 * qy) + qw = R[0, 2] - R[2, 0] / (4 * qy) + elif max_diag == 2: + qz = 0.5 * torch.sqrt(1 + R[2, 2] - R[0, 0] - R[1, 1]) + qx = R[0, 2] + R[2, 0] / (4 * qz) + qy = R[1, 2] + R[2, 1] / (4 * qz) + qw = R[1, 0] - R[0, 1] / (4 * qz) + else: + qw = 0 + qx = 0 + qy = 0 + qz = 0 + q = torch.tensor([qw, qx, qy, qz]) + return q + class PoseConditionProcessor(torch.nn.Module): def __init__(self, emb_ch, H, W, @@ -116,27 +153,43 @@ def forward(self, batch, cond_mask): class EncoderBlock(nn.Module): def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None: super().__init__() - self.conv = nn.Conv2d(in_channel, out_channel, 5, 2, 2) - self.norm = nn.LayerNorm([out_channel, input_h // 2, input_w // 2]) + self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2) + self.conv2 = nn.Conv2d(out_channel, out_channel, 5, 2, 2) + self.norm1 = nn.LayerNorm([out_channel, input_h, input_w]) + self.norm2 = nn.LayerNorm([out_channel, input_h // 2, input_w // 2]) self.relu = nn.ReLU() def forward(self, x): - x = self.conv(x) - x = self.norm(x) + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) x = self.relu(x) return x class DecoderBlock(nn.Module): - def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None: + def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int, end:bool=False) -> None: super().__init__() - self.conv = nn.ConvTranspose2d(in_channel, out_channel, 5, 2, 2, output_padding=1) - self.layer_norm = nn.LayerNorm([out_channel, input_h * 2, input_w * 2]) - self.relu = nn.ReLU() + # self.end = end + self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2) + self.conv2 = nn.ConvTranspose2d(out_channel, out_channel, 5, 2, 2, output_padding=1) + self.norm1 = nn.LayerNorm([out_channel, input_h, input_w]) + self.norm2 = nn.LayerNorm([out_channel, input_h * 2, input_w * 2]) + self.used_relu = nn.ReLU() + if not end: + self.relu = nn.ReLU() + else: + self.relu = None def forward(self, x): - x = self.conv(x) - x = self.layer_norm(x) - x = self.relu(x) + x = self.conv1(x) + x = self.norm1(x) + x = self.used_relu(x) + x = self.conv2(x) + x = self.norm2(x) + if self.relu: + x = self.relu(x) return x class ConditionalVAE(nn.Module): @@ -167,7 +220,7 @@ def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: i self.dc1 = DecoderBlock(256, 128, H // 16, W // 16) self.dc2 = DecoderBlock(128 + emb_ch, 64, H // 8, W // 8) self.dc3 = DecoderBlock(64 + emb_ch, 32, H // 4, W // 4) - self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2) + self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2, True) def bottle_neck(self, x): assert len(x.shape) == 2 @@ -181,11 +234,11 @@ def bottle_neck(self, x): def encode(self, x, pose_embeds): out1 = self.ec1(x) - # input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1) + # input2 = torch.concat([out1, pose_embeds[0][:,0]], dim=1) out2 = self.ec2(out1) - # input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1) + # input3 = torch.concat([out2, pose_embeds[1][:,0]], dim=1) out3 = self.ec3(out2) - # input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1) + # input4 = torch.concat([out3, pose_embeds[2][:,0]], dim=1) out4 = self.ec4(out3) z_out = self.fc1(self.flatten(out4)) return z_out[:,:self.z_dim], z_out[:,self.z_dim:] @@ -232,8 +285,9 @@ def loss(self, z_mu, z_logvar, img_gen, img_gt, img_recon, img_input): kld = torch.mean( -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0 ) - img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean() - img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean() + # img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean() + # img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean() + img_loss = F.mse_loss(img_gen, img_gt) + F.mse_loss(img_recon, img_input) return self.beta * kld , img_loss def eval_img(self, batch, cond_mask=None): @@ -243,17 +297,34 @@ def eval_img(self, batch, cond_mask=None): pred_img, recon_img = self.decode(z_mu, pose_embeds) return pred_img, recon_img -# class LabelMapping(nn.Module): -# def __init__(self, channel:int=8, category:int=0, f_num:int=1) -> None: -# super().__init__() -# self.channel = channel -# self.category = category -# self.f_num = f_num -# self.fc = nn.Linear(channel, f_num * category) -# self.softmax = nn.Softmax(dim=1) + def freeze_encoder(self): + self.ec1.requires_grad_(False) + self.ec2.requires_grad_(False) + self.ec3.requires_grad_(False) + self.ec4.requires_grad_(False) + self.fc1.requires_grad_(False) -# class ConditionalDeformableVAE(nn.Module): -# def __init__(self, *args, **kwargs) -> None: -# super().__init__(*args, **kwargs) +class PoseMapping(nn.Module): + '''Map the pose(quaternion) to two vectors''' + def __init__(self, embed:int = 64) -> None: + super().__init__() + self.fc_x = nn.Linear(4, embed) + self.fc_y = nn.Linear(4, embed) + def forward(self, x): + assert len(x.shape) == 2 + assert x.shape[1] == 4 + return self.fc_x(x), self.fc_y(x) + +class CDVAEEncBlock(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ConditionalDeformableVAE(nn.Module): + def __init__(self, pose_embed:int=64) -> None: + super().__init__() + self.pose_embed = pose_embed + self.pose_mapping_enc = PoseMapping(pose_embed) + self.pose_mapping_dec = PoseMapping(pose_embed) diff --git a/VAE/train.py b/VAE/train.py index 90c17da..643c62d 100644 --- a/VAE/train.py +++ b/VAE/train.py @@ -34,7 +34,11 @@ def main(args): optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99)) if args.transfer == "": - now = './results/shapenet_SRN_car/' + str(int(time.time())) + if args.exp_name is None: + now = './results/shapenet_SRN_car/' + str(int(time.time())) + else: + now = './results/shapenet_SRN_car/' + args.exp_name + # now = './results/shapenet_SRN_car/' + "no_nonlinear" writer = SummaryWriter(now) step = 0 else: @@ -46,8 +50,10 @@ def main(args): optimizer.load_state_dict(ckpt['optim']) now = args.transfer + # now = args.transfer + "_val" writer = SummaryWriter(now) step = ckpt['step'] + # validation(model, loader, writer, step, args.batch_size) train(model, optimizer, loader, loader_val, writer, now, step, args) @@ -61,10 +67,15 @@ def warmup(optimizer, step, last_step, last_lr): def train(model, optimizer, loader, loader_val, writer, now, step, args): a = 1 + freezed = True ## No freezing for e in range(args.num_epochs): print(f'starting epoch {e}') for img, R, T, K in tqdm(loader): + if not freezed and step > args.freeze_step: + print('freezing encoder') + model.module.freeze_encoder() + freezed = True warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr) B = img.shape[0] @@ -115,6 +126,12 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'): pred_img, recon_img = model.module.eval_img(batch, None) pred_img = pred_img.detach().cpu().numpy() recon_img = recon_img.detach().cpu().numpy() + + writer.add_scalar("val/recon_min", recon_img.min(), global_step=step) + writer.add_scalar("val/recon_max", recon_img.max(), global_step=step) + writer.add_scalar("val/gen_min", pred_img.min(), global_step=step) + writer.add_scalar("val/gen_max", pred_img.max(), global_step=step) + pred_img = ((pred_img.clip(-1, 1)+1)*127.5).astype(np.uint8) recon_img = ((recon_img.clip(-1, 1)+1)*127.5).astype(np.uint8) @@ -132,16 +149,20 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'): parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default="../data/SRN/cars_train") parser.add_argument('--pickle_path', type=str, default="../data/cars.pickle") + parser.add_argument('--exp_name', type=str, default=None) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--num_workers', type=int, default=0) parser.add_argument('--image_size', type=int, default=128) parser.add_argument('--transfer', type=str, default="") parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--num_epochs', type=int, default=100000) - parser.add_argument('--warmup_step', type=int, default=10000000) + + parser.add_argument('--warmup_step', type=int, default=0) + parser.add_argument('--freeze_step', type=float, default=5000) + parser.add_argument('--verbose_interval', type=int, default=500) parser.add_argument('--validation_interval', type=int, default=1000) - parser.add_argument('--save_interval', type=int, default=20) + parser.add_argument('--save_interval', type=int, default=10) parser.add_argument('--save_path', type=str, default="./results") opts = parser.parse_args() main(opts) diff --git a/cal_result.py b/cal_result.py new file mode 100644 index 0000000..e68de70 --- /dev/null +++ b/cal_result.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import math +import cv2 + + + +def psnr(img1, img2): + mse = np.mean((img1/1. - img2/1.) ** 2 ) + if mse < 1.0e-10: + return 100*1.0 + return 10 * math.log10(255.0*255.0/mse) + +def mse(img1,img2): + mse = np.mean((img1/1. - img2/1.) ** 2 ) + return mse + +def ssim(y_true , y_pred): + u_true = np.mean(y_true) + u_pred = np.mean(y_pred) + var_true = np.var(y_true) + var_pred = np.var(y_pred) + std_true = np.sqrt(var_true) + std_pred = np.sqrt(var_pred) + c1 = np.square(0.01*7) + c2 = np.square(0.03*7) + ssim = (2 * u_true * u_pred + c1) * (2 * std_pred * std_true + c2) + denom = (u_true ** 2 + u_pred ** 2 + c1) * (var_pred + var_true + c2) + return ssim / denom + +list_psnr = [] +list_ssim = [] +list_mse = [] +for j in range(1,11): + path1 = f"/public/home/chenzheng/CV_final_proj/CV2-Final/diffusion/sampling/{j}/" #指定输出结果文件夹 + path2 = f"/public/home/chenzheng/CV_final_proj/CV2-Final/diffusion/sampling/{j}/"#指定原图文件夹 + f_nums = len(os.listdir(path1)) + #change if you sample more. + for i in range(0,4): + img_a = cv2.imread(path1+str(i)+'.png') + img_b = cv2.imread(path2+'gt.png') + psnr_num = psnr(img_a, img_b) + ssim_num = ssim(img_a, img_b) + mse_num = mse(img_a,img_b) + list_ssim.append(ssim_num) + list_psnr.append(psnr_num) + list_mse.append(mse_num) +print("平均PSNR:",np.mean(list_psnr))#,list_psnr) +print("平均SSIM:",np.mean(list_ssim))#,list_ssim) +print("平均MSE:",np.mean(list_mse))#,list_mse)