summaryrefslogtreecommitdiff
path: root/new_main.py
blob: 99a67675ba77603b580d8b2039e756a838a53203 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from dataloader import load_ply, save_ply
from gaussian import Gaussian
from scipy.spatial import KDTree
from scipy.optimize import minimize
from merge import merge_geo
from utils import quat2rot
from tqdm import tqdm
import argparse
import numpy as np
import pickle as pkl

NUMBER=8000

parser = argparse.ArgumentParser("main.py")
parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_50000/point_cloud.ply")
parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_60000/point_cloud.ply")
parser.add_argument('-d', '--sh_degrees', type=int, default=3)

args = parser.parse_args()

gaussian_model = load_ply(args.model_path, args.sh_degrees)
gaussian_extended = Gaussian.empty_with_cap(NUMBER)

kd_tree = KDTree(gaussian_model.positions)
pairs = []
for i in tqdm(range(gaussian_model.num_gaussians)):
    scale = gaussian_model.scales[i]
    radius = max(max(scale[0], scale[1]), scale[2]) * 1.5
    points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18)
    for j in points:
        if j < i:
            pairs.append((i, j))

# with open('pairs.pkl', 'wb') as f:
#     pkl.dump(pairs, f)

def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float:
    mu_f = gaussian.positions[id1]
    mu_g = gaussian.positions[id2]
    s_f = gaussian.scales[id1]
    s_g = gaussian.scales[id2]
    r_f = gaussian.rotations[id1]
    r_g = gaussian.rotations[id2]
    o_f = gaussian.opacity[id1]
    o_g = gaussian.opacity[id2]

    mu_0, s_0, r_0, o_0, point_min, vectors = merge_geo(gaussian_model, id1, id2)
    vector_x = vectors[:, 0]
    vector_y = vectors[:, 1]
    vector_z = vectors[:, 2]
    x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0)

    def make_gaussian(mu, s, r, o):
        def calc_inner(pos):
            pos = pos - mu
            S = np.diag(s)
            R = quat2rot(r)
            M = S @ R
            Sigma = np.linalg.inv(M.T @ M)
            return o * np.exp(-0.5 * np.dot(pos, Sigma @ pos))
        return calc_inner
    gaussian_f = make_gaussian(mu_f, s_f, r_f, o_f)
    gaussian_g = make_gaussian(mu_g, s_g, r_g, o_g)

    def target_inner(features) -> float:
        mu = features[:3]
        s = features[3:6]
        r = features[6:10]
        o = features[10:]
        gaussian_h = make_gaussian(mu, s, r, o)

        N = 4

        sum = 0.
        for i in range(N):
            for j in range(N):
                for k in range(N):
                    pos = point_min + (i * vector_x + j * vector_y + k * vector_z) / N
                    sum += (gaussian_f(pos) + gaussian_g(pos) - gaussian_h(pos)) ** 2
        return sum
    
    res = minimize(target_inner, x_0)
    mu_h = res.x[:3]
    s_h = res.x[3:6]
    r_h = res.x[6:10]
    o_h = res.x[10]

    sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
    features_dc = sh[0, :]
    features_rest = sh[1:, :]
    # print(mu_h, s_h, r_h, features_dc, features_rest, o_h)
    
    return mu_h, s_h, r_h, features_dc, features_rest, o_h

cnt = 0
filter = np.ones(gaussian_model.num_gaussians, dtype=bool)
used = set()
for i, j in tqdm(pairs):
    if i in used or j in used:
        continue
    mu, s, r, dc, rest, o = merge_3(gaussian_model, i, j)
    if mu is not None:
        # scale_product = np.prod(scale)
        # scale_product_1 = np.prod(gaussian_model.scales[i])
        # scale_product_2 = np.prod(gaussian_model.scales[j])
        # if scale_product / max(scale_product_1, scale_product_2) > 10.:
        #     continue
        
        cnt += 1
        id = gaussian_extended.add(mu, s, r, dc, rest, o)
        filter[i] = False
        filter[j] = False
        used.add(i)
        used.add(j)

        if cnt == NUMBER:
            break

deleted_gaussian = gaussian_model.copy()
deleted_gaussian.apply_filter(np.logical_not(filter))

gaussian_model.apply_filter(filter)
gaussian_model.concat(gaussian_extended)

deleted_path = args.output_path.replace('60000', '60001')
added_path = args.output_path.replace('60000', '60002')
save_ply(args.output_path, gaussian_model)
save_ply(deleted_path, deleted_gaussian)
save_ply(added_path, gaussian_extended)