from dataloader import load_ply, save_ply from gaussian import Gaussian from scipy.spatial import KDTree from scipy.optimize import minimize from merge import merge_geo from utils import quat2rot from tqdm import tqdm import argparse import numpy as np import pickle as pkl NUMBER=8000 parser = argparse.ArgumentParser("main.py") parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_50000/point_cloud.ply") parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/models/train/point_cloud/iteration_60000/point_cloud.ply") parser.add_argument('-d', '--sh_degrees', type=int, default=3) args = parser.parse_args() gaussian_model = load_ply(args.model_path, args.sh_degrees) gaussian_extended = Gaussian.empty_with_cap(NUMBER) kd_tree = KDTree(gaussian_model.positions) pairs = [] for i in tqdm(range(gaussian_model.num_gaussians)): scale = gaussian_model.scales[i] radius = max(max(scale[0], scale[1]), scale[2]) * 1.5 points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18) for j in points: if j < i: pairs.append((i, j)) # with open('pairs.pkl', 'wb') as f: # pkl.dump(pairs, f) def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float: mu_f = gaussian.positions[id1] mu_g = gaussian.positions[id2] s_f = gaussian.scales[id1] s_g = gaussian.scales[id2] r_f = gaussian.rotations[id1] r_g = gaussian.rotations[id2] o_f = gaussian.opacity[id1] o_g = gaussian.opacity[id2] mu_0, s_0, r_0, o_0, point_min, vectors = merge_geo(gaussian_model, id1, id2) vector_x = vectors[:, 0] vector_y = vectors[:, 1] vector_z = vectors[:, 2] x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0) def make_gaussian(mu, s, r, o): def calc_inner(pos): pos = pos - mu S = np.diag(s) R = quat2rot(r) M = S @ R Sigma = np.linalg.inv(M.T @ M) return o * np.exp(-0.5 * np.dot(pos, Sigma @ pos)) return calc_inner gaussian_f = make_gaussian(mu_f, s_f, r_f, o_f) gaussian_g = make_gaussian(mu_g, s_g, r_g, o_g) def target_inner(features) -> float: mu = features[:3] s = features[3:6] r = features[6:10] o = features[10:] gaussian_h = make_gaussian(mu, s, r, o) N = 4 sum = 0. for i in range(N): for j in range(N): for k in range(N): pos = point_min + (i * vector_x + j * vector_y + k * vector_z) / N sum += (gaussian_f(pos) + gaussian_g(pos) - gaussian_h(pos)) ** 2 return sum res = minimize(target_inner, x_0) mu_h = res.x[:3] s_h = res.x[3:6] r_h = res.x[6:10] o_h = res.x[10] sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2. features_dc = sh[0, :] features_rest = sh[1:, :] # print(mu_h, s_h, r_h, features_dc, features_rest, o_h) return mu_h, s_h, r_h, features_dc, features_rest, o_h cnt = 0 filter = np.ones(gaussian_model.num_gaussians, dtype=bool) used = set() for i, j in tqdm(pairs): if i in used or j in used: continue mu, s, r, dc, rest, o = merge_3(gaussian_model, i, j) if mu is not None: # scale_product = np.prod(scale) # scale_product_1 = np.prod(gaussian_model.scales[i]) # scale_product_2 = np.prod(gaussian_model.scales[j]) # if scale_product / max(scale_product_1, scale_product_2) > 10.: # continue cnt += 1 id = gaussian_extended.add(mu, s, r, dc, rest, o) filter[i] = False filter[j] = False used.add(i) used.add(j) if cnt == NUMBER: break deleted_gaussian = gaussian_model.copy() deleted_gaussian.apply_filter(np.logical_not(filter)) gaussian_model.apply_filter(filter) gaussian_model.concat(gaussian_extended) deleted_path = args.output_path.replace('60000', '60001') added_path = args.output_path.replace('60000', '60002') save_ply(args.output_path, gaussian_model) save_ply(deleted_path, deleted_gaussian) save_ply(added_path, gaussian_extended)