Skip to content

Commit

Permalink
add validation and vis in VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-yf committed May 24, 2023
1 parent 3e8c204 commit bf60345
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions VAE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tensorboardX import SummaryWriter
import os
import argparse
import cv2


def main(args):
Expand Down Expand Up @@ -49,12 +50,17 @@ def main(args):
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optim'])

now = args.transfer
# now = args.transfer + "_val"
if not args.val:
now = args.transfer
else:
now = args.transfer + "val"
writer = SummaryWriter(now)
step = ckpt['step']
# validation(model, loader, writer, step, args.batch_size)
train(model, optimizer, loader, loader_val, writer, now, step, args)
if args.val:
for i in range(120):
validation(model, loader, writer, step+i, args.batch_size, save_path=now)
else:
train(model, optimizer, loader, loader_val, writer, now, step, args)


def warmup(optimizer, step, last_step, last_lr):
Expand Down Expand Up @@ -112,7 +118,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
now + f"/latest.pt")


def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
def validation(model, loader_val, writer, step, batch_size=8, device='cuda', save_path=None):
# TODO: Add image writer
model.eval()
with torch.no_grad():
Expand Down Expand Up @@ -140,6 +146,23 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
writer.add_images(f"train/pred",pred_img, step)
writer.add_images(f"train/recon", recon_img, step)

# save image locally
if save_path is not None:
save_input = np.transpose(np.copy(input_img[0]), (1, 2, 0))
save_recon = np.transpose(np.copy(recon_img[0]), (1, 2, 0))
save_gt = np.transpose(np.copy(gt_img[0]), (1, 2, 0))
save_pred = np.transpose(np.copy(pred_img[0]), (1, 2, 0))

save_input = cv2.cvtColor(save_input, cv2.COLOR_RGB2BGR)
save_recon = cv2.cvtColor(save_recon, cv2.COLOR_RGB2BGR)
save_gt = cv2.cvtColor(save_gt, cv2.COLOR_RGB2BGR)
save_pred = cv2.cvtColor(save_pred, cv2.COLOR_RGB2BGR)

cv2.imwrite(os.path.join(save_path, str(step) + "input.png"), save_input)
cv2.imwrite(os.path.join(save_path, str(step) + "recon.png"), save_recon)
cv2.imwrite(os.path.join(save_path, str(step) + "gt.png"), save_gt)
cv2.imwrite(os.path.join(save_path, str(step) + "pred.png"), save_pred)

# print('image sampled!')
writer.flush()
model.train()
Expand All @@ -156,6 +179,7 @@ def validation(model, loader_val, writer, step, batch_size=8, device='cuda'):
parser.add_argument('--transfer', type=str, default="")
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num_epochs', type=int, default=100000)
parser.add_argument('--val', action='store_true', default=False)

parser.add_argument('--warmup_step', type=int, default=0)
parser.add_argument('--freeze_step', type=float, default=5000)
Expand Down

0 comments on commit bf60345

Please sign in to comment.