summaryrefslogtreecommitdiff
path: root/new_main.py
diff options
context:
space:
mode:
authorChuyan Zhang <me@zcy.moe>2024-01-15 15:49:41 -0800
committerChuyan Zhang <me@zcy.moe>2024-01-15 15:49:41 -0800
commit394e19b012cb9264feaec582948fa7ac8bff901c (patch)
tree50c0e3b49821b3ef5b2e727cd02e5e3dd0ecab82 /new_main.py
downloadgaussian-lod-master.tar.gz
gaussian-lod-master.zip
init commitHEADmaster
Diffstat (limited to 'new_main.py')
-rw-r--r--new_main.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/new_main.py b/new_main.py
new file mode 100644
index 0000000..99a6767
--- /dev/null
+++ b/new_main.py
@@ -0,0 +1,129 @@
+from dataloader import load_ply, save_ply
+from gaussian import Gaussian
+from scipy.spatial import KDTree
+from scipy.optimize import minimize
+from merge import merge_geo
+from utils import quat2rot
+from tqdm import tqdm
+import argparse
+import numpy as np
+import pickle as pkl
+
+NUMBER=8000
+
+parser = argparse.ArgumentParser("main.py")
+parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_50000/point_cloud.ply")
+parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_60000/point_cloud.ply")
+parser.add_argument('-d', '--sh_degrees', type=int, default=3)
+
+args = parser.parse_args()
+
+gaussian_model = load_ply(args.model_path, args.sh_degrees)
+gaussian_extended = Gaussian.empty_with_cap(NUMBER)
+
+kd_tree = KDTree(gaussian_model.positions)
+pairs = []
+for i in tqdm(range(gaussian_model.num_gaussians)):
+ scale = gaussian_model.scales[i]
+ radius = max(max(scale[0], scale[1]), scale[2]) * 1.5
+ points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18)
+ for j in points:
+ if j < i:
+ pairs.append((i, j))
+
+# with open('pairs.pkl', 'wb') as f:
+# pkl.dump(pairs, f)
+
+def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float:
+ mu_f = gaussian.positions[id1]
+ mu_g = gaussian.positions[id2]
+ s_f = gaussian.scales[id1]
+ s_g = gaussian.scales[id2]
+ r_f = gaussian.rotations[id1]
+ r_g = gaussian.rotations[id2]
+ o_f = gaussian.opacity[id1]
+ o_g = gaussian.opacity[id2]
+
+ mu_0, s_0, r_0, o_0, point_min, vectors = merge_geo(gaussian_model, id1, id2)
+ vector_x = vectors[:, 0]
+ vector_y = vectors[:, 1]
+ vector_z = vectors[:, 2]
+ x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0)
+
+ def make_gaussian(mu, s, r, o):
+ def calc_inner(pos):
+ pos = pos - mu
+ S = np.diag(s)
+ R = quat2rot(r)
+ M = S @ R
+ Sigma = np.linalg.inv(M.T @ M)
+ return o * np.exp(-0.5 * np.dot(pos, Sigma @ pos))
+ return calc_inner
+ gaussian_f = make_gaussian(mu_f, s_f, r_f, o_f)
+ gaussian_g = make_gaussian(mu_g, s_g, r_g, o_g)
+
+ def target_inner(features) -> float:
+ mu = features[:3]
+ s = features[3:6]
+ r = features[6:10]
+ o = features[10:]
+ gaussian_h = make_gaussian(mu, s, r, o)
+
+ N = 4
+
+ sum = 0.
+ for i in range(N):
+ for j in range(N):
+ for k in range(N):
+ pos = point_min + (i * vector_x + j * vector_y + k * vector_z) / N
+ sum += (gaussian_f(pos) + gaussian_g(pos) - gaussian_h(pos)) ** 2
+ return sum
+
+ res = minimize(target_inner, x_0)
+ mu_h = res.x[:3]
+ s_h = res.x[3:6]
+ r_h = res.x[6:10]
+ o_h = res.x[10]
+
+ sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+ features_dc = sh[0, :]
+ features_rest = sh[1:, :]
+ # print(mu_h, s_h, r_h, features_dc, features_rest, o_h)
+
+ return mu_h, s_h, r_h, features_dc, features_rest, o_h
+
+cnt = 0
+filter = np.ones(gaussian_model.num_gaussians, dtype=bool)
+used = set()
+for i, j in tqdm(pairs):
+ if i in used or j in used:
+ continue
+ mu, s, r, dc, rest, o = merge_3(gaussian_model, i, j)
+ if mu is not None:
+ # scale_product = np.prod(scale)
+ # scale_product_1 = np.prod(gaussian_model.scales[i])
+ # scale_product_2 = np.prod(gaussian_model.scales[j])
+ # if scale_product / max(scale_product_1, scale_product_2) > 10.:
+ # continue
+
+ cnt += 1
+ id = gaussian_extended.add(mu, s, r, dc, rest, o)
+ filter[i] = False
+ filter[j] = False
+ used.add(i)
+ used.add(j)
+
+ if cnt == NUMBER:
+ break
+
+deleted_gaussian = gaussian_model.copy()
+deleted_gaussian.apply_filter(np.logical_not(filter))
+
+gaussian_model.apply_filter(filter)
+gaussian_model.concat(gaussian_extended)
+
+deleted_path = args.output_path.replace('60000', '60001')
+added_path = args.output_path.replace('60000', '60002')
+save_ply(args.output_path, gaussian_model)
+save_ply(deleted_path, deleted_gaussian)
+save_ply(added_path, gaussian_extended) \ No newline at end of file