From 2ea23c9e65cf0e8ad2af6a0128f778bfd0b7b391 Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:10:18 +0800 Subject: [PATCH] fix bug in MSE color mean loss --- utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.py b/utils.py index 4a12f04..d27d7cd 100644 --- a/utils.py +++ b/utils.py @@ -66,10 +66,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev())) img_color_mean = torch.mean(z, dim=(2, 3)) - rec_color_mean = torch.mean(rec_img, dim=(2, 3)).to(dev()) + rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean, rec_color_mean) + color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean) return loss + color_loss