From 4257c9804698aaa26d3c1468d8b6dc3f7fb30dea Mon Sep 17 00:00:00 2001 From: libn Date: Fri, 19 May 2023 15:12:38 +0800 Subject: [PATCH] fix bug in MSE color mean loss --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index d27d7cd..28af36a 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,7 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean) + color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean).mean() return loss + color_loss