diff --git a/utils.py b/utils.py index 28af36a..4a636cc 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,7 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co rec_color_mean = torch.mean(rec_img, dim=(2, 3)) color_loss = torch.nn.MSELoss() - color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean).mean() + color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean() return loss + color_loss