From 394e19b012cb9264feaec582948fa7ac8bff901c Mon Sep 17 00:00:00 2001 From: Chuyan Zhang Date: Mon, 15 Jan 2024 15:49:41 -0800 Subject: init commit --- main.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 main.py (limited to 'main.py') diff --git a/main.py b/main.py new file mode 100644 index 0000000..521f341 --- /dev/null +++ b/main.py @@ -0,0 +1,99 @@ +from dataloader import load_ply, save_ply +from gaussian import Gaussian +from aabb import process_aabb +from similarity import similarity +from scipy.spatial import KDTree +from merge import merge, merge_2, merge_3, merge_4 +from tqdm import tqdm +import argparse +import numpy as np +import pickle as pkl + +NUMBER=1000 + +parser = argparse.ArgumentParser("main.py") +parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_50000/point_cloud.ply") +parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_60000/point_cloud.ply") +parser.add_argument('-d', '--sh_degrees', type=int, default=3) + +args = parser.parse_args() + +gaussian_model = load_ply(args.model_path, args.sh_degrees) +gaussian_extended = Gaussian.empty_with_cap(NUMBER) + +# gaussian_model.clip_to_box( +# min=np.array([-0.2, -0.7, -0.3]), +# max=np.array([0, -0.3, 0]) +# ) + + +# kd_tree = KDTree(gaussian_model.positions) +# similarities = [] +# for i in tqdm(range(gaussian_model.num_gaussians)): +# # print(f"currently running {i}") + +# scale = gaussian_model.scales[i] +# radius = max(max(scale[0], scale[1]), scale[2]) * 1.5 + +# points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18) +# for j in points: +# if j < i: +# geo_score, color_score = similarity(gaussian_model, j, i) +# if geo_score * color_score != 0: +# similarities.append((j, i, geo_score, color_score, )) + +# sim = sorted(similarities, key=lambda pair: pair[2] * pair[3], reverse=True) +# print(f"pair number is {len(sim)}") + +# with open('similarities.pkl', 'wb') as f: +# pkl.dump(sim, f) + +with open('similarities.pkl', 'rb') as f: + sim = pkl.load(f) + +def metric_with_opacity(pair): + i, j, gs, cs = pair[0], pair[1], pair[2], pair[3] + o1, o2 = gaussian_model.opacity[i], gaussian_model.opacity[j] + return gs * (cs ** 0.3) \ + * np.exp(np.dot(gaussian_model.rotations[i], gaussian_model.rotations[j])) \ + * min(o1/o2, o2/o1) + +sim = sorted(sim, key=metric_with_opacity, reverse=True) +print(f"There are {len(sim)} pairs") + +cnt = 0 +filter = np.ones(gaussian_model.num_gaussians, dtype=bool) +removed = [] +used = set() +for idx, (i, j, gs, cs) in enumerate(sim): + if i in used or j in used: + continue + position, scale, rotation, features_dc, features_rest, opacity = merge_3(gaussian_model, i, j) + if position is not None: + + cnt += 1 + # if cnt != NUMBER: + # continue + id = gaussian_extended.add(position, scale, rotation, features_dc, features_rest, opacity) + # print(f"opacity: ({gaussian_model.opacity[i]},{gaussian_model.opacity[j]}) -> {opacity}") + filter[i] = False + filter[j] = False + used.add(i) + used.add(j) + + if cnt == NUMBER: + break + +deleted_gaussian = gaussian_model.copy() +deleted_gaussian.apply_filter(np.logical_not(filter)) +# deleted_gaussian.opacity = np.ones_like(deleted_gaussian.opacity) +# gaussian_extended.opacity = np.ones_like(gaussian_extended.opacity) + +gaussian_model.apply_filter(filter) +gaussian_model.concat(gaussian_extended) + +deleted_path = args.output_path.replace('60000', '60001') +added_path = args.output_path.replace('60000', '60002') +save_ply(args.output_path, gaussian_model) +save_ply(deleted_path, deleted_gaussian) +save_ply(added_path, gaussian_extended) -- cgit v1.2.3-70-g09d2