Skip to content

Commit

Permalink
fix bug in MSE color mean loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 19, 2023
1 parent 9c4f8a7 commit 2ea23c9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co

rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev()))
img_color_mean = torch.mean(z, dim=(2, 3))
rec_color_mean = torch.mean(rec_img, dim=(2, 3)).to(dev())
rec_color_mean = torch.mean(rec_img, dim=(2, 3))

color_loss = torch.nn.MSELoss()
color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean, rec_color_mean)
color_loss = (logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)

return loss + color_loss

Expand Down

0 comments on commit 2ea23c9

Please sign in to comment.