Skip to content

Commit

Permalink
fix bugs in hue_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 19, 2023
1 parent fcf031b commit e23e6e2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
5 changes: 3 additions & 2 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(args):
shuffle=True, drop_last=True,
num_workers=args.num_workers)

model = XUNet(H=args.image_size, W=args.image_size, ch=128)
model = XUNet(use_hue_decoder=args.use_hue_decoder, H=args.image_size, W=args.image_size, ch=128)
model = torch.nn.DataParallel(model)
model.to(utils.dev())

Expand Down Expand Up @@ -77,7 +77,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
logsnr = utils.logsnr_schedule_cosine(torch.rand((B, )))

loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta,
loss_type="l2", cond_prob=0.1)
loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder)
loss.backward()
optimizer.step()

Expand Down Expand Up @@ -141,5 +141,6 @@ def validation(model, loader_val, writer, step, timesteps, batch_size=8):
parser.add_argument('--save_interval', type=int, default=20)
parser.add_argument('--timesteps', type=int, default=256)
parser.add_argument('--save_path', type=str, default="./results")
parser.add_argument('--use_hue_decoder', action='store_true')
opts = parser.parse_args()
main(opts)
23 changes: 14 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def q_sample(z, logsnr, noise):


def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_type="l2", cond_prob=0.1,
use_color_loss=False):
use_color_loss=False, use_hue_loss=False):
B, N, C, H, W = img.shape
x = img[:, 0]
z = img[:, 1]
Expand All @@ -55,10 +55,10 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty
x_condition = torch.where(cond_mask[:, None, None, None], x, torch.randn_like(x))

batch = xt2batch(x=x_condition, logsnr=logsnr, z=z_noisy, R=R, T=T, K=K)

predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask)
x = x * 0.5 + 0.5
recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
if use_hue_loss:
predicted_noise, hue_pred = denoise_model(batch, cond_mask=cond_mask)
else:
predicted_noise = denoise_model(batch, cond_mask=cond_mask)
if loss_type == 'l1':
loss = F.l1_loss(noise.to(dev()), predicted_noise)
elif loss_type == 'l2':
Expand All @@ -75,10 +75,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty
color_loss = torch.nn.MSELoss()
color_loss = ((logsnr.to(dev()) + 20) / 20 * color_loss(img_color_mean.to(dev()), rec_color_mean)).mean()
return loss + color_loss
hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255)
hue_loss = 0.1 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred))
hue_loss = hue_loss_weight * hue_loss
return loss + hue_loss
if use_hue_loss:
x = x * 0.5 + 0.5
recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255)
hue_loss = 0.1 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred))
hue_loss = hue_loss_weight * hue_loss
return loss + hue_loss
else:
return loss


@torch.no_grad()
Expand Down

0 comments on commit e23e6e2

Please sign in to comment.