from dataloader import load_ply, save_ply from gaussian import Gaussian from aabb import process_aabb from similarity import similarity from scipy.spatial import KDTree from merge import merge, merge_2, merge_3, merge_4 from tqdm import tqdm import argparse import numpy as np import pickle as pkl NUMBER=1000 parser = argparse.ArgumentParser("main.py") parser.add_argument('-m', '--model_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_50000/point_cloud.ply") parser.add_argument('-o', '--output_path', type=str, default="/home/chuyan/Documents/code/gaussian/models/train/point_cloud/iteration_60000/point_cloud.ply") parser.add_argument('-d', '--sh_degrees', type=int, default=3) args = parser.parse_args() gaussian_model = load_ply(args.model_path, args.sh_degrees) gaussian_extended = Gaussian.empty_with_cap(NUMBER) # gaussian_model.clip_to_box( # min=np.array([-0.2, -0.7, -0.3]), # max=np.array([0, -0.3, 0]) # ) # kd_tree = KDTree(gaussian_model.positions) # similarities = [] # for i in tqdm(range(gaussian_model.num_gaussians)): # # print(f"currently running {i}") # scale = gaussian_model.scales[i] # radius = max(max(scale[0], scale[1]), scale[2]) * 1.5 # points = kd_tree.query_ball_point(gaussian_model.positions[i], r=radius, workers=18) # for j in points: # if j < i: # geo_score, color_score = similarity(gaussian_model, j, i) # if geo_score * color_score != 0: # similarities.append((j, i, geo_score, color_score, )) # sim = sorted(similarities, key=lambda pair: pair[2] * pair[3], reverse=True) # print(f"pair number is {len(sim)}") # with open('similarities.pkl', 'wb') as f: # pkl.dump(sim, f) with open('similarities.pkl', 'rb') as f: sim = pkl.load(f) def metric_with_opacity(pair): i, j, gs, cs = pair[0], pair[1], pair[2], pair[3] o1, o2 = gaussian_model.opacity[i], gaussian_model.opacity[j] return gs * (cs ** 0.3) \ * np.exp(np.dot(gaussian_model.rotations[i], gaussian_model.rotations[j])) \ * min(o1/o2, o2/o1) sim = sorted(sim, key=metric_with_opacity, reverse=True) print(f"There are {len(sim)} pairs") cnt = 0 filter = np.ones(gaussian_model.num_gaussians, dtype=bool) removed = [] used = set() for idx, (i, j, gs, cs) in enumerate(sim): if i in used or j in used: continue position, scale, rotation, features_dc, features_rest, opacity = merge_3(gaussian_model, i, j) if position is not None: cnt += 1 # if cnt != NUMBER: # continue id = gaussian_extended.add(position, scale, rotation, features_dc, features_rest, opacity) # print(f"opacity: ({gaussian_model.opacity[i]},{gaussian_model.opacity[j]}) -> {opacity}") filter[i] = False filter[j] = False used.add(i) used.add(j) if cnt == NUMBER: break deleted_gaussian = gaussian_model.copy() deleted_gaussian.apply_filter(np.logical_not(filter)) # deleted_gaussian.opacity = np.ones_like(deleted_gaussian.opacity) # gaussian_extended.opacity = np.ones_like(gaussian_extended.opacity) gaussian_model.apply_filter(filter) gaussian_model.concat(gaussian_extended) deleted_path = args.output_path.replace('60000', '60001') added_path = args.output_path.replace('60000', '60002') save_ply(args.output_path, gaussian_model) save_ply(deleted_path, deleted_gaussian) save_ply(added_path, gaussian_extended)