Skip to content

Commit

Permalink
Fix intrinsics problem
Browse files Browse the repository at this point in the history
  • Loading branch information
yenchenlin committed Jun 8, 2021
1 parent a1e1d27 commit 223fe62
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
50 changes: 35 additions & 15 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
Expand All @@ -17,6 +16,7 @@
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -66,7 +66,7 @@ def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
return all_ret


def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
Expand Down Expand Up @@ -94,7 +94,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
"""
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, focal, c2w)
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
Expand All @@ -104,14 +104,14 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam)
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()

sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

# Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
Expand All @@ -134,7 +134,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
return ret_list + [ret_dict]


def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

H, W, focal = hwf

Expand All @@ -151,7 +151,7 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
for i, c2w in enumerate(tqdm(render_poses)):
print(i, time.time() - t)
t = time.time()
rgb, disp, acc, _ = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgbs.append(rgb.cpu().numpy())
disps.append(disp.cpu().numpy())
if i==0:
Expand Down Expand Up @@ -537,7 +537,7 @@ def train():
args = parser.parse_args()

# Load data

K = None
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
Expand Down Expand Up @@ -579,6 +579,17 @@ def train():
else:
images = images[...,:3]

elif args.dataset_type == 'LINEMOD':
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
print(f'[CHECK HERE] near: {near}, far: {far}.')
i_train, i_val, i_test = i_split

if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]

elif args.dataset_type == 'deepvoxels':

images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
Expand All @@ -601,6 +612,13 @@ def train():
H, W = int(H), int(W)
hwf = [H, W, focal]

if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
])

if args.render_test:
render_poses = np.array(poses[i_test])

Expand Down Expand Up @@ -647,7 +665,7 @@ def train():
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)

rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

Expand All @@ -659,7 +677,7 @@ def train():
if use_batching:
# For random ray batching
print('get rays')
rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
Expand All @@ -673,7 +691,8 @@ def train():
i_batch = 0

# Move training data to GPU
images = torch.Tensor(images).to(device)
if use_batching:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)
Expand Down Expand Up @@ -710,10 +729,11 @@ def train():
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4]

if N_rand is not None:
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)

if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
Expand All @@ -737,7 +757,7 @@ def train():
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)

##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays,
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)

Expand Down Expand Up @@ -782,7 +802,7 @@ def train():
if i%args.i_video==0 and i > 0:
# Turn on testing mode
with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
Expand All @@ -800,7 +820,7 @@ def train():
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
with torch.no_grad():
render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set')


Expand Down
8 changes: 4 additions & 4 deletions run_nerf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,21 @@ def load_weights_from_keras(self, weights):


# Ray helpers
def get_rays(H, W, focal, c2w):
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d


def get_rays_np(H, W, focal, c2w):
def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
Expand Down

0 comments on commit 223fe62

Please sign in to comment.