Skip to content

Commit

Permalink
add hue_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 19, 2023
1 parent 461a2e8 commit 2f6f4f9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
28 changes: 18 additions & 10 deletions diffusion/SRNdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pickle
import random
import numpy as np
import torch
import torchvision.transforms.functional
from PIL import Image
from torch.utils.data import Dataset, DataLoader

Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars

random.seed(0)
random.shuffle(all_the_vid)
self.split = split
if split == 'train':
self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)]
else:
Expand Down Expand Up @@ -83,19 +86,24 @@ def __getitem__(self, idx):
poses.append(pose)

imgs = np.stack(imgs, 0)
hue_delta = 0
if self.split == 'train':
hue_delta = random.random() - 0.5
adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta)
imgs[1] = adjust_img.numpy()
poses = np.stack(poses, 0)
R = poses[:, :3, :3]
T = poses[:, :3, 3]

return imgs, R, T, K
return imgs, R, T, K, hue_delta


if __name__ == "__main__":

from torch.utils.data import DataLoader

d = SRNDataset('train')
dd = d[0]

for ddd in dd:
print(ddd.shape)
# if __name__ == "__main__":
#
# from torch.utils.data import DataLoader
#
# d = SRNDataset('train')
# dd = d[0]
#
# for ddd in dd:
# print(ddd.shape)
8 changes: 5 additions & 3 deletions diffusion/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torchvision.transforms.functional

from xunet import XUNet

import torch
Expand Down Expand Up @@ -63,7 +65,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
for e in range(args.num_epochs):
print(f'starting epoch {e}')

for img, R, T, K in tqdm(loader):
for img, R, T, K, hue_delta in tqdm(loader):
# validation(model, loader_val, writer, step, args.timesteps, args.batch_size)

warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)
Expand All @@ -74,7 +76,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):

logsnr = utils.logsnr_schedule_cosine(torch.rand((B, )))

loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr,
loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta,
loss_type="l2", cond_prob=0.1)
loss.backward()
optimizer.step()
Expand Down Expand Up @@ -102,7 +104,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
def validation(model, loader_val, writer, step, timesteps, batch_size=8):
model.eval()
with torch.no_grad():
ori_img, R, T, K = next(iter(loader_val))
ori_img, R, T, K, hue_delta = next(iter(loader_val))
w = torch.tensor([3.0] * batch_size)
img = utils.sample(model, img=ori_img, R=R, T=T, K=K, w=w, timesteps=timesteps)

Expand Down
21 changes: 19 additions & 2 deletions diffusion/xunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import utils
from typing import Tuple


lazy = numpy_utils.lazy
etils.enp.linalg._tf_or_xnp = lambda x: lazy.get_xnp(x)

Expand Down Expand Up @@ -426,6 +425,20 @@ def __init__(self, **kwargs):
attn_heads=self.attn_heads,
use_attn=self.num_resolutions in self.attn_resolutions)

# hue_delta prediction
self.hue_decoder = torch.nn.Sequential(
torch.nn.Conv2d(self.dim_out[-1] * 2, self.dim_out[-1] * 4, kernel_size=3, padding=1), # 1024 -> 2048
torch.nn.AvgPool2d(kernel_size=2, stride=2), # 8x8 -> 4x4
torch.nn.ReLU(),
torch.nn.Conv2d(self.dim_out[-1] * 4, self.dim_out[-1] // 4, kernel_size=1, stride=1), # 2048 -> 128
torch.nn.ReLU(),
torch.nn.Flatten(start_dim=1), # 128x4x4 -> 2048
torch.nn.Linear(2048, 256), # 4098 -> 256
torch.nn.ReLU(),
torch.nn.Linear(256, 1), # 256 -> 1
torch.nn.Tanh()
)

# Downsampling
self.upsample = torch.nn.ModuleDict()
for i_level in reversed(range(self.num_resolutions)):
Expand Down Expand Up @@ -509,6 +522,8 @@ def forward(self, batch, *, cond_mask):

h = self.middle(h, emb)

hue_delta = self.hue_decoder(rearrange(h, 'b f c h w -> b (f c) h w'))
# hue_delta = rearrange(hue_delta, '(b f) d -> b f d', b=B, f=2)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
emb = logsnr_emb[..., None, None] + pose_embs[i_level]
Expand All @@ -524,7 +539,9 @@ def forward(self, batch, *, cond_mask):
assert not hs # check hs is empty

h = torch.nn.functional.silu(self.lastgn(h)) # [B, F, self.ch, 128, 128]
return rearrange(self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w')), '(b f) c h w -> b f c h w', b=B)[:, 1]
pred_noise = self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w'))
pred_noise = rearrange(pred_noise, '(b f) c h w -> b f c h w', b=B)[:, 1]
return pred_noise, hue_delta

# if __name__ == "__main__":
# h, w = 56, 56
Expand Down
15 changes: 12 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional
from tqdm import tqdm
from matplotlib import pyplot as plt


def dev():
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -38,7 +40,8 @@ def q_sample(z, logsnr, noise):
return alpha * z + sigma * noise


def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1, use_color_loss=False):
def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_type="l2", cond_prob=0.1,
use_color_loss=False):
B, N, C, H, W = img.shape
x = img[:, 0]
z = img[:, 1]
Expand All @@ -53,8 +56,11 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co

batch = xt2batch(x=x_condition, logsnr=logsnr, z=z_noisy, R=R, T=T, K=K)

predicted_noise = denoise_model(batch, cond_mask=cond_mask)
predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask)

recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
pred_recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(predicted_noise[i], hue_pred[i])
for i in range(B)])
if loss_type == 'l1':
loss = F.l1_loss(noise.to(dev()), predicted_noise)
elif loss_type == 'l2':
Expand All @@ -71,7 +77,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co
color_loss = torch.nn.MSELoss()
color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean()
return loss + color_loss
return loss
hue_loss_weight = F.mse_loss(recovered_img.to(dev()), pred_recovered_img)
hue_loss = F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred))
hue_loss = hue_loss_weight * hue_loss
return loss + hue_loss


@torch.no_grad()
Expand Down

0 comments on commit 2f6f4f9

Please sign in to comment.