-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathrender.py
168 lines (132 loc) · 6.93 KB
/
render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (C) 2023, Gaussian-Grouping
# Gaussian-Grouping research group, https://github.com/lkeab/gaussian-grouping
# All rights reserved.
#
# ------------------------------------------------------------------------
# Modified from codes in Gaussian-Splatting
# GRAPHDECO research group, https://team.inria.fr/graphdeco
import torch
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from PIL import Image
import colorsys
import cv2
from sklearn.decomposition import PCA
def feature_to_rgb(features):
# Input features shape: (16, H, W)
# Reshape features for PCA
H, W = features.shape[1], features.shape[2]
features_reshaped = features.view(features.shape[0], -1).T
# Apply PCA and get the first 3 components
pca = PCA(n_components=3)
pca_result = pca.fit_transform(features_reshaped.cpu().numpy())
# Reshape back to (H, W, 3)
pca_result = pca_result.reshape(H, W, 3)
# Normalize to [0, 255]
pca_normalized = 255 * (pca_result - pca_result.min()) / (pca_result.max() - pca_result.min())
rgb_array = pca_normalized.astype('uint8')
return rgb_array
def id2rgb(id, max_num_obj=256):
if not 0 <= id <= max_num_obj:
raise ValueError("ID should be in range(0, max_num_obj)")
# Convert the ID into a hue value
golden_ratio = 1.6180339887
h = ((id * golden_ratio) % 1) # Ensure value is between 0 and 1
s = 0.5 + (id % 2) * 0.5 # Alternate between 0.5 and 1.0
l = 0.5
# Use colorsys to convert HSL to RGB
rgb = np.zeros((3, ), dtype=np.uint8)
if id==0: #invalid region
return rgb
r, g, b = colorsys.hls_to_rgb(h, l, s)
rgb[0], rgb[1], rgb[2] = int(r*255), int(g*255), int(b*255)
return rgb
def visualize_obj(objects):
rgb_mask = np.zeros((*objects.shape[-2:], 3), dtype=np.uint8)
all_obj_ids = np.unique(objects)
for id in all_obj_ids:
colored_mask = id2rgb(id)
rgb_mask[objects == id] = colored_mask
return rgb_mask
def render_set(model_path, name, iteration, views, gaussians, pipeline, background, classifier):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
colormask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "objects_feature16")
gt_colormask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_objects_color")
pred_obj_path = os.path.join(model_path, name, "ours_{}".format(iteration), "objects_pred")
makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(colormask_path, exist_ok=True)
makedirs(gt_colormask_path, exist_ok=True)
makedirs(pred_obj_path, exist_ok=True)
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
results = render(view, gaussians, pipeline, background)
rendering = results["render"]
rendering_obj = results["render_object"]
logits = classifier(rendering_obj)
pred_obj = torch.argmax(logits,dim=0)
pred_obj_mask = visualize_obj(pred_obj.cpu().numpy().astype(np.uint8))
gt_objects = view.objects
gt_rgb_mask = visualize_obj(gt_objects.cpu().numpy().astype(np.uint8))
rgb_mask = feature_to_rgb(rendering_obj)
Image.fromarray(rgb_mask).save(os.path.join(colormask_path, '{0:05d}'.format(idx) + ".png"))
Image.fromarray(gt_rgb_mask).save(os.path.join(gt_colormask_path, '{0:05d}'.format(idx) + ".png"))
Image.fromarray(pred_obj_mask).save(os.path.join(pred_obj_path, '{0:05d}'.format(idx) + ".png"))
gt = view.original_image[0:3, :, :]
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
out_path = os.path.join(render_path[:-8],'concat')
makedirs(out_path,exist_ok=True)
fourcc = cv2.VideoWriter.fourcc(*'DIVX')
size = (gt.shape[-1]*5,gt.shape[-2])
fps = float(5) if 'train' in out_path else float(1)
writer = cv2.VideoWriter(os.path.join(out_path,'result.mp4'), fourcc, fps, size)
for file_name in sorted(os.listdir(gts_path)):
gt = np.array(Image.open(os.path.join(gts_path,file_name)))
rgb = np.array(Image.open(os.path.join(render_path,file_name)))
gt_obj = np.array(Image.open(os.path.join(gt_colormask_path,file_name)))
render_obj = np.array(Image.open(os.path.join(colormask_path,file_name)))
pred_obj = np.array(Image.open(os.path.join(pred_obj_path,file_name)))
result = np.hstack([gt,rgb,gt_obj,pred_obj,render_obj])
result = result.astype('uint8')
Image.fromarray(result).save(os.path.join(out_path,file_name))
writer.write(result[:,:,::-1])
writer.release()
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
num_classes = dataset.num_classes
print("Num classes: ",num_classes)
classifier = torch.nn.Conv2d(gaussians.num_objects, num_classes, kernel_size=1)
classifier.cuda()
classifier.load_state_dict(torch.load(os.path.join(dataset.model_path,"point_cloud","iteration_"+str(scene.loaded_iter),"classifier.pth")))
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, classifier)
if (not skip_test) and (len(scene.getTestCameras()) > 0):
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, classifier)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)