1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
from dataloader import load_ply, save_ply
from gaussian import Gaussian
from scipy.spatial import KDTree
from scipy.optimize import minimize
from merge import merge_geo
from utils import quat2rot
from tqdm import tqdm
import argparse
import numpy as np
import pickle as pkl
NUMBER=8000
parser = argparse.ArgumentParser("main.py")
parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_50000/point_cloud.ply")
parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_60000/point_cloud.ply")
parser.add_argument('-d', '--sh_degrees', type=int, default=3)
args = parser.parse_args()
gaussian_model = load_ply(args.model_path, args.sh_degrees)
gaussian_extended = Gaussian.empty_with_cap(NUMBER)
kd_tree = KDTree(gaussian_model.positions)
pairs = []
for i in tqdm(range(gaussian_model.num_gaussians)):
scale = gaussian_model.scales[i]
radius = max(max(scale[0], scale[1]), scale[2]) * 1.5
points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18)
for j in points:
if j < i:
pairs.append((i, j))
# with open('pairs.pkl', 'wb') as f:
# pkl.dump(pairs, f)
def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float:
mu_f = gaussian.positions[id1]
mu_g = gaussian.positions[id2]
s_f = gaussian.scales[id1]
s_g = gaussian.scales[id2]
r_f = gaussian.rotations[id1]
r_g = gaussian.rotations[id2]
o_f = gaussian.opacity[id1]
o_g = gaussian.opacity[id2]
mu_0, s_0, r_0, o_0, point_min, vectors = merge_geo(gaussian_model, id1, id2)
vector_x = vectors[:, 0]
vector_y = vectors[:, 1]
vector_z = vectors[:, 2]
x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0)
def make_gaussian(mu, s, r, o):
def calc_inner(pos):
pos = pos - mu
S = np.diag(s)
R = quat2rot(r)
M = S @ R
Sigma = np.linalg.inv(M.T @ M)
return o * np.exp(-0.5 * np.dot(pos, Sigma @ pos))
return calc_inner
gaussian_f = make_gaussian(mu_f, s_f, r_f, o_f)
gaussian_g = make_gaussian(mu_g, s_g, r_g, o_g)
def target_inner(features) -> float:
mu = features[:3]
s = features[3:6]
r = features[6:10]
o = features[10:]
gaussian_h = make_gaussian(mu, s, r, o)
N = 4
sum = 0.
for i in range(N):
for j in range(N):
for k in range(N):
pos = point_min + (i * vector_x + j * vector_y + k * vector_z) / N
sum += (gaussian_f(pos) + gaussian_g(pos) - gaussian_h(pos)) ** 2
return sum
res = minimize(target_inner, x_0)
mu_h = res.x[:3]
s_h = res.x[3:6]
r_h = res.x[6:10]
o_h = res.x[10]
sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
features_dc = sh[0, :]
features_rest = sh[1:, :]
# print(mu_h, s_h, r_h, features_dc, features_rest, o_h)
return mu_h, s_h, r_h, features_dc, features_rest, o_h
cnt = 0
filter = np.ones(gaussian_model.num_gaussians, dtype=bool)
used = set()
for i, j in tqdm(pairs):
if i in used or j in used:
continue
mu, s, r, dc, rest, o = merge_3(gaussian_model, i, j)
if mu is not None:
# scale_product = np.prod(scale)
# scale_product_1 = np.prod(gaussian_model.scales[i])
# scale_product_2 = np.prod(gaussian_model.scales[j])
# if scale_product / max(scale_product_1, scale_product_2) > 10.:
# continue
cnt += 1
id = gaussian_extended.add(mu, s, r, dc, rest, o)
filter[i] = False
filter[j] = False
used.add(i)
used.add(j)
if cnt == NUMBER:
break
deleted_gaussian = gaussian_model.copy()
deleted_gaussian.apply_filter(np.logical_not(filter))
gaussian_model.apply_filter(filter)
gaussian_model.concat(gaussian_extended)
deleted_path = args.output_path.replace('60000', '60001')
added_path = args.output_path.replace('60000', '60002')
save_ply(args.output_path, gaussian_model)
save_ply(deleted_path, deleted_gaussian)
save_ply(added_path, gaussian_extended)
|