From a0fb810478c002c9da74a32fb12c656bdb813160 Mon Sep 17 00:00:00 2001 From: ZERONE182 <1104865009@qq.com> Date: Thu, 18 May 2023 16:10:43 +0800 Subject: [PATCH 01/13] Main architecture and training procedure of CVAE --- VAE/CVAE.py | 192 +++++++++++++++++++++++++++++++++++++++++----- VAE/SRNdataset.py | 105 +++++++++++++++++++++++++ VAE/train.py | 146 +++++++++++++++++++++++++++++++++++ 3 files changed, 422 insertions(+), 21 deletions(-) create mode 100644 VAE/SRNdataset.py create mode 100644 VAE/train.py diff --git a/VAE/CVAE.py b/VAE/CVAE.py index a7dff9a..f14d674 100644 --- a/VAE/CVAE.py +++ b/VAE/CVAE.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -import visu3d +from einops import rearrange +import visu3d as v3d import numpy as np ### TODO: For Prof. Zeng: you can always reconstruct the whole model as you like. All the code below is not tested, error may occur. @@ -16,17 +17,106 @@ def forward(self, x): H = W = int(np.sqrt(D / 256)) return x.view(B, 256, H, W) -class PoseCondition(nn.Module): - def __init__(self) -> None: - # TODO: Decide to imitate the method done in XUnet. - # Suppose we have pose data R, t, K. (I don't know whether it is paired -> (R1, R2, t1, t1, K1, K2) or just (R, t, K)) - # By using the camera model in visu3d, we can get the ray info for all the pixels in the image. - # The ray origin is 3-dim vector and direction is also 3-dim vector. - # If we just concate the ray info for all the pixel in the image, we can get a tensor in shape (H, W, 6), 6 can be seen as channel? - # NeRF PE is applied on ray origin and ray direction. - # CNN is also used to change the spatial size to be the same as the downsampled image during VAE processing. +def posenc_nerf(x, min_deg=0, max_deg=15): + """Concatenate x and its positional encodings, following NeRF.""" + if min_deg == max_deg: + return x + scales = torch.tensor([2 ** i for i in range(min_deg, max_deg)]).float().to(x) + + xb = rearrange( + (x[..., None, :] * scales[:, None]), "b f h w c d -> b f h w (c d)") + emb = torch.sin(torch.concat([xb, xb + torch.pi / 2.], dim=-1)) + + return torch.concat([x, emb], dim=-1) +# class PoseCondition(nn.Module): +# def __init__(self) -> None: +# # TODO: Decide to imitate the method done in XUnet. +# # By using the camera model in visu3d, we can get the ray info for all the pixels in the image. +# # The ray origin is 3-dim vector and direction is also 3-dim vector. +# # If we just concate the ray info for all the pixel in the image, we can get a tensor in shape (H, W, 6), 6 can be seen as channel? +# # NeRF PE is applied on ray origin and ray direction. +# # CNN is also used to change the spatial size to be the same as the downsampled image during VAE processing. +# super().__init__() + +class PoseConditionProcessor(torch.nn.Module): + + def __init__(self, emb_ch, H, W, + num_resolutions, + use_pos_emb=False, + use_ref_pose_emb=False): + super().__init__() + self.emb_ch = emb_ch + self.num_resolutions = num_resolutions + self.use_pos_emb = use_pos_emb + self.use_ref_pose_emb = use_ref_pose_emb + + D = 144 + # D is related to the max_deg and the min_deg of posenc_nerf together with x.shape[-1] + # So if all the values about are fixed, then no need to change D + if use_pos_emb: + self.pos_emb = torch.nn.Parameter(torch.zeros(D, H, W), requires_grad=True) + torch.nn.init.normal_(self.pos_emb, std=(1 / np.sqrt(D))) + + # if use_ref_pose_emb: + # self.first_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True) + # torch.nn.init.normal_(self.first_emb, std=(1 / np.sqrt(D))) + + # self.other_emb = torch.nn.Parameter(torch.zeros(1, 1, D, 1, 1), requires_grad=True) + # torch.nn.init.normal_(self.other_emb, std=(1 / np.sqrt(D))) + + convs = [] + for i_level in range(self.num_resolutions): + convs.append(torch.nn.Conv2d(in_channels=D, + out_channels=self.emb_ch, + kernel_size=3, + stride=2 ** (i_level+1), padding=1)) + + self.convs = torch.nn.ModuleList(convs) + + def forward(self, batch, cond_mask): + + B, C, H, W = batch['x'].shape + + world_from_cam = v3d.Transform(R=batch['R'].cpu().numpy(), t=batch['t'].cpu().numpy()) + cam_spec = v3d.PinholeCamera(resolution=(H, W), K=batch['K'].unsqueeze(1).cpu().numpy()) + rays = v3d.Camera( + spec=cam_spec, world_from_cam=world_from_cam).rays() + + pose_emb_pos = posenc_nerf(torch.tensor(rays.pos).float().to(batch['x']), min_deg=0, max_deg=15) + pose_emb_dir = posenc_nerf(torch.tensor(rays.dir).float().to(batch['x']), min_deg=0, max_deg=8) + + pose_emb = torch.concat([pose_emb_pos, pose_emb_dir], dim=-1) # [batch, h, w, 144] + + if cond_mask is not None: + assert cond_mask.shape == (B,), (cond_mask.shape, B) + cond_mask = cond_mask[:, None, None, None, None] + pose_emb = torch.where(cond_mask, pose_emb, torch.zeros_like(pose_emb)) # [B, F, H, W, 144] + + pose_emb = rearrange(pose_emb, "b f h w c -> b f c h w") + # pose_emb = torch.tensor(pose_emb).float().to(device) + + # now [B, 1, C=144, H, W] + + if self.use_pos_emb: + pose_emb += self.pos_emb[None, None] + if self.use_ref_pose_emb: + pose_emb = torch.concat([self.first_emb, self.other_emb], dim=1) + pose_emb + # now [B, 2, C=144, H, W] + + pose_embs = [] + for i_level in range(self.num_resolutions): + B, F = pose_emb.shape[:2] + pose_embs.append( + rearrange(self.convs[i_level]( + rearrange(pose_emb, 'b f c h w -> (b f) c h w') + ), + '(b f) c h w -> b f c h w', b=B, f=F + ) + ) + + return pose_embs class EncoderBlock(nn.Module): def __init__(self, in_channel:int, out_channel:int, input_h:int, input_w:int) -> None: @@ -55,27 +145,31 @@ def forward(self, x): return x class ConditionalVAE(nn.Module): - def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: int = 3) -> None: + def __init__(self, H: int = 128, W: int = 128, z_dim: int = 128, n_resolution: int = 3, emb_ch : int = 128) -> None: super().__init__() self.H = H self.W = W self.z_dim = z_dim self.n_resolution = n_resolution + self.emb_ch = emb_ch + self.beta = 1/z_dim - # TODO: The in channel for all the blocks below are wrong, as the pose info needs to be injected. + self.condition_processor = PoseConditionProcessor(emb_ch, H, W, n_resolution) + # TODO: Now hardcode for layers, change to list self.ec1 = EncoderBlock(3, 32, H, W) - self.ec2 = EncoderBlock(32, 64, H // 2, W // 2) - self.ec3 = EncoderBlock(64, 128, H // 4, W // 4) - self.ec4 = EncoderBlock(128, 256, H // 8, W // 8) + self.ec2 = EncoderBlock(32 + emb_ch, 64, H // 2, W // 2) + self.ec3 = EncoderBlock(64 + emb_ch, 128, H // 4, W // 4) + self.ec4 = EncoderBlock(128 + emb_ch, 256, H // 8, W // 8) - self.fc1 = nn.Linear(256 * (H // 16) * (W // 16), z_dim) # for mu - self.fc2 = nn.Linear(256 * (H // 16) * (W // 16), z_dim) # for logvar - self.fc3 = nn.Linear(z_dim, 256 * (H // 16) * (W // 16)) # for decoder + self.flatten = Flatten() + self.fc1 = nn.Linear(256 * (H // 16) * (W // 16), 2*z_dim) # for mu, logvar + self.fc2 = nn.Linear(z_dim, 256 * (H // 16) * (W // 16)) # for decoder + self.unflatten = UnFlatten() self.dc1 = DecoderBlock(256, 128, H // 16, W // 16) - self.dc2 = DecoderBlock(128, 64, H // 8, W // 8) - self.dc3 = DecoderBlock(64, 32, H // 4, W // 4) - self.dc4 = DecoderBlock(32, 3, H // 2, W // 2) + self.dc2 = DecoderBlock(128 + emb_ch, 64, H // 8, W // 8) + self.dc3 = DecoderBlock(64 + emb_ch, 32, H // 4, W // 4) + self.dc4 = DecoderBlock(32 + emb_ch, 3, H // 2, W // 2) def bottle_neck(self, x): assert len(x.shape) == 2 @@ -87,3 +181,59 @@ def bottle_neck(self, x): z_sampled = self.fc3(z_sampled) return z_sampled, mu, logvar + def encode(self, x, pose_embeds): + out1 = self.ec1(x) + input2 = torch.concat([out1, pose_embeds[0][:,0,:]], dim=1) + out2 = self.ec2(input2) + input3 = torch.concat([out2, pose_embeds[1][:,0,:]], dim=1) + out3 = self.ec3(input3) + input4 = torch.concat([out3, pose_embeds[2][:,0,:]], dim=1) + out4 = self.ec4(input4) + z_out = self.fc1(self.flatten(out4)) + return z_out[:,:self.z_dim], z_out[:,self.z_dim:] + + def reparaterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def decode(self, z, pose_embeds): + input1 = self.fc2(z) + out1 = self.dc1(self.unflatten(input1)) + input2 = torch.concat([out1, pose_embeds[2][:,1,:]], dim=1) + out2 = self.dc2(input2) + input3 = torch.concat([out2, pose_embeds[1][:,1,:]], dim=1) + out3 = self.dc3(input3) + input4 = torch.concat([out3, pose_embeds[0][:,1,:]], dim=1) + out4 = self.dc4(input4) + return out4 + + + def forward(self, batch, cond_mask=None): + pose_embeds = self.condition_processor(batch, cond_mask) + # print([pose_embeds[i].shape for i in range(3)]) + x = batch['x'] + z_mu, z_logvar = self.encode(x, pose_embeds) + z = self.reparaterize(z_mu, z_logvar) + img_recon = self.decode(z, pose_embeds) + return self.loss(z_mu, z_logvar, img_recon, x) + + def loss(self, z_mu, z_logvar, img_recon, img_gt): + kld = torch.mean( + -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0 + ) + if torch.isnan(kld): + print(z_mu[0]) + print(z_logvar[0]) + raise RuntimeError("KLD is nan") + img_loss = (img_gt**2 - img_recon**2).mean() + return self.beta * kld , img_loss + + def eval(self, batch, cond_mask=None): + pose_embeds = self.condition_processor(batch, cond_mask) + x = batch['x'] + z_mu, z_logvar = self.encode(x, pose_embeds) + img_recon = self.decode(z_mu, pose_embeds) + return img_recon + + diff --git a/VAE/SRNdataset.py b/VAE/SRNdataset.py new file mode 100644 index 0000000..91a14d4 --- /dev/null +++ b/VAE/SRNdataset.py @@ -0,0 +1,105 @@ +from torch.utils.data import Dataset +import glob +import os +import pickle +import torch +from PIL import Image +import numpy as np +import csv +import torch +import random + + +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever. + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + + +class dataset(Dataset): + + def __init__(self, split, path='./data/SRN/cars_train', picklefile='./data/cars.pickle', imgsize=128): + self.imgsize = imgsize + self.path = path + super().__init__() + self.picklefile = pickle.load(open(picklefile, 'rb')) + + allthevid = sorted(list(self.picklefile.keys())) + + random.seed(0) + random.shuffle(allthevid) + if split == 'train': + self.ids = allthevid[:int(len(allthevid) * 0.9)] + else: + self.ids = allthevid[int(len(allthevid) * 0.9):] + + def __len__(self): + return len(self.ids) + + def __getitem__(self, idx): + + item = self.ids[idx] + + intrinsics_filename = os.path.join(self.path, item, 'intrinsics', self.picklefile[item][0][:-4] + ".txt") + K = np.array(open(intrinsics_filename).read().strip().split()).astype(float).reshape((3, 3)) + + indices = random.sample(self.picklefile[item], k=2) + + imgs = [] + poses = [] + for i in indices: + img_filename = os.path.join(self.path, item, 'rgb', i) + img = Image.open(img_filename) + if self.imgsize != 128: + img = img.resize((self.imgsize, self.imgsize)) + img = np.array(img) / 255 * 2 - 1 + + img = img.transpose(2, 0, 1)[:3].astype(np.float32) + imgs.append(img) + + pose_filename = os.path.join(self.path, item, 'pose', i[:-4] + ".txt") + pose = np.array(open(pose_filename).read().strip().split()).astype(float).reshape((4, 4)) + poses.append(pose) + + imgs = np.stack(imgs, 0) + poses = np.stack(poses, 0) + R = poses[:, :3, :3] + T = poses[:, :3, 3] + + return imgs, R, T, K + + +if __name__ == "__main__": + + from torch.utils.data import DataLoader + + d = dataset('train') + dd = d[0] + + for ddd in dd: + print(ddd.shape) diff --git a/VAE/train.py b/VAE/train.py new file mode 100644 index 0000000..283b9d9 --- /dev/null +++ b/VAE/train.py @@ -0,0 +1,146 @@ +from CVAE import ConditionalVAE + +import torch +from torch.utils.data import DataLoader +from torch.optim import Adam +import numpy as np + +from tqdm import tqdm +from einops import rearrange +import time + +from SRNdataset import dataset, MultiEpochsDataLoader +from tensorboardX import SummaryWriter +import os +import argparse + + +def main(args): + d = dataset('train', path=args.data_path, picklefile=args.pickle_path, imgsize=args.image_size) + d_val = dataset('val', path=args.data_path, picklefile=args.pickle_path, imgsize=args.image_size) + + loader = MultiEpochsDataLoader(d, batch_size=args.batch_size, + shuffle=True, drop_last=True, + num_workers=args.num_workers) + loader_val = DataLoader(d_val, batch_size=args.batch_size, + shuffle=True, drop_last=True, + num_workers=args.num_workers) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = ConditionalVAE(H=args.image_size, W=args.image_size, z_dim=128, n_resolution=3) + model = torch.nn.DataParallel(model) + model.to(device) + + optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99)) + + if args.transfer == "": + now = './results/shapenet_SRN_car/' + str(int(time.time())) + writer = SummaryWriter(now) + step = 0 + else: + print('transferring from: ', args.transfer) + + ckpt = torch.load(os.path.join(args.transfer, 'latest.pt')) + + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optim']) + + now = args.transfer + writer = SummaryWriter(now) + step = ckpt['step'] + train(model, optimizer, loader, loader_val, writer, now, step, args) + + +def warmup(optimizer, step, last_step, last_lr): + if step < last_step: + optimizer.param_groups[0]['lr'] = step / last_step * last_lr + + else: + optimizer.param_groups[0]['lr'] = last_lr + + +def train(model, optimizer, loader, loader_val, writer, now, step, args): + a = 1 + for e in range(args.num_epochs): + print(f'starting epoch {e}') + + for img, R, T, K in tqdm(loader): + warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr) + + B = img.shape[0] + + optimizer.zero_grad() + + batch = {'x':img[:,0], 'z':img[:,1], 'R': R, 't': T, 'K': K,} + + kld_loss, img_loss = model(batch, None) + loss = kld_loss + img_loss + loss.backward() + optimizer.step() + + writer.add_scalar("train/kld_loss", kld_loss.item(), global_step=step) + writer.add_scalar("train/img_loss", img_loss.item(), global_step=step) + writer.add_scalar("train/loss", loss.item(), global_step=step) + writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step=step) + + if step % args.verbose_interval == 0: + print(f"loss: {loss.item()}, kld loss: {kld_loss.item()}, img loss: {img_loss.item()}") + + if step % args.validation_interval == 900: + validation(model, loader_val, writer, step, args.batch_size) + + if step == int(args.warmup_step / args.batch_size): + torch.save({'optim': optimizer.state_dict(), 'model': model.state_dict(), 'step': step}, + now + f"/after_warmup.pt") + + step += 1 + + if e % args.save_interval == 0: + torch.save({'optim': optimizer.state_dict(), 'model': model.state_dict(), 'step': step, 'epoch': e}, + now + f"/latest.pt") + + +def validation(model, loader_val, writer, step, batch_size=8): + # TODO: Add image writer + model.eval() + with torch.no_grad(): + ori_img, R, T, K = next(iter(loader_val)) + # w = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]).repeat(16) + w = torch.tensor([3.0] * batch_size) + # img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w) + + img = rearrange(((img[-1].clip(-1, 1) + 1) * 127.5).astype(np.uint8), + "(b a) c h w -> a c h (b w)", + a=8, b=16) + + gt = rearrange(((ori_img[:, 1] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), + "(b a) c h w -> a c h (b w)", a=8, b=16) + cd = rearrange(((ori_img[:, 0] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), + "(b a) c h w -> a c h (b w)", a=8, b=16) + + fi = np.concatenate([cd, gt, img], axis=2) + for i, ww in enumerate(range(8)): + writer.add_image(f"train/{ww}", fi[i], step) + + print('image sampled!') + writer.flush() + model.train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str, default="../data/SRN/cars_train") + parser.add_argument('--pickle_path', type=str, default="../data/cars.pickle") + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--num_workers', type=int, default=0) + parser.add_argument('--image_size', type=int, default=128) + parser.add_argument('--transfer', type=str, default="") + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--num_epochs', type=int, default=100000) + parser.add_argument('--warmup_step', type=int, default=10000000) + parser.add_argument('--verbose_interval', type=int, default=500) + parser.add_argument('--validation_interval', type=int, default=1000) + parser.add_argument('--save_interval', type=int, default=20) + parser.add_argument('--save_path', type=str, default="./results") + opts = parser.parse_args() + main(opts) From da1a7bce1240f12f270e8c70433a238175fd9d50 Mon Sep 17 00:00:00 2001 From: ZERONE182 <1104865009@qq.com> Date: Thu, 18 May 2023 16:18:05 +0800 Subject: [PATCH 02/13] Fix L2 img loss --- VAE/CVAE.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/VAE/CVAE.py b/VAE/CVAE.py index f14d674..a2145ac 100644 --- a/VAE/CVAE.py +++ b/VAE/CVAE.py @@ -222,11 +222,7 @@ def loss(self, z_mu, z_logvar, img_recon, img_gt): kld = torch.mean( -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0 ) - if torch.isnan(kld): - print(z_mu[0]) - print(z_logvar[0]) - raise RuntimeError("KLD is nan") - img_loss = (img_gt**2 - img_recon**2).mean() + img_loss = ((img_gt - img_recon)**2).mean() return self.beta * kld , img_loss def eval(self, batch, cond_mask=None): From 7d2a2fab979923333e2e3ea34ce6c7ea9904edd9 Mon Sep 17 00:00:00 2001 From: ZERONE182 <1104865009@qq.com> Date: Thu, 18 May 2023 16:53:20 +0800 Subject: [PATCH 03/13] Fix img writer in validation --- VAE/CVAE.py | 2 +- VAE/train.py | 29 +++++++++++------------------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/VAE/CVAE.py b/VAE/CVAE.py index a2145ac..14abc60 100644 --- a/VAE/CVAE.py +++ b/VAE/CVAE.py @@ -225,7 +225,7 @@ def loss(self, z_mu, z_logvar, img_recon, img_gt): img_loss = ((img_gt - img_recon)**2).mean() return self.beta * kld , img_loss - def eval(self, batch, cond_mask=None): + def eval_img(self, batch, cond_mask=None): pose_embeds = self.condition_processor(batch, cond_mask) x = batch['x'] z_mu, z_logvar = self.encode(x, pose_embeds) diff --git a/VAE/train.py b/VAE/train.py index 283b9d9..27f33d3 100644 --- a/VAE/train.py +++ b/VAE/train.py @@ -86,7 +86,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): if step % args.verbose_interval == 0: print(f"loss: {loss.item()}, kld loss: {kld_loss.item()}, img loss: {img_loss.item()}") - if step % args.validation_interval == 900: + if step % args.validation_interval == 0: validation(model, loader_val, writer, step, args.batch_size) if step == int(args.warmup_step / args.batch_size): @@ -100,29 +100,22 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): now + f"/latest.pt") -def validation(model, loader_val, writer, step, batch_size=8): +def validation(model, loader_val, writer, step, batch_size=8, device='cuda'): # TODO: Add image writer model.eval() with torch.no_grad(): ori_img, R, T, K = next(iter(loader_val)) - # w = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]).repeat(16) - w = torch.tensor([3.0] * batch_size) - # img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w) + + batch = {'x':ori_img[:,0].to(device), 'z':ori_img[:,1].to(device), 'R': R.to(device), 't': T.to(device), 'K': K.to(device),} + gt_img = ori_img[:, 1].detach().cpu().numpy() + gt_img = ((gt_img.clip(-1, 1)+1)*127.5).astype(np.uint8) + pred_img = model.module.eval_img(batch, None).detach().cpu().numpy() + pred_img = ((pred_img.clip(-1, 1)+1)*127.5).astype(np.uint8) - img = rearrange(((img[-1].clip(-1, 1) + 1) * 127.5).astype(np.uint8), - "(b a) c h w -> a c h (b w)", - a=8, b=16) + writer.add_images(f"train/gt", gt_img, step) + writer.add_images(f"train/pred",pred_img, step) - gt = rearrange(((ori_img[:, 1] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), - "(b a) c h w -> a c h (b w)", a=8, b=16) - cd = rearrange(((ori_img[:, 0] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), - "(b a) c h w -> a c h (b w)", a=8, b=16) - - fi = np.concatenate([cd, gt, img], axis=2) - for i, ww in enumerate(range(8)): - writer.add_image(f"train/{ww}", fi[i], step) - - print('image sampled!') + # print('image sampled!') writer.flush() model.train() From 24eec2b055c0dcc91ee49a62a1dbef086855c2a4 Mon Sep 17 00:00:00 2001 From: libn Date: Thu, 18 May 2023 17:04:57 +0800 Subject: [PATCH 04/13] change the default value of warmup_step from 10000000 to 30000 --- diffusion/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/train.py b/diffusion/train.py index 7d583d7..edeafb7 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -133,7 +133,7 @@ def validation(model, loader_val, writer, step, batch_size=8): parser.add_argument('--transfer', type=str, default="") parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--num_epochs', type=int, default=100000) - parser.add_argument('--warmup_step', type=int, default=10000000) + parser.add_argument('--warmup_step', type=int, default=30000) parser.add_argument('--verbose_interval', type=int, default=500) parser.add_argument('--validation_interval', type=int, default=1000) parser.add_argument('--save_interval', type=int, default=20) From cdbc0cabd77fe1b5f1cf29890db8c51eb1de2d5b Mon Sep 17 00:00:00 2001 From: libn Date: Thu, 18 May 2023 17:30:32 +0800 Subject: [PATCH 05/13] fixed bug in validation --- diffusion/train.py | 10 ++++------ utils.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index edeafb7..e032047 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -101,22 +101,20 @@ def validation(model, loader_val, writer, step, batch_size=8): model.eval() with torch.no_grad(): ori_img, R, T, K = next(iter(loader_val)) - # w = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]).repeat(16) w = torch.tensor([3.0] * batch_size) img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w) img = rearrange(((img[-1].clip(-1, 1) + 1) * 127.5).astype(np.uint8), "(b a) c h w -> a c h (b w)", - a=8, b=16) + a=1, b=batch_size) gt = rearrange(((ori_img[:, 1] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), - "(b a) c h w -> a c h (b w)", a=8, b=16) + "(b a) c h w -> a c h (b w)", a=1, b=batch_size) cd = rearrange(((ori_img[:, 0] + 1) * 127.5).detach().cpu().numpy().astype(np.uint8), - "(b a) c h w -> a c h (b w)", a=8, b=16) + "(b a) c h w -> a c h (b w)", a=1, b=batch_size) fi = np.concatenate([cd, gt, img], axis=2) - for i, ww in enumerate(range(8)): - writer.add_image(f"train/{ww}", fi[i], step) + writer.add_image(f"train/val_{step}", fi[0], step) print('image sampled!') writer.flush() diff --git a/utils.py b/utils.py index 6997221..3cf4e01 100644 --- a/utils.py +++ b/utils.py @@ -77,7 +77,7 @@ def sample(model, img, R, T, K, w, timesteps=256): logsnr_nexts = logsnr_schedule_cosine(torch.linspace(1., 0., timesteps + 1)[1:]) for logsnr, logsnr_next in tqdm(zip(logsnrs, logsnr_nexts)): # [1, ..., 0] = size is 257 - img = p_sample(model, x=x, z=img, R=R, T=T, K=K, logsnr=logsnr, logsnr_next=logsnr_next, w=w) + img = p_sample(model, x=x, z=img, R=R, T=T, K=K, logsnr=logsnr, logsnr_next=logsnr_next, w=w) # [B, C, H, W] imgs.append(img.cpu().numpy()) return imgs From 9d018d062119c4ba6d1f2041613b5dbd17ba3a74 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 14:27:49 +0800 Subject: [PATCH 06/13] add MSE color mean loss --- diffusion/train.py | 17 ++++++++++------- utils.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index e032047..c36fda8 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -40,7 +40,7 @@ def main(args): else: print('transferring from: ', args.transfer) - ckpt = torch.load(os.path.join(args.transfer, 'latest.pt')) + ckpt = torch.load(os.path.join(args.transfer, 'latest.pt'), map_location=torch.device(utils.dev())) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optim']) @@ -64,16 +64,18 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): print(f'starting epoch {e}') for img, R, T, K in tqdm(loader): + # validation(model, loader_val, writer, step, args.timesteps, args.batch_size) + warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr) B = img.shape[0] optimizer.zero_grad() - logsnr = utils.logsnr_schedule_cosine(torch.rand((B,))) + logsnr = utils.logsnr_schedule_cosine(torch.rand((B, ))) - loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, loss_type="l2", - cond_prob=0.1) + loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, + loss_type="l2", cond_prob=0.1) loss.backward() optimizer.step() @@ -84,7 +86,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): print("Loss:", loss.item()) if step % args.validation_interval == 900: - validation(model, loader_val, writer, step, args.batch_size) + validation(model, loader_val, writer, step, args.timesteps, args.batch_size) if step == int(args.warmup_step / args.batch_size): torch.save({'optim': optimizer.state_dict(), 'model': model.state_dict(), 'step': step}, @@ -97,12 +99,12 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): now + f"/latest.pt") -def validation(model, loader_val, writer, step, batch_size=8): +def validation(model, loader_val, writer, step, timesteps, batch_size=8): model.eval() with torch.no_grad(): ori_img, R, T, K = next(iter(loader_val)) w = torch.tensor([3.0] * batch_size) - img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w) + img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w, timesteps=timesteps) img = rearrange(((img[-1].clip(-1, 1) + 1) * 127.5).astype(np.uint8), "(b a) c h w -> a c h (b w)", @@ -135,6 +137,7 @@ def validation(model, loader_val, writer, step, batch_size=8): parser.add_argument('--verbose_interval', type=int, default=500) parser.add_argument('--validation_interval', type=int, default=1000) parser.add_argument('--save_interval', type=int, default=20) + parser.add_argument('--timesteps', type=int, default=256) parser.add_argument('--save_path', type=str, default="./results") opts = parser.parse_args() main(opts) diff --git a/utils.py b/utils.py index 3cf4e01..1bb187d 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,7 @@ import numpy as np import torch.nn.functional as F from tqdm import tqdm - +from matplotlib import pyplot as plt def dev(): return torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -39,7 +39,7 @@ def q_sample(z, logsnr, noise): def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1): - B = img.shape[0] + B, N, C, H, W = img.shape x = img[:, 0] z = img[:, 1] if noise is None: @@ -64,7 +64,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co else: raise NotImplementedError() - return loss + rec_img = reconstruct_z_start(z_noisy, predicted_noise, logsnr) + + img_color_mean = torch.mean(z, dim=(2, 3)) + rec_color_mean = torch.mean(rec_img, dim=(2, 3)) + + color_loss = torch.nn.MSELoss() + color_loss = (logsnr + 20) / 20 * color_loss(img_color_mean, rec_color_mean) + + return loss + color_loss @torch.no_grad() @@ -82,6 +90,24 @@ def sample(model, img, R, T, K, w, timesteps=256): return imgs +def reconstruct_z_start(z_noisy, pred_noise, logsnr): + B = z_noisy.shape[0] + logsnr_next = torch.tensor([20.0] * B) + c = - torch.special.expm1(logsnr - logsnr_next)[:, None, None, None] + squared_alpha, squared_alpha_next = logsnr.sigmoid(), logsnr_next.sigmoid() + squared_sigma, squared_sigma_next = (-logsnr).sigmoid(), (-logsnr_next).sigmoid() + alpha, sigma, alpha_next = map(lambda a: a.sqrt(), (squared_alpha, squared_sigma, squared_alpha_next)) + alpha = alpha[:, None, None, None] + sigma = sigma[:, None, None, None] + alpha_next = alpha_next[:, None, None, None] + + z_start = (z_noisy - sigma * pred_noise) / alpha + z_start.clamp_(-1., 1.) + + z_start = alpha_next * (z_noisy * (1 - c) / alpha + c * z_start) + return z_start + + @torch.no_grad() def p_sample(model, x, z, R, T, K, logsnr, logsnr_next, w): model_mean, model_variance = p_mean_variance(model, x=x, z=z, R=R, T=T, K=K, logsnr=logsnr, logsnr_next=logsnr_next, From 15c785bcdf0e2f1d41842bae686741bd82780117 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 14:45:23 +0800 Subject: [PATCH 07/13] fix bug in MSE color mean loss --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 1bb187d..8735ebf 100644 --- a/utils.py +++ b/utils.py @@ -92,7 +92,7 @@ def sample(model, img, R, T, K, w, timesteps=256): def reconstruct_z_start(z_noisy, pred_noise, logsnr): B = z_noisy.shape[0] - logsnr_next = torch.tensor([20.0] * B) + logsnr_next = torch.tensor([20.0] * B).to(dev()) c = - torch.special.expm1(logsnr - logsnr_next)[:, None, None, None] squared_alpha, squared_alpha_next = logsnr.sigmoid(), logsnr_next.sigmoid() squared_sigma, squared_sigma_next = (-logsnr).sigmoid(), (-logsnr_next).sigmoid() From c94187d8072762164e92ec4a2d8519c319ee6ba2 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 14:57:01 +0800 Subject: [PATCH 08/13] fix bug in MSE color mean loss --- utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/utils.py b/utils.py index 8735ebf..4377edd 100644 --- a/utils.py +++ b/utils.py @@ -64,13 +64,12 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co else: raise NotImplementedError() - rec_img = reconstruct_z_start(z_noisy, predicted_noise, logsnr) - + rec_img = reconstruct_z_start(z_noisy, predicted_noise, logsnr.to(dev())) img_color_mean = torch.mean(z, dim=(2, 3)) - rec_color_mean = torch.mean(rec_img, dim=(2, 3)) + rec_color_mean = torch.mean(rec_img, dim=(2, 3)).to(dev()) color_loss = torch.nn.MSELoss() - color_loss = (logsnr + 20) / 20 * color_loss(img_color_mean, rec_color_mean) + color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean, rec_color_mean) return loss + color_loss From 9c4f8a7c0241cf3e4608dbf5dff715fc1212799f Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:00:31 +0800 Subject: [PATCH 09/13] fix bug in MSE color mean loss --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 4377edd..4a12f04 100644 --- a/utils.py +++ b/utils.py @@ -64,7 +64,7 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co else: raise NotImplementedError() - rec_img = reconstruct_z_start(z_noisy, predicted_noise, logsnr.to(dev())) + rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev())) img_color_mean = torch.mean(z, dim=(2, 3)) rec_color_mean = torch.mean(rec_img, dim=(2, 3)).to(dev()) From 2ea23c9e65cf0e8ad2af6a0128f778bfd0b7b391 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:10:18 +0800 Subject: [PATCH 10/13] fix bug in MSE color mean loss --- utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.py b/utils.py index 4a12f04..d27d7cd 100644 --- a/utils.py +++ b/utils.py @@ -66,10 +66,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev())) img_color_mean = torch.mean(z, dim=(2, 3)) - rec_color_mean = torch.mean(rec_img, dim=(2, 3)).to(dev()) + rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean, rec_color_mean) + color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean) return loss + color_loss From 4257c9804698aaa26d3c1468d8b6dc3f7fb30dea Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:12:38 +0800 Subject: [PATCH 11/13] fix bug in MSE color mean loss --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index d27d7cd..28af36a 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,7 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean) + color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean).mean() return loss + color_loss From 45f2bcad9a013782307a0d5472dacb7aaaab3dc3 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:15:28 +0800 Subject: [PATCH 12/13] fix bug in MSE color mean loss --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 28af36a..4a636cc 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,7 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean).mean() + color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() return loss + color_loss From b2934c85d4b7b7b65a66cb7092cc3cf3a31bd697 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 18:16:44 +0800 Subject: [PATCH 13/13] fix bug in MSE color mean loss --- utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/utils.py b/utils.py index 4a636cc..4c51d5f 100644 --- a/utils.py +++ b/utils.py @@ -38,7 +38,7 @@ def q_sample(z, logsnr, noise): return alpha * z + sigma * noise -def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1): +def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1, use_color_loss=False): B, N, C, H, W = img.shape x = img[:, 0] z = img[:, 1] @@ -63,15 +63,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co loss = F.smooth_l1_loss(noise.to(dev()), predicted_noise) else: raise NotImplementedError() - - rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev())) - img_color_mean = torch.mean(z, dim=(2, 3)) - rec_color_mean = torch.mean(rec_img, dim=(2, 3)) - - color_loss = torch.nn.MSELoss() - color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() - - return loss + color_loss + if use_color_loss: + rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev())) + img_color_mean = torch.mean(z, dim=(2, 3)) + rec_color_mean = torch.mean(rec_img, dim=(2, 3)) + + color_loss = torch.nn.MSELoss() + color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() + return loss + color_loss + return loss @torch.no_grad()