summaryrefslogtreecommitdiff
path: root/main.py
blob: 521f341e77229bd8c2ab4f1e283dfe0c0c7df2d7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from dataloader import load_ply, save_ply
from gaussian import Gaussian
from aabb import process_aabb
from similarity import similarity
from scipy.spatial import KDTree
from merge import merge, merge_2, merge_3, merge_4
from tqdm import tqdm
import argparse
import numpy as np
import pickle as pkl

NUMBER=1000

parser = argparse.ArgumentParser("main.py")
parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_50000/point_cloud.ply")
parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_60000/point_cloud.ply")
parser.add_argument('-d', '--sh_degrees', type=int, default=3)

args = parser.parse_args()

gaussian_model = load_ply(args.model_path, args.sh_degrees)
gaussian_extended = Gaussian.empty_with_cap(NUMBER)

# gaussian_model.clip_to_box(
#     min=np.array([-0.2, -0.7, -0.3]),
#     max=np.array([0, -0.3, 0])
# )


# kd_tree = KDTree(gaussian_model.positions)
# similarities = []
# for i in tqdm(range(gaussian_model.num_gaussians)):
#     # print(f"currently running {i}")

#     scale = gaussian_model.scales[i]
#     radius = max(max(scale[0], scale[1]), scale[2]) * 1.5

#     points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18)
#     for j in points:
#         if j < i:
#             geo_score, color_score = similarity(gaussian_model, j, i)
#             if geo_score * color_score != 0:
#                 similarities.append((j, i, geo_score, color_score, ))

# sim = sorted(similarities, key=lambda pair: pair[2] * pair[3], reverse=True)
# print(f"pair number is {len(sim)}")

# with open('similarities.pkl', 'wb') as f:
#     pkl.dump(sim, f)

with open('similarities.pkl', 'rb') as f:
    sim = pkl.load(f)

def metric_with_opacity(pair):
    i, j, gs, cs = pair[0], pair[1], pair[2], pair[3]   
    o1, o2 = gaussian_model.opacity[i], gaussian_model.opacity[j]
    return gs * (cs ** 0.3) \
        * np.exp(np.dot(gaussian_model.rotations[i], gaussian_model.rotations[j])) \
        * min(o1/o2, o2/o1)

sim = sorted(sim, key=metric_with_opacity, reverse=True)
print(f"There are {len(sim)} pairs")

cnt = 0
filter = np.ones(gaussian_model.num_gaussians, dtype=bool)
removed = []
used = set()
for idx, (i, j, gs, cs) in enumerate(sim):
    if i in used or j in used:
        continue
    position, scale, rotation, features_dc, features_rest, opacity = merge_3(gaussian_model, i, j)
    if position is not None:
        
        cnt += 1
        # if cnt != NUMBER:
        #     continue
        id = gaussian_extended.add(position, scale, rotation, features_dc, features_rest, opacity)
        # print(f"opacity: ({gaussian_model.opacity[i]},{gaussian_model.opacity[j]}) -> {opacity}")
        filter[i] = False
        filter[j] = False
        used.add(i)
        used.add(j)

        if cnt == NUMBER:
            break
    
deleted_gaussian = gaussian_model.copy()
deleted_gaussian.apply_filter(np.logical_not(filter))
# deleted_gaussian.opacity = np.ones_like(deleted_gaussian.opacity)
# gaussian_extended.opacity = np.ones_like(gaussian_extended.opacity)

gaussian_model.apply_filter(filter)
gaussian_model.concat(gaussian_extended)

deleted_path = args.output_path.replace('60000', '60001')
added_path = args.output_path.replace('60000', '60002')
save_ply(args.output_path, gaussian_model)
save_ply(deleted_path, deleted_gaussian)
save_ply(added_path, gaussian_extended)