Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 24, 2023
2 parents 01356eb + 3e11c16 commit 5004ce5
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 31 deletions.
127 changes: 99 additions & 28 deletions VAE/CVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ def posenc_nerf(x, min_deg=0, max_deg=15):

return torch.concat([x, emb], dim=-1)

def rt_to_quaternion(R:torch.tensor, t:torch.tensor):
"""Converts rotation matrix and translation vector to quaternion."""
# R: [B, 3, 3]
# t: [B, 3]
# q: [B, 4]
trace = torch.trace(R)
trace = torch.clamp(trace, min=-1, max=3)
if trace > 0:
qw = 0.5 * torch.sqrt(1 + trace)
qx = R[2, 1] - R[1, 2] / (4 * qw)
qy = R[0, 2] - R[2, 0] / (4 * qw)
qz = R[1, 0] - R[0, 1] / (4 * qw)
else:
max_diag = torch.argmax(torch.diag(R))
if max_diag == 0:
qx = 0.5 * torch.sqrt(1 + R[0, 0] - R[1, 1] - R[2, 2])
qy = R[0, 1] + R[1, 0] / (4 * qx)
qz = R[0, 2] + R[2, 0] / (4 * qx)
qw = R[2, 1] - R[1, 2] / (4 * qx)
elif max_diag == 1:
qy = 0.5 * torch.sqrt(1 + R[1, 1] - R[0, 0] - R[2, 2])
qx = R[0, 1] + R[1, 0] / (4 * qy)
qz = R[1, 2] + R[2, 1] / (4 * qy)
qw = R[0, 2] - R[2, 0] / (4 * qy)
elif max_diag == 2:
qz = 0.5 * torch.sqrt(1 + R[2, 2] - R[0, 0] - R[1, 1])
qx = R[0, 2] + R[2, 0] / (4 * qz)
qy = R[1, 2] + R[2, 1] / (4 * qz)
qw = R[1, 0] - R[0, 1] / (4 * qz)
else:
qw = 0
qx = 0
qy = 0
qz = 0
q = torch.tensor([qw, qx, qy, qz])
return q

class PoseConditionProcessor(torch.nn.Module):

def __init__(self, emb_ch, H, W,
Expand Down Expand Up @@ -116,27 +153,43 @@ def forward(self, batch, cond_mask):
class EncoderBlock(nn.Module):
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 5, 2, 2)
self.norm = nn.LayerNorm([out_channel, input_h // 2, input_w // 2])
self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2)
self.conv2 = nn.Conv2d(out_channel, out_channel, 5, 2, 2)
self.norm1 = nn.LayerNorm([out_channel, input_h, input_w])
self.norm2 = nn.LayerNorm([out_channel, input_h // 2, input_w // 2])
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
return x

class DecoderBlock(nn.Module):
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None:
def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int, end:bool=False) -> None:
super().__init__()
self.conv = nn.ConvTranspose2d(in_channel, out_channel, 5, 2, 2, output_padding=1)
self.layer_norm = nn.LayerNorm([out_channel, input_h * 2, input_w * 2])
self.relu = nn.ReLU()
# self.end = end
self.conv1 = nn.Conv2d(in_channel, out_channel, 5, 1, 2)
self.conv2 = nn.ConvTranspose2d(out_channel, out_channel, 5, 2, 2, output_padding=1)
self.norm1 = nn.LayerNorm([out_channel, input_h, input_w])
self.norm2 = nn.LayerNorm([out_channel, input_h * 2, input_w * 2])
self.used_relu = nn.ReLU()
if not end:
self.relu = nn.ReLU()
else:
self.relu = None

def forward(self, x):
x = self.conv(x)
x = self.layer_norm(x)
x = self.relu(x)
x = self.conv1(x)
x = self.norm1(x)
x = self.used_relu(x)
x = self.conv2(x)
x = self.norm2(x)
if self.relu:
x = self.relu(x)
return x

class ConditionalVAE(nn.Module):
Expand Down Expand Up @@ -167,7 +220,7 @@ def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: i
self.dc1 = DecoderBlock(256, 128, H // 16, W // 16)
self.dc2 = DecoderBlock(128 + emb_ch, 64, H // 8, W // 8)
self.dc3 = DecoderBlock(64 + emb_ch, 32, H // 4, W // 4)
self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2)
self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2, True)

def bottle_neck(self, x):
assert len(x.shape) == 2
Expand All @@ -181,11 +234,11 @@ def bottle_neck(self, x):

def encode(self, x, pose_embeds):
out1 = self.ec1(x)
# input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1)
# input2 = torch.concat([out1, pose_embeds[0][:,0]], dim=1)
out2 = self.ec2(out1)
# input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1)
# input3 = torch.concat([out2, pose_embeds[1][:,0]], dim=1)
out3 = self.ec3(out2)
# input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1)
# input4 = torch.concat([out3, pose_embeds[2][:,0]], dim=1)
out4 = self.ec4(out3)
z_out = self.fc1(self.flatten(out4))
return z_out[:,:self.z_dim], z_out[:,self.z_dim:]
Expand Down Expand Up @@ -232,8 +285,9 @@ def loss(self, z_mu, z_logvar, img_gen, img_gt, img_recon, img_input):
kld = torch.mean(
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
)
img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean()
img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean()
# img_loss = ((img_gen - img_gt)**2).sum(dim=(1,2,3)).mean()
# img_loss += ((img_recon - img_input)**2).sum(dim=(1,2,3)).mean()
img_loss = F.mse_loss(img_gen, img_gt) + F.mse_loss(img_recon, img_input)
return self.beta * kld , img_loss

def eval_img(self, batch, cond_mask=None):
Expand All @@ -243,17 +297,34 @@ def eval_img(self, batch, cond_mask=None):
pred_img, recon_img = self.decode(z_mu, pose_embeds)
return pred_img, recon_img

# class LabelMapping(nn.Module):
# def __init__(self, channel:int=8, category:int=0, f_num:int=1) -> None:
# super().__init__()
# self.channel = channel
# self.category = category
# self.f_num = f_num
# self.fc = nn.Linear(channel, f_num * category)
# self.softmax = nn.Softmax(dim=1)
def freeze_encoder(self):
self.ec1.requires_grad_(False)
self.ec2.requires_grad_(False)
self.ec3.requires_grad_(False)
self.ec4.requires_grad_(False)
self.fc1.requires_grad_(False)

# class ConditionalDeformableVAE(nn.Module):
# def __init__(self, *args, **kwargs) -> None:
# super().__init__(*args, **kwargs)
class PoseMapping(nn.Module):
'''Map the pose(quaternion) to two vectors'''
def __init__(self, embed:int = 64) -> None:
super().__init__()
self.fc_x = nn.Linear(4, embed)
self.fc_y = nn.Linear(4, embed)

def forward(self, x):
assert len(x.shape) == 2
assert x.shape[1] == 4
return self.fc_x(x), self.fc_y(x)

class CDVAEEncBlock(nn.Module):
def __init__(self) -> None:
super().__init__()


class ConditionalDeformableVAE(nn.Module):
def __init__(self, pose_embed:int=64) -> None:
super().__init__()
self.pose_embed = pose_embed

self.pose_mapping_enc = PoseMapping(pose_embed)
self.pose_mapping_dec = PoseMapping(pose_embed)
27 changes: 24 additions & 3 deletions VAE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def main(args):
optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))

if args.transfer == "":
now = './results/shapenet_SRN_car/' + str(int(time.time()))
if args.exp_name is None:
now = './results/shapenet_SRN_car/' + str(int(time.time()))
else:
now = './results/shapenet_SRN_car/' + args.exp_name
# now = './results/shapenet_SRN_car/' + "no_nonlinear"
writer = SummaryWriter(now)
step = 0
else:
Expand All @@ -46,8 +50,10 @@ def main(args):
optimizer.load_state_dict(ckpt['optim'])

now = args.transfer
# now = args.transfer + "_val"
writer = SummaryWriter(now)
step = ckpt['step']
# validation(model, loader, writer, step, args.batch_size)
train(model, optimizer, loader, loader_val, writer, now, step, args)


Expand All @@ -61,10 +67,15 @@ def warmup(optimizer, step, last_step, last_lr):

def train(model, optimizer, loader, loader_val, writer, now, step, args):
a = 1
freezed = True ## No freezing
for e in range(args.num_epochs):
print(f'starting epoch {e}')

for img, R, T, K in tqdm(loader):
if not freezed and step > args.freeze_step:
print('freezing encoder')
model.module.freeze_encoder()
freezed = True
warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)

B = img.shape[0]
Expand Down Expand Up @@ -115,6 +126,12 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
pred_img, recon_img = model.module.eval_img(batch, None)
pred_img = pred_img.detach().cpu().numpy()
recon_img = recon_img.detach().cpu().numpy()

writer.add_scalar("val/recon_min", recon_img.min(), global_step=step)
writer.add_scalar("val/recon_max", recon_img.max(), global_step=step)
writer.add_scalar("val/gen_min", pred_img.min(), global_step=step)
writer.add_scalar("val/gen_max", pred_img.max(), global_step=step)

pred_img = ((pred_img.clip(-1, 1)+1)*127.5).astype(np.uint8)
recon_img = ((recon_img.clip(-1, 1)+1)*127.5).astype(np.uint8)

Expand All @@ -132,16 +149,20 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default="../data/SRN/cars_train")
parser.add_argument('--pickle_path', type=str, default="../data/cars.pickle")
parser.add_argument('--exp_name', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--image_size', type=int, default=128)
parser.add_argument('--transfer', type=str, default="")
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num_epochs', type=int, default=100000)
parser.add_argument('--warmup_step', type=int, default=10000000)

parser.add_argument('--warmup_step', type=int, default=0)
parser.add_argument('--freeze_step', type=float, default=5000)

parser.add_argument('--verbose_interval', type=int, default=500)
parser.add_argument('--validation_interval', type=int, default=1000)
parser.add_argument('--save_interval', type=int, default=20)
parser.add_argument('--save_interval', type=int, default=10)
parser.add_argument('--save_path', type=str, default="./results")
opts = parser.parse_args()
main(opts)
50 changes: 50 additions & 0 deletions cal_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import numpy as np
import math
import cv2



def psnr(img1, img2):
mse = np.mean((img1/1. - img2/1.) ** 2 )
if mse < 1.0e-10:
return 100*1.0
return 10 * math.log10(255.0*255.0/mse)

def mse(img1,img2):
mse = np.mean((img1/1. - img2/1.) ** 2 )
return mse

def ssim(y_true , y_pred):
u_true = np.mean(y_true)
u_pred = np.mean(y_pred)
var_true = np.var(y_true)
var_pred = np.var(y_pred)
std_true = np.sqrt(var_true)
std_pred = np.sqrt(var_pred)
c1 = np.square(0.01*7)
c2 = np.square(0.03*7)
ssim = (2 * u_true * u_pred + c1) * (2 * std_pred * std_true + c2)
denom = (u_true ** 2 + u_pred ** 2 + c1) * (var_pred + var_true + c2)
return ssim / denom

list_psnr = []
list_ssim = []
list_mse = []
for j in range(1,11):
path1 = f"/public/home/chenzheng/CV_final_proj/CV2-Final/diffusion/sampling/{j}/" #指定输出结果文件夹
path2 = f"/public/home/chenzheng/CV_final_proj/CV2-Final/diffusion/sampling/{j}/"#指定原图文件夹
f_nums = len(os.listdir(path1))
#change if you sample more.
for i in range(0,4):
img_a = cv2.imread(path1+str(i)+'.png')
img_b = cv2.imread(path2+'gt.png')
psnr_num = psnr(img_a, img_b)
ssim_num = ssim(img_a, img_b)
mse_num = mse(img_a,img_b)
list_ssim.append(ssim_num)
list_psnr.append(psnr_num)
list_mse.append(mse_num)
print("平均PSNR:",np.mean(list_psnr))#,list_psnr)
print("平均SSIM:",np.mean(list_ssim))#,list_ssim)
print("平均MSE:",np.mean(list_mse))#,list_mse)

0 comments on commit 5004ce5

Please sign in to comment.