forked from hustvl/4DGaussians
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraphics_utils.py
132 lines (104 loc) · 3.82 KB
/
graphics_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import math
import numpy as np
from typing import NamedTuple
class BasicPointCloud(NamedTuple):
points : np.array
colors : np.array
normals : np.array
def geom_transform_points(points, transf_matrix):
P, _ = points.shape
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
points_hom = torch.cat([points, ones], dim=1)
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
denom = points_out[..., 3:] + 0.0000001
return (points_out[..., :3] / denom).squeeze(dim=0)
def getWorld2View(R, t):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
return np.float32(Rt)
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2*math.atan(pixels/(2*focal))
def apply_rotation(q1, q2):
"""
Applies a rotation to a quaternion.
Parameters:
q1 (Tensor): The original quaternion.
q2 (Tensor): The rotation quaternion to be applied.
Returns:
Tensor: The resulting quaternion after applying the rotation.
"""
# Extract components for readability
w1, x1, y1, z1 = q1
w2, x2, y2, z2 = q2
# Compute the product of the two quaternions
w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
# Combine the components into a new quaternion tensor
q3 = torch.tensor([w3, x3, y3, z3])
# Normalize the resulting quaternion
q3_normalized = q3 / torch.norm(q3)
return q3_normalized
def batch_quaternion_multiply(q1, q2):
"""
Multiply batches of quaternions.
Args:
- q1 (torch.Tensor): A tensor of shape [N, 4] representing the first batch of quaternions.
- q2 (torch.Tensor): A tensor of shape [N, 4] representing the second batch of quaternions.
Returns:
- torch.Tensor: The resulting batch of quaternions after applying the rotation.
"""
# Calculate the product of each quaternion in the batch
w = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3]
x = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2]
y = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1]
z = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0]
# Combine into new quaternions
q3 = torch.stack((w, x, y, z), dim=1)
# Normalize the quaternions
norm_q3 = q3 / torch.norm(q3, dim=1, keepdim=True)
return norm_q3