Skip to content

Commit

Permalink
fix bugs in sampling.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 22, 2023
1 parent 5ce5620 commit 87e4601
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
4 changes: 3 additions & 1 deletion diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(args):
model = torch.nn.DataParallel(model)
model.to(utils.dev())

ckpt = torch.load(args.model)
ckpt = torch.load(args.model, map_location=torch.device(utils.dev()))
model.load_state_dict(ckpt['model'])

w = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
Expand All @@ -54,6 +54,8 @@ def main(args):
data_Rs[0],
data_Ts[0]]]

data_K = data_K.repeat(b, 1, 1)

result_dir = os.path.join(args.result_dir, os.path.basename(args.target))

os.makedirs(os.path.join(result_dir, '0'), exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
print(f'starting epoch {e}')

for img, R, T, K, hue_delta in tqdm(loader):
# validation(model, loader_val, writer, step, args.timesteps, args.batch_size)
validation(model, loader_val, writer, step, args.timesteps, args.batch_size)

warmup(optimizer, step, args.warmup_step / args.batch_size, args.lr)

Expand Down
9 changes: 9 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def p_losses(denoise_model, img, R, T, K, logsnr, hue_delta, noise=None, loss_ty
loss = F.smooth_l1_loss(noise.to(dev()), predicted_noise)
else:
raise NotImplementedError()
# rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev()))
# plt.subplot(1, 2, 1)
# plt.imshow(rec_img[0].permute(1, 2, 0).detach().cpu().numpy())
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(z[0].permute(1, 2, 0).detach().cpu().numpy())
# plt.axis('off')
# plt.savefig("test.png")
# plt.show()
if use_color_loss:
rec_img = reconstruct_z_start(z_noisy.to(dev()), predicted_noise, logsnr.to(dev()))
img_color_mean = torch.mean(z, dim=(2, 3))
Expand Down

0 comments on commit 87e4601

Please sign in to comment.