Skip to content

Commit

Permalink
fix bug in MSE color mean loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 19, 2023
1 parent 45f2bca commit b2934c8
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def q_sample(z, logsnr, noise):
return alpha * z + sigma * noise


def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1):
def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", cond_prob=0.1, use_color_loss=False):
B, N, C, H, W = img.shape
x = img[:, 0]
z = img[:, 1]
Expand All @@ -63,15 +63,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, noise=None, loss_type="l2", co
loss = F.smooth_l1_loss(noise.to(dev()), predicted_noise)
else:
raise NotImplementedError()

rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev()))
img_color_mean = torch.mean(z, dim=(2, 3))
rec_color_mean = torch.mean(rec_img, dim=(2, 3))

color_loss = torch.nn.MSELoss()
color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean()

return loss + color_loss
if use_color_loss:
rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev()))
img_color_mean = torch.mean(z, dim=(2, 3))
rec_color_mean = torch.mean(rec_img, dim=(2, 3))

color_loss = torch.nn.MSELoss()
color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean()
return loss + color_loss
return loss


@torch.no_grad()
Expand Down

0 comments on commit b2934c8

Please sign in to comment.