forked from mikh3x4/nerf-navigation
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisual_helpers.py
115 lines (76 loc) · 3.31 KB
/
visual_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os, sys
import torch
from torchtyping import TensorDetail, TensorType
from typeguard import typechecked
from nerf_core import *
from run_nerf_helpers import *
import matplotlib.pyplot as plt
import numpy as np
def draw_pyramid(points):
base_pts = points[:-1, :]
origin = points[-1, :]
for tt in range(4):
trianglex = base_pts[[tt,(tt+1)%4], 0]
np.append(trianglex, origin[0])
triangley = base_pts[[tt,(tt+1)%4], 1]
np.append(triangley, origin[1])
plt.fill(trianglex, triangley, facecolor='lightsalmon', edgecolor='orangered', alpha=0.1)
plt.quiver(origin[0], origin[1], (base_pts[tt, 0] - origin[0]), (base_pts[tt, 1] - origin[1]), headaxislength=0, headlength=0)
plt.fill(base_pts[:, 0], base_pts[:, 1], facecolor='lightsalmon', edgecolor='orangered', alpha=0.1)
#plt.scatter(points[:, 0], points[:, 1])
return
def scale(points, scale_factor):
base_pts = points[:-1, :]
origin = points[-1, :]
for tt in range(4):
points[tt, :] = scale_factor*(base_pts[tt, :] - origin) + origin
return points
def visualize(background_pose: TensorType[4, 4], true_poses: TensorType["T", 3, 5], pose_estimates: TensorType["T", 4, 4], savedir: str, render_args, render_kwargs_train) -> None:
#Plot background image
scale_factor = 0.1
hwf = render_args['hwf']
K = render_args['K']
chunk = render_args['chunk']
H, W, focal = hwf
rgb, _, _, _ = render(H, W, K, chunk=chunk, c2w=background_pose[:3, :4], **render_kwargs_train)
fig = plt.figure(1)
plt.imshow(rgb.cpu().detach().numpy())
background_pose = background_pose.cpu().detach().numpy()
def pix2world(vertices, rot, dt, origin):
dirs = np.zeros((vertices.shape[0], 3))
for i, row in enumerate(vertices):
dir = np.array([(row[0] - K[0][2])/K[0][0], -(row[1] - K[1][2])/K[1][1], -dt])
dir = (rot@(dir.reshape(-1, 1)) + origin.reshape(-1, 1)).reshape(-1)
print(dir)
dirs[i, :] = dir
return dirs
def world2pix(pts, rot, origin):
pixs = np.zeros((pts.shape[0], 2))
for i, row in enumerate(pts):
pix = rot.T @ (row.reshape(-1, 1) - origin.reshape(-1, 1))
pix = [pix[0]*K[0][0] + K[0][2], -pix[1]*K[1][1] + K[1][2]]
pixs[i, :] = pix
return pixs
def pose2pixel(pose, K):
_, rays_d = get_rays(H, W, K, torch.Tensor(pose))
K_full = np.hstack([K, np.zeros((3, 1))])
dt = 1.
#Get pose origin in world frame
pose = pose.cpu().detach().numpy()
origin = pose[:3, 3]
print('Origin', origin)
rot = pose[:3, :3]
print('Rotation', rot)
vertices = np.array([[0, 0], [0, H-1], [W-1, H-1], [W-1, 0]])
pts = pix2world(vertices, rot, dt, origin)
pixs = world2pix(pts, background_pose[:3, :3], background_pose[:3, 3])
pix_origin = world2pix(np.array([origin]), background_pose[:3, :3], background_pose[:3, 3])
return np.vstack([pixs, pix_origin])
for t in trange(true_poses.shape[0]):
pixels = pose2pixel(true_poses[t], K)
pixels = scale(pixels, scale_factor)
draw_pyramid(pixels)
filename = os.path.join(savedir, '{:03d}.png'.format(t))
plt.savefig(filename)
print('Done Plotting')
return