Skip to content

Commit

Permalink
add sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Shakiba Kheradmand committed Feb 18, 2023
1 parent 5e013dc commit 03dd3d4
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 23 deletions.
266 changes: 243 additions & 23 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data
from torch.utils.tensorboard import SummaryWriter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -422,7 +423,7 @@ def config_parser():

import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True,
parser.add_argument('--config', is_config_file=True, default="./configs/hotdog.txt",
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
Expand Down Expand Up @@ -452,6 +453,18 @@ def config_parser():
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
parser.add_argument("--image_sampling", action='store_true',
help='whether to do image level sampling or not')
parser.add_argument("--sampling_type", type=str, default="none",
help='options = none / multinomial / rejection / metropolis-hastings')
parser.add_argument("--sigma", type=float, default=2.0,
help='value of sigma in case metropolis-hastings is selected')
parser.add_argument("--weight_exponential", type=float, default=1.0,
help='weight of exponential')
parser.add_argument("--initialize", action='store_true',
help='initialize probability map ')
parser.add_argument("--global_sampling", action='store_true',
help='global sampling at each iteration - slow ')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
Expand Down Expand Up @@ -517,15 +530,15 @@ def config_parser():
help='will take every 1/N images as LLFF test set, paper uses 8')

# logging/saving options
parser.add_argument("--i_print", type=int, default=100,
parser.add_argument("--i_print", type=int, default=500,
help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500,
help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000,
parser.add_argument("--i_testset", type=int, default=10000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000,
parser.add_argument("--i_video", type=int, default=5000,
help='frequency of render_poses video saving')

return parser
Expand Down Expand Up @@ -677,7 +690,43 @@ def train():

print("samples are taking from all samples: ", use_batching)

if use_batching:
if args.image_sampling:
heat_map = torch.zeros((images.shape[0], H, W), dtype=torch.float, device=device)
prob_map = torch.ones((images.shape[0], H, W), dtype=torch.float, device=device)
heat_num = torch.zeros((images.shape[0], H, W), dtype=torch.float, device=device)

if args.initialize:
for image_num in i_train:
print(image_num, len(i_train))
L = 4
pose_train = torch.from_numpy(poses[image_num, :3,:4]).float().cuda()
target_train = torch.from_numpy(images[image_num]).float().cuda()
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_train,
**render_kwargs_test)
coords_train = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1).reshape(-1, 2).long()
heat_map, heat_num, prob_map = update_heat_map(rgb.reshape(-1, 3), target_train.reshape(-1, 3), image_num, coords_train, heat_map, heat_num, prob_map, L, 0)

if use_batching and args.image_sampling:
# For random ray batching
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
hwind = np.arange(0, len(i_train)).repeat(H*W)[:, None, None].repeat(3, 2)
w, h = np.meshgrid(np.linspace(0, W-1, W), np.linspace(0, H-1, H))
h = np.tile(h.flatten(), len(i_train))[:, None, None].repeat(3, 2)
w = np.tile(w.flatten(), len(i_train))[:, None, None].repeat(3, 2)
rays_rgb_main = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb_main = np.transpose(rays_rgb_main, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb_main = np.stack([rays_rgb_main[i] for i in i_train], 0) # train images only
rays_rgb_main = np.reshape(rays_rgb_main, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb_main = np.concatenate([rays_rgb_main, hwind, h, w], 1) # [(N-1)*H*W, ro+rd+rgb+ind+h+w, 3]
rays_rgb = rays_rgb_main.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)
rays_rgb = torch.Tensor(rays_rgb)
print('done')
i_batch = 0
elif use_batching:
# For random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
Expand All @@ -689,10 +738,10 @@ def train():
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)

print('done')
i_batch = 0


# Move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device)
Expand All @@ -702,31 +751,70 @@ def train():


N_iters = 200000 + 1
epoch_num = 0
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)

# Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
# write video
selected_points_all = []
heatmaps_all = []
heatnums_all = []
prob_all = []
prev_sample = torch.empty(0)


start = start + 1
for i in trange(start, N_iters):
time0 = time.time()

# Sample random ray batch
if use_batching:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]

i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
if args.image_sampling:
# Random over all images
num_sample_points = N_rand
# if epoch_num >= 1:
# num_sample_points = N_rand * 2
batch = rays_rgb[i_batch:i_batch+num_sample_points] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
hwindi = batch[3]
hi = batch[4]
wi = batch[5]

i_batch += num_sample_points
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
epoch_num += 1

# if epoch_num >= 1:
# # reject half of the samples
# # (1) extract the heatmap value per selected point
# ind = hwindi[:, 0].cpu().int().numpy()
# hval = heat_map[ind, hi[:, 0].int().cpu().numpy(), wi[:, 0].int().cpu().numpy()]
# ten = torch.cat((hval[:, None], torch.arange(num_sample_points)[:, None]), dim=-1)
# sortvals = ten[ten[:, 0].sort()[1]]
# selected = sortvals[:N_rand]
# batch_rays = batch_rays[selected[:, 1]]

else:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]

i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0

else:
# Random from one image
Expand All @@ -747,13 +835,80 @@ def train():
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
), -1)
if i == start:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
else:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")

coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
elif not args.image_sampling or args.sampling_type == "none":
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
elif args.image_sampling and args.sampling_type == "multinomial":
# m = torch.distributions.categorical.Categorical(prob_map[img_i].flatten())
# samples = m.sample(sample_shape=(N_rand,))
# inds_w = samples % W
# inds_h = (samples / W).long()
# select_coords = torch.cat((inds_h[..., None], inds_w[..., None]), dim=-1)

samples = torch.multinomial(prob_map[img_i].flatten(), N_rand, False)
inds_w = samples % W
inds_h = (samples / W).long()
select_coords = torch.cat((inds_h[..., None], inds_w[..., None]), dim=-1)
elif args.image_sampling and args.sampling_type == "rejection":
num = 0
counter = 0
select_coords = torch.empty(0)
while num < N_rand:
counter += 1
rand_image = torch.rand(H, W) * prob_map[img_i].sum()
pinds = torch.where(rand_image < prob_map[img_i])
pinds = torch.stack(list(pinds), dim=-1)
if pinds.shape[0] > 0:
if pinds.shape[0] > (N_rand - num):
subpinds = np.random.choice(pinds.shape[0], size=[N_rand-num], replace=False)
pinds = pinds[subpinds]
select_coords = torch.concat((select_coords, pinds)).long()

num += pinds.shape[0]
if counter > 1000:
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
num += select_coords.shape[0]
elif args.image_sampling and args.sampling_type == "metropolis-hastings":
if prev_sample.nelement() == 0:
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
prev_sample = select_coords
else:
next_sample = prev_sample + torch.normal(mean=0, std=args.sigma, size=(prev_sample.shape))
next_sample[:, 0] = next_sample[:, 0] % (H-1)
next_sample[:, 1] = next_sample[:, 1] % (W-1)
next_sample = torch.round(next_sample).long()

prev_heat = prob_map[img_i, prev_sample[:, 0], prev_sample[:, 1]]
next_heat = prob_map[img_i, next_sample[:, 0], next_sample[:, 1]]

accept_prob = next_heat / (prev_heat + 1e-7)
rand_image = torch.rand(accept_prob.shape)
accept = rand_image <= accept_prob

select_coords = torch.where(accept.unsqueeze(-1).repeat(1, 2), next_sample, prev_sample)
prev_sample = select_coords


if img_i == i_train[0]:
selected_points = torch.zeros((H, W))
selected_points[select_coords[:, 0], select_coords[:, 1]] = 1
selected_points_all.append(selected_points.cpu())
# writer.add_image("sampled", selected_points, global_step=i, dataformats='HW')

coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
Expand All @@ -778,6 +933,29 @@ def train():
loss.backward()
optimizer.step()

# update the heatmap
if args.image_sampling:
L = 4
if args.global_sampling:
pose_train = poses[img_i, :3,:4]
target_train = torch.from_numpy(images[img_i]).float().cuda()
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_train,
**render_kwargs_test)
coords_train = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1).reshape(-1, 2).long()
heat_map, heat_num, prob_map = update_heat_map(rgb.reshape(-1, 3), target_train.reshape(-1, 3), img_i, coords_train, heat_map, heat_num, prob_map, L, args.weight_exponential)
else:
heat_map, heat_num, prob_map = update_heat_map(rgb, target_s, img_i, select_coords, heat_map, heat_num, prob_map, L, args.weight_exponential)
# heat_map, heat_num, prob_map = update_heat_map(rgb, target_s, hi, wi, hwindi, heat_map, heat_num, prob_map, L, i)
# if args.visualize:
if img_i == i_train[0]:
heatmaps_all.append(heat_map[img_i].cpu().detach().numpy())
heatnums_all.append((heat_num[img_i]/heat_num[img_i].max()).cpu().detach().numpy())
# writer.add_image("heat_map_"+str(img_i), heat_map[img_i].cpu(), global_step=i, dataformats='HW')
# writer.add_image("heat_num_"+str(img_i), (heat_num[img_i]/heat_num[img_i].max()).cpu(), global_step=i, dataformats='HW')
# writer.add_image("prob_map_"+str(ti), (prob_map[ti]).cpu(), global_step=i, dataformats='HW')


# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
Expand All @@ -802,6 +980,13 @@ def train():
}, path)
print('Saved checkpoints at', path)

if i%args.i_video==0 and i > 0:
moviebase = os.path.join(basedir, expname, '{:06d}_'.format(i))
imageio.mimwrite(moviebase + 'heatmap.mp4', to8b(heatmaps_all), fps=10, quality=8)
imageio.mimwrite(moviebase + 'prob.mp4', to8b(prob_all), fps=10, quality=8)
imageio.mimwrite(moviebase + 'heatnum.mp4', to8b(heatnums_all), fps=10, quality=8)
imageio.mimwrite(moviebase + 'selected.mp4', selected_points_all, fps=10, quality=8)

if i%args.i_video==0 and i > 0:
# Turn on testing mode
with torch.no_grad():
Expand Down Expand Up @@ -830,6 +1015,27 @@ def train():

if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")

writer.add_scalar("loss", loss, i)
writer.add_scalar("psnr", psnr, i)

if args.image_sampling and args.sampling_type == "metropolis-hastings":
writer.add_scalar("accept_rate", accept.cpu().sum() / accept.numel(), i)

# also report validation psnr
# Log a rendered validation view to Tensorboard
val_psnrs = 0
for num_i in i_val:
target_val = torch.from_numpy(images[num_i]).cuda()
pose_val = poses[num_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_val,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_val))
val_psnrs += psnr
val_psnrs = val_psnrs / len(i_val)
writer.add_scalar("val_psnr", val_psnrs, i)
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
Expand Down Expand Up @@ -874,6 +1080,20 @@ def train():

global_step += 1

# write test PSNR
test_psnrs = 0
for num_i in i_test:
target_test = torch.from_numpy(images[num_i]).cuda()
pose_test = poses[num_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, c2w=pose_test,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target_test))
test_psnrs += psnr
test_psnrs = test_psnrs / len(i_test)
print("Final Test set PSNR = ", test_psnrs)


if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
Expand Down
Loading

0 comments on commit 03dd3d4

Please sign in to comment.