From f9ab5b4c7129066b299225f7d75c75975ce7365b Mon Sep 17 00:00:00 2001 From: nick-yf Date: Sun, 21 May 2023 23:30:42 +0800 Subject: [PATCH] add input image in tensorboard --- VAE/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/VAE/train.py b/VAE/train.py index 27f33d3..16455d6 100644 --- a/VAE/train.py +++ b/VAE/train.py @@ -107,11 +107,13 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'): ori_img, R, T, K = next(iter(loader_val)) batch = {'x':ori_img[:,0].to(device), 'z':ori_img[:,1].to(device), 'R': R.to(device), 't': T.to(device), 'K': K.to(device),} + input_img = ori_img[:, 0].detach().cpu().numpy() gt_img = ori_img[:, 1].detach().cpu().numpy() gt_img = ((gt_img.clip(-1, 1)+1)*127.5).astype(np.uint8) pred_img = model.module.eval_img(batch, None).detach().cpu().numpy() pred_img = ((pred_img.clip(-1, 1)+1)*127.5).astype(np.uint8) + writer.add_images(f"train/input", input_img, step) writer.add_images(f"train/gt", gt_img, step) writer.add_images(f"train/pred",pred_img, step)