Skip to content

Commit

Permalink
Add load_LINEMOD.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yenchenlin committed Jun 9, 2021
1 parent 223fe62 commit ec26d1c
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions load_LINEMOD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2


trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w


def load_LINEMOD_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
metas[s] = json.load(fp)

all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip

for idx_test, frame in enumerate(meta['frames'][::skip]):
fname = frame['file_path']
if s == 'test':
print(f"{idx_test}th test frame: {fname}")
imgs.append(imageio.imread(fname))
poses.append(np.array(frame['transform_matrix']))
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)

i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]

imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)

H, W = imgs[0].shape[:2]
focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
K = meta['frames'][0]['intrinsic_matrix']
print(f"Focal: {focal}")

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)

if half_res:
H = H//2
W = W//2
focal = focal/2.

imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()

near = np.floor(min(metas['train']['near'], metas['test']['near']))
far = np.ceil(max(metas['train']['far'], metas['test']['far']))
return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far


0 comments on commit ec26d1c

Please sign in to comment.