diff --git a/diffusion/SRNdataset.py b/diffusion/SRNdataset.py index b40b3cb..30928f7 100644 --- a/diffusion/SRNdataset.py +++ b/diffusion/SRNdataset.py @@ -2,6 +2,8 @@ import pickle import random import numpy as np +import torch +import torchvision.transforms.functional from PIL import Image from torch.utils.data import Dataset, DataLoader @@ -49,6 +51,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars random.seed(0) random.shuffle(all_the_vid) + self.split = split if split == 'train': self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)] else: @@ -83,19 +86,24 @@ def __getitem__(self, idx): poses.append(pose) imgs = np.stack(imgs, 0) + hue_delta = 0 + if self.split == 'train': + hue_delta = random.random() - 0.5 + adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta) + imgs[1] = adjust_img.numpy() poses = np.stack(poses, 0) R = poses[:, :3, :3] T = poses[:, :3, 3] - return imgs, R, T, K + return imgs, R, T, K, hue_delta -if __name__ == "__main__": - - from torch.utils.data import DataLoader - - d = SRNDataset('train') - dd = d[0] - - for ddd in dd: - print(ddd.shape) +# if __name__ == "__main__": +# +# from torch.utils.data import DataLoader +# +# d = SRNDataset('train') +# dd = d[0] +# +# for ddd in dd: +# print(ddd.shape) diff --git a/diffusion/train.py b/diffusion/train.py index c36fda8..0578e1f 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -1,3 +1,5 @@ +import torchvision.transforms.functional + from xunet import XUNet import torch @@ -63,7 +65,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): for e in range(args.num_epochs): print(f'starting epoch {e}') - for img, R, T, K in tqdm(loader): + for img, R, T, K, hue_delta in tqdm(loader): # validation(model, loader_val, writer, step, args.timesteps, args.batch_size) warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr) @@ -74,7 +76,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): logsnr = utils.logsnr_schedule_cosine(torch.rand((B, ))) - loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, + loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta, loss_type="l2", cond_prob=0.1) loss.backward() optimizer.step() @@ -102,7 +104,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): def validation(model, loader_val, writer, step, timesteps, batch_size=8): model.eval() with torch.no_grad(): - ori_img, R, T, K = next(iter(loader_val)) + ori_img, R, T, K, hue_delta = next(iter(loader_val)) w = torch.tensor([3.0] * batch_size) img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w, timesteps=timesteps) diff --git a/diffusion/xunet.py b/diffusion/xunet.py index 7f79280..3a64c90 100644 --- a/diffusion/xunet.py +++ b/diffusion/xunet.py @@ -7,7 +7,6 @@ import utils from typing import Tuple - lazy = numpy_utils.lazy etils.enp.linalg._tf_or_xnp = lambda x: lazy.get_xnp(x) @@ -426,6 +425,20 @@ def __init__(self, **kwargs): attn_heads=self.attn_heads, use_attn=self.num_resolutions in self.attn_resolutions) + # hue_delta prediction + self.hue_decoder = torch.nn.Sequential( + torch.nn.Conv2d(self.dim_out[-1] * 2, self.dim_out[-1] * 4, kernel_size=3, padding=1), # 1024 -> 2048 + torch.nn.AvgPool2d(kernel_size=2, stride=2), # 8x8 -> 4x4 + torch.nn.ReLU(), + torch.nn.Conv2d(self.dim_out[-1] * 4, self.dim_out[-1] // 4, kernel_size=1, stride=1), # 2048 -> 128 + torch.nn.ReLU(), + torch.nn.Flatten(start_dim=1), # 128x4x4 -> 2048 + torch.nn.Linear(2048, 256), # 4098 -> 256 + torch.nn.ReLU(), + torch.nn.Linear(256, 1), # 256 -> 1 + torch.nn.Tanh() + ) + # Downsampling self.upsample = torch.nn.ModuleDict() for i_level in reversed(range(self.num_resolutions)): @@ -509,6 +522,8 @@ def forward(self, batch, *, cond_mask): h = self.middle(h, emb) + hue_delta = self.hue_decoder(rearrange(h, 'b f c h w -> b (f c) h w')) + # hue_delta = rearrange(hue_delta, '(b f) d -> b f d', b=B, f=2) # upsampling for i_level in reversed(range(self.num_resolutions)): emb = logsnr_emb[..., None, None] + pose_embs[i_level] @@ -524,7 +539,9 @@ def forward(self, batch, *, cond_mask): assert not hs # check hs is empty h = torch.nn.functional.silu(self.lastgn(h)) # [B, F, self.ch, 128, 128] - return rearrange(self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w')), '(b f) c h w -> b f c h w', b=B)[:, 1] + pred_noise = self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w')) + pred_noise = rearrange(pred_noise, '(b f) c h w -> b f c h w', b=B)[:, 1] + return pred_noise, hue_delta # if __name__ == "__main__": # h, w = 56, 56 diff --git a/utils.py b/utils.py index 4c51d5f..49e861f 100644 --- a/utils.py +++ b/utils.py @@ -1,9 +1,11 @@ import torch import numpy as np import torch.nn.functional as F +import torchvision.transforms.functional from tqdm import tqdm from matplotlib import pyplot as plt + def dev(): return torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -38,7 +40,8 @@ def q_sample(z, logsnr, noise): return alpha * z + sigma * noise -def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1, use_color_loss=False): +def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_type="l2", cond_prob=0.1, + use_color_loss=False): B, N, C, H, W = img.shape x = img[:, 0] z = img[:, 1] @@ -53,8 +56,11 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co batch = xt2batch(x=x_condition, logsnr=logsnr, z=z_noisy, R=R, T=T, K=K) - predicted_noise = denoise_model(batch, cond_mask=cond_mask) + predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask) + recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)]) + pred_recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(predicted_noise[i], hue_pred[i]) + for i in range(B)]) if loss_type == 'l1': loss = F.l1_loss(noise.to(dev()), predicted_noise) elif loss_type == 'l2': @@ -71,7 +77,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co color_loss = torch.nn.MSELoss() color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() return loss + color_loss - return loss + hue_loss_weight = F.mse_loss(recovered_img.to(dev()), pred_recovered_img) + hue_loss = F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred)) + hue_loss = hue_loss_weight * hue_loss + return loss + hue_loss @torch.no_grad()