Skip to content

Commit

Permalink
add test_only, fix prev_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakiba Kheradmand committed Mar 7, 2023
1 parent 4f99ff9 commit 50e9e31
Showing 1 changed file with 94 additions and 22 deletions.
116 changes: 94 additions & 22 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def config_parser():
help='whether to do image level sampling or not')
parser.add_argument("--sampling_type", type=str, default="multinomial",
help='options = none / multinomial / rejection / metropolis-hastings')
parser.add_argument("--sigma", type=float, default=200.0,
parser.add_argument("--sigma", type=float, default=2.0,
help='value of sigma in case metropolis-hastings is selected')
parser.add_argument("--weight_exponential", type=float, default=1.0,
help='weight of exponential')
Expand Down Expand Up @@ -495,7 +495,8 @@ def config_parser():
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

parser.add_argument("--test_only", action='store_true',
help='do not optimize, reload weights and write final psnr')
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
Expand Down Expand Up @@ -694,6 +695,77 @@ def train():

return

# Short circuit if only rendering out from trained model
if args.test_only:
print('TEST ONLY')
with torch.no_grad():
import xlsxwriter
if args.ft_path is not None and args.ft_path!='None':
ckpts = [args.ft_path]
else:
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

result_path = os.path.join(basedir, expname, 'results.xlsx')
workbook = xlsxwriter.Workbook(os.path.join(basedir, expname, 'results.xlsx'))

if os.path.exists(result_path):
print("shakiba")
# check if the iteration is the same
import pandas as pd
df = pd.read_excel(result_path, usecols=[0], header=1, nrows=0, index_col=None)
itr = df.columns.values[0]
print(int(itr), int(ckpts[-1].split("/")[-1][:-4]))
if int(itr) >= int(ckpts[-1].split("/")[-1][:-4]):
print("The test results is already available for iteration", int(itr))
return

val_psnrs = 0
for num_i in i_val:
print("val", num_i)
target_val = torch.from_numpy(images[num_i]).cuda()
pose_val = torch.from_numpy(poses[num_i, :3,:4]).cuda()
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_val))
val_psnrs += psnr

train_psnrs = 0
for num_i in i_train:
print("train", num_i)
target_val = torch.from_numpy(images[num_i]).cuda()
pose_val = torch.from_numpy(poses[num_i, :3,:4]).cuda()
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_val))
train_psnrs += psnr

test_psnrs = 0
for num_i in i_test:
print("test", num_i)
target_val = torch.from_numpy(images[num_i]).cuda()
pose_val = torch.from_numpy(poses[num_i, :3,:4]).cuda()
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_val))
test_psnrs += psnr

worksheet1 = workbook.add_worksheet()
worksheet1.write(0, 0, 'iteration')
worksheet1.write(0, 1, 'val psnr')
worksheet1.write(0, 2, 'train psnr')
worksheet1.write(0, 3, 'test psnr')
worksheet1.write(0, 4, 'all psnr')
worksheet1.write(1, 0, ckpts[-1].split("/")[-1][:-4])
worksheet1.write(1, 1, val_psnrs / len(i_val))
worksheet1.write(1, 2, train_psnrs / len(i_train))
worksheet1.write(1, 3, test_psnrs / len(i_test))
worksheet1.write(1, 4, (val_psnrs + train_psnrs + test_psnrs)/(len(i_train)+len(i_val)+len(i_test)))
workbook.close()
return

# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
use_batching = not args.no_batching
Expand Down Expand Up @@ -791,7 +863,7 @@ def train():
heatmaps_all = []
heatnums_all = []
prob_all = []
prev_sample = torch.empty(0)
prev_sample = torch.zeros((images.shape[0], N_rand, 2)).long()


start = start + 1
Expand Down Expand Up @@ -870,13 +942,13 @@ def train():
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
prev_sample[img_i] = select_coords
elif not args.image_sampling or args.sampling_type == "none":
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
prev_sample[img_i] = select_coords
elif args.image_sampling and args.sampling_type == "multinomial":
# m = torch.distributions.categorical.Categorical(prob_map[img_i].flatten())
# samples = m.sample(sample_shape=(N_rand,))
Expand Down Expand Up @@ -910,27 +982,27 @@ def train():
select_coords = coords[select_inds].long() # (N_rand, 2)
num += select_coords.shape[0]
elif args.image_sampling and args.sampling_type == "metropolis-hastings":
if prev_sample.nelement() == 0:
if prev_sample[img_i].sum() == 0:
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
prev_sample[img_i] = select_coords
else:
next_sample = prev_sample + torch.normal(mean=0, std=args.sigma, size=(prev_sample.shape))
next_sample = prev_sample[img_i] + torch.normal(mean=0, std=args.sigma, size=(prev_sample[img_i].shape))
next_sample[:, 0] = next_sample[:, 0] % (H-1)
next_sample[:, 1] = next_sample[:, 1] % (W-1)
next_sample = torch.round(next_sample).long()

prev_heat = prob_map[img_i, prev_sample[:, 0], prev_sample[:, 1]]
prev_heat = prob_map[img_i, prev_sample[img_i][:, 0], prev_sample[img_i][:, 1]]
next_heat = prob_map[img_i, next_sample[:, 0], next_sample[:, 1]]

accept_prob = next_heat / (prev_heat + 1e-7)
rand_image = torch.rand(accept_prob.shape)
accept = rand_image <= accept_prob

select_coords = torch.where(accept.unsqueeze(-1).repeat(1, 2), next_sample, prev_sample)
prev_sample = select_coords
select_coords = torch.where(accept.unsqueeze(-1).repeat(1, 2), next_sample, prev_sample[img_i])
prev_sample[img_i] = select_coords.long()


if img_i == i_train[0]:
Expand Down Expand Up @@ -1085,18 +1157,18 @@ def train():
val_psnrs = val_psnrs / len(i_val)
writer.add_scalar("val_psnr", val_psnrs, i)

train_psnrs = 0
for num_i in i_train:
target_val = torch.from_numpy(images[num_i]).cuda()
pose_val = poses[num_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
**render_kwargs_test)
# train_psnrs = 0
# for num_i in i_train:
# target_val = torch.from_numpy(images[num_i]).cuda()
# pose_val = poses[num_i, :3,:4]
# with torch.no_grad():
# rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
# **render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_val))
train_psnrs += psnr
train_psnrs = train_psnrs / len(i_val)
writer.add_scalar("train_psnr", train_psnrs, i)
# psnr = mse2psnr(img2mse(rgb, target_val))
# train_psnrs += psnr
# train_psnrs = train_psnrs / len(i_train)
# writer.add_scalar("train_psnr", train_psnrs, i)
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
Expand Down

0 comments on commit 50e9e31

Please sign in to comment.