From 394e19b012cb9264feaec582948fa7ac8bff901c Mon Sep 17 00:00:00 2001 From: Chuyan Zhang Date: Mon, 15 Jan 2024 15:49:41 -0800 Subject: init commit --- new_main.py | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 new_main.py (limited to 'new_main.py') diff --git a/new_main.py b/new_main.py new file mode 100644 index 0000000..99a6767 --- /dev/null +++ b/new_main.py @@ -0,0 +1,129 @@ +from dataloader import load_ply, save_ply +from gaussian import Gaussian +from scipy.spatial import KDTree +from scipy.optimize import minimize +from merge import merge_geo +from utils import quat2rot +from tqdm import tqdm +import argparse +import numpy as np +import pickle as pkl + +NUMBER=8000 + +parser = argparse.ArgumentParser("main.py") +parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_50000/point_cloud.ply") +parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_60000/point_cloud.ply") +parser.add_argument('-d', '--sh_degrees', type=int, default=3) + +args = parser.parse_args() + +gaussian_model = load_ply(args.model_path, args.sh_degrees) +gaussian_extended = Gaussian.empty_with_cap(NUMBER) + +kd_tree = KDTree(gaussian_model.positions) +pairs = [] +for i in tqdm(range(gaussian_model.num_gaussians)): + scale = gaussian_model.scales[i] + radius = max(max(scale[0], scale[1]), scale[2]) * 1.5 + points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18) + for j in points: + if j < i: + pairs.append((i, j)) + +# with open('pairs.pkl', 'wb') as f: +# pkl.dump(pairs, f) + +def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float: + mu_f = gaussian.positions[id1] + mu_g = gaussian.positions[id2] + s_f = gaussian.scales[id1] + s_g = gaussian.scales[id2] + r_f = gaussian.rotations[id1] + r_g = gaussian.rotations[id2] + o_f = gaussian.opacity[id1] + o_g = gaussian.opacity[id2] + + mu_0, s_0, r_0, o_0, point_min, vectors = merge_geo(gaussian_model, id1, id2) + vector_x = vectors[:, 0] + vector_y = vectors[:, 1] + vector_z = vectors[:, 2] + x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0) + + def make_gaussian(mu, s, r, o): + def calc_inner(pos): + pos = pos - mu + S = np.diag(s) + R = quat2rot(r) + M = S @ R + Sigma = np.linalg.inv(M.T @ M) + return o * np.exp(-0.5 * np.dot(pos, Sigma @ pos)) + return calc_inner + gaussian_f = make_gaussian(mu_f, s_f, r_f, o_f) + gaussian_g = make_gaussian(mu_g, s_g, r_g, o_g) + + def target_inner(features) -> float: + mu = features[:3] + s = features[3:6] + r = features[6:10] + o = features[10:] + gaussian_h = make_gaussian(mu, s, r, o) + + N = 4 + + sum = 0. + for i in range(N): + for j in range(N): + for k in range(N): + pos = point_min + (i * vector_x + j * vector_y + k * vector_z) / N + sum += (gaussian_f(pos) + gaussian_g(pos) - gaussian_h(pos)) ** 2 + return sum + + res = minimize(target_inner, x_0) + mu_h = res.x[:3] + s_h = res.x[3:6] + r_h = res.x[6:10] + o_h = res.x[10] + + sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2. + features_dc = sh[0, :] + features_rest = sh[1:, :] + # print(mu_h, s_h, r_h, features_dc, features_rest, o_h) + + return mu_h, s_h, r_h, features_dc, features_rest, o_h + +cnt = 0 +filter = np.ones(gaussian_model.num_gaussians, dtype=bool) +used = set() +for i, j in tqdm(pairs): + if i in used or j in used: + continue + mu, s, r, dc, rest, o = merge_3(gaussian_model, i, j) + if mu is not None: + # scale_product = np.prod(scale) + # scale_product_1 = np.prod(gaussian_model.scales[i]) + # scale_product_2 = np.prod(gaussian_model.scales[j]) + # if scale_product / max(scale_product_1, scale_product_2) > 10.: + # continue + + cnt += 1 + id = gaussian_extended.add(mu, s, r, dc, rest, o) + filter[i] = False + filter[j] = False + used.add(i) + used.add(j) + + if cnt == NUMBER: + break + +deleted_gaussian = gaussian_model.copy() +deleted_gaussian.apply_filter(np.logical_not(filter)) + +gaussian_model.apply_filter(filter) +gaussian_model.concat(gaussian_extended) + +deleted_path = args.output_path.replace('60000', '60001') +added_path = args.output_path.replace('60000', '60002') +save_ply(args.output_path, gaussian_model) +save_ply(deleted_path, deleted_gaussian) +save_ply(added_path, gaussian_extended) \ No newline at end of file -- cgit v1.2.3-70-g09d2