Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ZERONE182/CV2-Final into main
Browse files Browse the repository at this point in the history
  • Loading branch information
FangYuhai committed May 20, 2023
1 parent 8faae01 commit 5d0f44c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 27 deletions.
28 changes: 18 additions & 10 deletions diffusion/SRNdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pickle
import random
import numpy as np
import torch
import torchvision.transforms.functional
from PIL import Image
from torch.utils.data import Dataset, DataLoader

Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars

random.seed(0)
random.shuffle(all_the_vid)
self.split = split
if split == 'train':
self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)]
else:
Expand Down Expand Up @@ -83,19 +86,24 @@ def __getitem__(self, idx):
poses.append(pose)

imgs = np.stack(imgs, 0)
hue_delta = 0
if self.split == 'train':
hue_delta = random.random() - 0.5
adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta)
imgs[1] = adjust_img.numpy()
poses = np.stack(poses, 0)
R = poses[:, :3, :3]
T = poses[:, :3, 3]

return imgs, R, T, K
return imgs, R, T, K, hue_delta


if __name__ == "__main__":

from torch.utils.data import DataLoader

d = SRNDataset('train')
dd = d[0]

for ddd in dd:
print(ddd.shape)
# if __name__ == "__main__":
#
# from torch.utils.data import DataLoader
#
# d = SRNDataset('train')
# dd = d[0]
#
# for ddd in dd:
# print(ddd.shape)
13 changes: 8 additions & 5 deletions diffusion/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torchvision.transforms.functional

from xunet import XUNet

import torch
Expand Down Expand Up @@ -27,7 +29,7 @@ def main(args):
shuffle=True, drop_last=True,
num_workers=args.num_workers)

model = XUNet(H=args.image_size, W=args.image_size, ch=128)
model = XUNet(use_hue_decoder=args.use_hue_decoder, H=args.image_size, W=args.image_size, ch=128)
model = torch.nn.DataParallel(model)
model.to(utils.dev())

Expand Down Expand Up @@ -63,7 +65,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
for e in range(args.num_epochs):
print(f'starting epoch {e}')

for img, R, T, K in tqdm(loader):
for img, R, T, K, hue_delta in tqdm(loader):
# validation(model, loader_val, writer, step, args.timesteps, args.batch_size)

warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)
Expand All @@ -74,8 +76,8 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):

logsnr = utils.logsnr_schedule_cosine(torch.rand((B, )))

loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr,
loss_type="l2", cond_prob=0.1)
loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta,
loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder)
loss.backward()
optimizer.step()

Expand All @@ -102,7 +104,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
def validation(model, loader_val, writer, step, timesteps, batch_size=8):
model.eval()
with torch.no_grad():
ori_img, R, T, K = next(iter(loader_val))
ori_img, R, T, K, hue_delta = next(iter(loader_val))
w = torch.tensor([3.0] * batch_size)
img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w, timesteps=timesteps)

Expand Down Expand Up @@ -139,5 +141,6 @@ def validation(model, loader_val, writer, step, timesteps, batch_size=8):
parser.add_argument('--save_interval', type=int, default=20)
parser.add_argument('--timesteps', type=int, default=256)
parser.add_argument('--save_path', type=str, default="./results")
parser.add_argument('--use_hue_decoder', action='store_true')
opts = parser.parse_args()
main(opts)
31 changes: 26 additions & 5 deletions diffusion/xunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import utils
from typing import Tuple


lazy = numpy_utils.lazy
etils.enp.linalg._tf_or_xnp = lambda x: lazy.get_xnp(x)

Expand Down Expand Up @@ -367,10 +366,10 @@ class XUNet(torch.nn.Module):
use_pos_emb: bool = True
use_ref_pose_emb: bool = True

def __init__(self, **kwargs):
def __init__(self, use_hue_decoder=False, **kwargs):
self.__dict__.update(kwargs)
super().__init__()

self.use_hue_decoder = use_hue_decoder
assert self.H % (2 ** (
len(self.ch_mult) - 1)) == 0, f"Size of the image must me multiple of {2 ** (len(self.ch_mult) - 1)}"
assert self.W % (2 ** (
Expand Down Expand Up @@ -426,6 +425,21 @@ def __init__(self, **kwargs):
attn_heads=self.attn_heads,
use_attn=self.num_resolutions in self.attn_resolutions)

# hue_delta prediction
if self.use_hue_decoder:
self.hue_decoder = torch.nn.Sequential(
torch.nn.Conv2d(self.dim_out[-1] * 2, self.dim_out[-1] * 4, kernel_size=3, padding=1), # 1024 -> 2048
torch.nn.AvgPool2d(kernel_size=2, stride=2), # 8x8 -> 4x4
torch.nn.ReLU(),
torch.nn.Conv2d(self.dim_out[-1] * 4, self.dim_out[-1] // 4, kernel_size=1, stride=1), # 2048 -> 128
torch.nn.ReLU(),
torch.nn.Flatten(start_dim=1), # 128x4x4 -> 2048
torch.nn.Linear(2048, 256), # 4098 -> 256
torch.nn.ReLU(),
torch.nn.Linear(256, 1), # 256 -> 1
torch.nn.Tanh()
)

# Downsampling
self.upsample = torch.nn.ModuleDict()
for i_level in reversed(range(self.num_resolutions)):
Expand Down Expand Up @@ -508,7 +522,9 @@ def forward(self, batch, *, cond_mask):
emb = logsnr_emb[..., None, None] + pose_embs[-1]

h = self.middle(h, emb)

if self.use_hue_decoder:
hue_delta = self.hue_decoder(rearrange(h, 'b f c h w -> b (f c) h w'))
# hue_delta = rearrange(hue_delta, '(b f) d -> b f d', b=B, f=2)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
emb = logsnr_emb[..., None, None] + pose_embs[i_level]
Expand All @@ -524,7 +540,12 @@ def forward(self, batch, *, cond_mask):
assert not hs # check hs is empty

h = torch.nn.functional.silu(self.lastgn(h)) # [B, F, self.ch, 128, 128]
return rearrange(self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w')), '(b f) c h w -> b f c h w', b=B)[:, 1]
pred_noise = self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w'))
pred_noise = rearrange(pred_noise, '(b f) c h w -> b f c h w', b=B)[:, 1]
if self.use_hue_decoder:
return pred_noise, hue_delta
else:
return pred_noise

# if __name__ == "__main__":
# h, w = 56, 56
Expand Down
33 changes: 26 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional
from tqdm import tqdm
from matplotlib import pyplot as plt


def dev():
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -38,7 +40,8 @@ def q_sample(z, logsnr, noise):
return alpha * z + sigma * noise


def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1, use_color_loss=False):
def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_type="l2", cond_prob=0.1,
use_color_loss=False, use_hue_loss=False):
B, N, C, H, W = img.shape
x = img[:, 0]
z = img[:, 1]
Expand All @@ -52,9 +55,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co
x_condition = torch.where(cond_mask[:, None, None, None], x, torch.randn_like(x))

batch = xt2batch(x=x_condition, logsnr=logsnr, z=z_noisy, R=R, T=T, K=K)

predicted_noise = denoise_model(batch, cond_mask=cond_mask)

if use_hue_loss:
predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask)
else:
predicted_noise = denoise_model(batch, cond_mask=cond_mask)
if loss_type == 'l1':
loss = F.l1_loss(noise.to(dev()), predicted_noise)
elif loss_type == 'l2':
Expand All @@ -71,7 +75,16 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co
color_loss = torch.nn.MSELoss()
color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean()
return loss + color_loss
return loss
if use_hue_loss:
x = x * 0.5 + 0.5
recovered_img = torch.stack(
[torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255)
hue_loss = 0.01 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred))
hue_loss = hue_loss_weight * hue_loss
return loss + hue_loss
else:
return loss


@torch.no_grad()
Expand Down Expand Up @@ -133,9 +146,15 @@ def p_mean_variance(model, x, z, R, T, K, logsnr, logsnr_next, w=2.0):

batch = xt2batch(x, logsnr.repeat(b), z, R, T, K)

pred_noise = model(batch, cond_mask=torch.tensor([True] * b)).detach().cpu()
if model.module.use_hue_decoder:
pred_noise, _ = model(batch, cond_mask=torch.tensor([True] * b))
pred_noise = pred_noise.detach().cpu()
pred_noise_unconditioned, _ = model(batch, cond_mask=torch.tensor([False] * b))
pred_noise_unconditioned = pred_noise_unconditioned.detach().cpu()
else:
pred_noise = model(batch, cond_mask=torch.tensor([True] * b)).detach().cpu()
pred_noise_unconditioned = model(batch, cond_mask=torch.tensor([False] * b)).detach().cpu()
batch['x'] = torch.randn_like(x)
pred_noise_unconditioned = model(batch, cond_mask=torch.tensor([False] * b)).detach().cpu()

pred_noise_final = (1 + w) * pred_noise - w * pred_noise_unconditioned

Expand Down

0 comments on commit 5d0f44c

Please sign in to comment.