Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed May 27, 2020
1 parent 47a869b commit e60e99a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
2 changes: 1 addition & 1 deletion datasets/llff.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def read_meta(self):
far*torch.ones_like(rays_o[:, :1])],
1)] # (h*w, 8)

self.all_rays = torch.cat(self.all_rays, 0) # ((N_images-1)*h*w, 3)
self.all_rays = torch.cat(self.all_rays, 0) # ((N_images-1)*h*w, 8)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3)

elif self.split == 'val':
Expand Down
20 changes: 8 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,10 @@ def training_step(self, batch, batch_nb):
rays, rgbs = self.decode_batch(batch)
results = self(rays)
log['train/loss'] = loss = self.loss(results, rgbs)
typ = 'fine' if 'rgb_fine' in results else 'coarse'

with torch.no_grad():
if 'rgb_fine' in results:
psnr_ = psnr(results['rgb_fine'], rgbs)
else:
psnr_ = psnr(results['rgb_coarse'], rgbs)
psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
log['train/psnr'] = psnr_

return {'loss': loss,
Expand All @@ -124,21 +122,19 @@ def validation_step(self, batch, batch_nb):
rgbs = rgbs.squeeze() # (H*W, 3)
results = self(rays)
log = {'val_loss': self.loss(results, rgbs)}
typ = 'fine' if 'rgb_fine' in results else 'coarse'

if batch_nb == 0:
W, H = self.hparams.img_wh
img_fine = results['rgb_fine'].view(H, W, 3).cpu()
img_fine = img_fine.permute(2, 0, 1) # (3, H, W)
img = results[f'rgb_{typ}'].view(H, W, 3).cpu()
img = img.permute(2, 0, 1) # (3, H, W)
img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
depth = visualize_depth(results['depth_fine'].view(H, W)) # (3, H, W)
stack = torch.stack([img_gt, img_fine, depth]) # (3, 3, H, W)
depth = visualize_depth(results[f'depth_{typ}'].view(H, W)) # (3, H, W)
stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W)
self.logger.experiment.add_images('val/GT_pred_depth',
stack, self.global_step)

if 'rgb_fine' in results:
psnr_ = psnr(results['rgb_fine'], rgbs)
else:
psnr_ = psnr(results['rgb_coarse'], rgbs)
psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
log['val_psnr'] = psnr_

return log
Expand Down

0 comments on commit e60e99a

Please sign in to comment.