Skip to content

Commit

Permalink
Better training interface
Browse files Browse the repository at this point in the history
  • Loading branch information
yenchenlin committed Apr 18, 2020
1 parent c3ccc0b commit 7158181
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -673,7 +673,7 @@ def train():
rays_rgb = torch.Tensor(rays_rgb).to(device)


N_iters = 1000000
N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
Expand All @@ -682,7 +682,7 @@ def train():
# Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

for i in range(start, N_iters):
for i in trange(start, N_iters):
time0 = time.time()

# Sample random ray batch
Expand Down Expand Up @@ -745,7 +745,7 @@ def train():
################################

dt = time.time()-time0
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
# print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end #####

# Rest is logging
Expand Down Expand Up @@ -784,11 +784,13 @@ def train():
print('Saved test set')


"""

if i%args.i_print==0 or i < 10:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
Expand Down

0 comments on commit 7158181

Please sign in to comment.