Skip to content

Commit

Permalink
fix bugs in hue_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 20, 2023
1 parent e23e6e2 commit 3b068cd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
print(f'starting epoch {e}')

for img, R, T, K, hue_delta in tqdm(loader):
# validation(model, loader_val, writer, step, args.timesteps, args.batch_size)
validation(model, loader_val, writer, step, args.timesteps, args.batch_size)

warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)

Expand Down
13 changes: 10 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty
return loss + color_loss
if use_hue_loss:
x = x * 0.5 + 0.5
recovered_img = torch.stack([torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
recovered_img = torch.stack(
[torchvision.transforms.functional.adjust_hue(x[i], hue_delta[i]) for i in range(B)])
hue_loss_weight = F.mse_loss(recovered_img.to(dev()) * 255, x.to(dev()) * 255)
hue_loss = 0.1 * F.mse_loss(hue_pred.squeeze(), hue_delta.to(hue_pred))
hue_loss = hue_loss_weight * hue_loss
Expand Down Expand Up @@ -145,9 +146,15 @@ def p_mean_variance(model, x, z, R, T, K, logsnr, logsnr_next, w=2.0):

batch = xt2batch(x, logsnr.repeat(b), z, R, T, K)

pred_noise = model(batch, cond_mask=torch.tensor([True] * b)).detach().cpu()
if model.module.use_hue_decoder:
pred_noise, _ = model(batch, cond_mask=torch.tensor([True] * b))
pred_noise = pred_noise.detach().cpu()
pred_noise_unconditioned, _ = model(batch, cond_mask=torch.tensor([False] * b))
pred_noise_unconditioned = pred_noise_unconditioned.detach().cpu()
else:
pred_noise = model(batch, cond_mask=torch.tensor([True] * b)).detach().cpu()
pred_noise_unconditioned = model(batch, cond_mask=torch.tensor([False] * b)).detach().cpu()
batch['x'] = torch.randn_like(x)
pred_noise_unconditioned = model(batch, cond_mask=torch.tensor([False] * b)).detach().cpu()

pred_noise_final = (1 + w) * pred_noise - w * pred_noise_unconditioned

Expand Down

0 comments on commit 3b068cd

Please sign in to comment.