From e23e6e2ace600bb2cb435db5f3d4ff5932c41ec7 Mon Sep 17 00:00:00 2001 From: libn Date: Sat, 20 May 2023 02:23:47 +0800 Subject: [PATCH] fix bugs in hue_loss --- diffusion/train.py | 5 +++-- utils.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index 0578e1f..48fad31 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -29,7 +29,7 @@ def main(args): shuffle=True, drop_last=True, num_workers=args.num_workers) - model = XUNet(H=args.image_size, W=args.image_size, ch=128) + model = XUNet(use_hue_decoder=args.use_hue_decoder, H=args.image_size, W=args.image_size, ch=128) model = torch.nn.DataParallel(model) model.to(utils.dev()) @@ -77,7 +77,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): logsnr = utils.logsnr_schedule_cosine(torch.rand((B, ))) loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta, - loss_type="l2", cond_prob=0.1) + loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder) loss.backward() optimizer.step() @@ -141,5 +141,6 @@ def validation(model, loader_val, writer, step, timesteps, batch_size=8): parser.add_argument('--save_interval', type=int, default=20) parser.add_argument('--timesteps', type=int, default=256) parser.add_argument('--save_path', type=str, default="./results") + parser.add_argument('--use_hue_decoder', action='store_true') opts = parser.parse_args() main(opts) diff --git a/utils.py b/utils.py index b3f2234..ecc0c80 100644 --- a/utils.py +++ b/utils.py @@ -41,7 +41,7 @@ def q_sample(z, logsnr, noise): def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_type="l2", cond_prob=0.1, - use_color_loss=False): + use_color_loss=False, use_hue_loss=False): B, N, C, H, W = img.shape x = img[:, 0] z = img[:, 1] @@ -55,10 +55,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty x_condition = torch.where(cond_mask[:, None, None, None], x, torch.randn_like(x)) batch = xt2batch(x=x_condition, logsnr=logsnr, z=z_noisy, R=R, T=T, K=K) - - predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask) - x = x * 0.5 + 0.5 - recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)]) + if use_hue_loss: + predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask) + else: + predicted_noise = denoise_model(batch, cond_mask=cond_mask) if loss_type == 'l1': loss = F.l1_loss(noise.to(dev()), predicted_noise) elif loss_type == 'l2': @@ -75,10 +75,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty color_loss = torch.nn.MSELoss() color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() return loss + color_loss - hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255) - hue_loss = 0.1 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred)) - hue_loss = hue_loss_weight * hue_loss - return loss + hue_loss + if use_hue_loss: + x = x * 0.5 + 0.5 + recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)]) + hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255) + hue_loss = 0.1 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred)) + hue_loss = hue_loss_weight * hue_loss + return loss + hue_loss + else: + return loss @torch.no_grad()