summaryrefslogtreecommitdiff
path: root/merge.py
diff options
context:
space:
mode:
Diffstat (limited to 'merge.py')
-rw-r--r--merge.py259
1 files changed, 259 insertions, 0 deletions
diff --git a/merge.py b/merge.py
new file mode 100644
index 0000000..f00ca1a
--- /dev/null
+++ b/merge.py
@@ -0,0 +1,259 @@
+from gaussian import Gaussian
+from aabb import process_aabb, aabb_merge, center, solve_scale, aabb_size
+from utils import quat2rot, rot2quat, sq2cov
+from scipy.optimize import minimize
+import numpy as np
+import random
+import time
+
+def merge_4(gaussian: Gaussian, id1: int, id2: int):
+ scale_1 = gaussian.scales[id1]
+ scale_2 = gaussian.scales[id2]
+ weight_1 = gaussian.opacity[id1] * np.prod(scale_1)
+ weight_2 = gaussian.opacity[id2] * np.prod(scale_2)
+ weight = weight_1 + weight_2
+
+ position = (weight_1 * gaussian.positions[id1] + weight_2 * gaussian.positions[id2]) / weight
+ cov_1 = sq2cov(scale_1, gaussian.rotations[id1])
+ cov_2 = sq2cov(scale_2, gaussian.rotations[id2])
+ cov = (weight_1 * cov_1 + weight_2 * cov_2) / weight
+
+ eigenvalues, eigenvectors = np.linalg.eig(cov)
+ scale = np.sqrt(eigenvalues)
+ rotation = rot2quat(eigenvectors.T)
+
+ sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+ features_dc = sh[0, :]
+ features_rest = sh[1:, :]
+ opacity = (weight_1 * gaussian.opacity[id1] + weight_2 * gaussian.opacity[id2]) / weight
+
+ return position, scale, rotation, features_dc, features_rest, opacity
+
+def merge_3(gaussian: Gaussian, id1: int, id2: int) -> float:
+ def make_gaussian(mu, s, r, o):
+ def calc_inner(pos: np.ndarray) -> float:
+ x = pos - mu
+ S = np.diag(s)
+ R = quat2rot(r)
+ M = S @ R
+ Sigma = np.linalg.inv(M.T @ M)
+ return o * np.exp(-0.5 * np.dot(x, Sigma @ x))
+ return calc_inner
+ opacity_f = gaussian.opacity[id1]
+ opacity_g = gaussian.opacity[id2]
+ mu_f = gaussian.positions[id1]
+ mu_g = gaussian.positions[id2]
+ cov_f = sq2cov(gaussian.scales[id1], gaussian.rotations[id1])
+ cov_g = sq2cov(gaussian.scales[id2], gaussian.rotations[id2])
+ sqrt_det_cov_f = np.sqrt(np.linalg.det(cov_f))
+ sqrt_det_cov_g = np.sqrt(np.linalg.det(cov_g))
+
+ def c_without_2pi(mu1, mu2, cov1, cov2):
+ mu = mu1 - mu2
+ cov = cov1 + cov2
+ return np.exp(-0.5 * np.dot(mu, cov @ mu)) / np.sqrt(np.linalg.det(cov))
+
+ f2integral = opacity_f * opacity_f * sqrt_det_cov_f
+ g2integral = opacity_g * opacity_g * sqrt_det_cov_g
+ fgintegral = (2 ** 1.5) * opacity_f * opacity_g * sqrt_det_cov_f * sqrt_det_cov_g * c_without_2pi(mu_f, mu_g, cov_f, cov_g)
+
+ gaussian_f = make_gaussian(gaussian.positions[id1], gaussian.scales[id1], gaussian.rotations[id1], opacity_f)
+ gaussian_g = make_gaussian(gaussian.positions[id2], gaussian.scales[id2], gaussian.rotations[id2], opacity_g)
+
+ mu_0, s_0, r_0, o_0, point_min, basis = merge_geo(gaussian, id1, id2)
+ x_0 = np.concatenate((mu_0, s_0, r_0, o_0), axis=0)
+
+ basis_x = basis[:, 0]
+ basis_y = basis[:, 1]
+ basis_z = basis[:, 2]
+
+ def target_inner_analytical(features) -> float:
+ mu = features[:3]
+ s = features[3:6]
+ r = features[6:10]
+ o = features[10]
+ cov = sq2cov(s, r)
+ sqrt_det_cov = np.sqrt(np.linalg.det(cov))
+
+ h2integral = o * o * sqrt_det_cov
+ fhintegral = (2 ** 1.5) * opacity_f * o * sqrt_det_cov_f * sqrt_det_cov * c_without_2pi(mu_f, mu, cov_f, cov)
+ ghintegral = (2 ** 1.5) * opacity_g * o * sqrt_det_cov_g * sqrt_det_cov * c_without_2pi(mu_g, mu, cov_g, cov)
+
+ return (f2integral + g2integral + h2integral + 2 * fgintegral - 2 * fhintegral - 2 * ghintegral) * (np.pi ** 1.5)
+
+
+ def target_inner(features) -> float:
+ mu = features[:3]
+ s = features[3:6]
+ r = features[6:10]
+ o = features[10]
+ gaussian_h = make_gaussian(mu, s, r, o)
+
+ N = 64
+
+ sum = 0.
+ tot_weight = 0.
+ random.seed(int(time.time()))
+ for _ in range(N):
+ while True:
+ x = random.random()
+ y = random.random()
+ z = random.random()
+ pos = point_min + x * basis_x + y * basis_y + z * basis_z
+
+ f_value = gaussian_f(pos)
+ g_value = gaussian_g(pos)
+
+ if N < 32: # 用 f 的 pdf
+ weight = f_value / opacity_f
+ if random.random() < weight:
+ break
+ else: # 用 g 的 pdf
+ weight = g_value / opacity_g
+ if random.random() < weight:
+ break
+
+ sum += (f_value + g_value - gaussian_h(pos)) ** 2 * weight
+ tot_weight += weight
+ return sum / tot_weight
+
+ res = minimize(target_inner_analytical, x_0, options={'gtol': 1e-4, 'disp': False})
+ mu_h = res.x[:3]
+ s_h = res.x[3:6]
+ r_h = res.x[6:10]
+ o_h = res.x[10]
+
+ sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+ features_dc = sh[0, :]
+ features_rest = sh[1:, :]
+ # print(mu_h, s_h, r_h, features_dc, features_rest, o_h)
+
+ return mu_h, s_h, r_h, features_dc, features_rest, o_h
+
+# def merge_neo(gaussian: Gaussian, id1: int, id2: int):
+# N_f = np.prod(gaussian.scales[id1]) * gaussian.opacity[id1]
+# N_g = np.prod(gaussian.scales[id2]) * gaussian.opacity[id2]
+# # N_f = gaussian.opacity[id1]
+# # N_g = gaussian.opacity[id2]
+
+# mu_f = gaussian.positions[id1]
+# mu_g = gaussian.positions[id2]
+# Lambda_f = np.diag(gaussian.scales[id1] ** 2)
+# Lambda_g = np.diag(gaussian.scales[id2] ** 2)
+# U_f = quat2rot(gaussian.rotations[id1])
+# U_g = quat2rot(gaussian.rotations[id2])
+
+# N_h = N_f + N_g
+# position = (N_f * mu_f + N_g * mu_g) / N_h
+
+# g = U_f.T @ (mu_f - mu_g)
+# G = U_f.T @ U_g
+
+# mat = N_f / N_h * Lambda_f + \
+# N_g / N_h * G @ Lambda_g @ G.T + \
+# N_f * N_g / (N_h * N_h) * g @ g.T
+# eigenvalues, eigenvectors = np.linalg.eig(mat)
+# U_h = U_f @ eigenvectors
+
+# rotation = rot2quat(U_h)
+# scale = np.sqrt(eigenvalues)
+# opacity = N_h / np.prod(scale)
+
+# sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+# features_dc = sh[0, :]
+# features_rest = sh[1:, :]
+
+# return position, scale, rotation, features_dc, features_rest, opacity
+
+def merge_geo(gaussian: Gaussian, id1: int, id2: int):
+ rotation = gaussian.rotations[id1] + gaussian.rotations[id2]
+ rotation = rotation / np.linalg.norm(rotation)
+
+ rotation_mat = quat2rot(rotation)
+ inv_rotation_mat = np.linalg.inv(rotation_mat)
+
+ rotation1_mat = inv_rotation_mat @ quat2rot(gaussian.rotations[id1])
+ rotation2_mat = inv_rotation_mat @ quat2rot(gaussian.rotations[id2])
+ position1 = inv_rotation_mat @ gaussian.positions[id1]
+ position2 = inv_rotation_mat @ gaussian.positions[id2]
+ aabb1 = process_aabb(position1, gaussian.scales[id1], rotation1_mat)
+ aabb2 = process_aabb(position2, gaussian.scales[id2], rotation2_mat)
+ aabb = aabb_merge(aabb1, aabb2)
+
+ # point_min = aabb[:3]
+ # vectors = aabb[3:] - point_min
+ # vector_x = rotation_mat @ np.array([vectors.x, 0, 0])
+ # vector_y = rotation_mat @ np.array([0, vectors.y, 0])
+ # vector_z = rotation_mat @ np.array([0, 0, vectors.z])
+
+ point_min = np.array([aabb[0], aabb[2], aabb[4]])
+ point_max = np.array([aabb[1], aabb[3], aabb[5]])
+ vectors = rotation_mat @ np.diag(point_max - point_min)
+
+ # new_aabb = np.concatenate((rotation_mat @ point_min, rotation_mat @ point_max), axis=0)
+
+ position = rotation_mat @ center(aabb)
+ scale = np.array([aabb[1] - aabb[0], aabb[3] - aabb[2], aabb[5] - aabb[4]]) / 2.
+ opacity = np.array((gaussian.opacity[id1] + gaussian.opacity[id2]) / 2.)
+ return position, scale, rotation, opacity, rotation_mat @ point_min, vectors
+
+
+def merge_2(gaussian: Gaussian, id1: int, id2: int):
+ rotation = gaussian.rotations[id1] + gaussian.rotations[id2]
+ rotation = rotation / np.linalg.norm(rotation)
+
+ rotation_mat = quat2rot(rotation)
+ inv_rotation_mat = np.linalg.inv(rotation_mat)
+
+ rotation1_mat = inv_rotation_mat @ quat2rot(gaussian.rotations[id1])
+ rotation2_mat = inv_rotation_mat @ quat2rot(gaussian.rotations[id2])
+ position1 = inv_rotation_mat @ gaussian.positions[id1]
+ position2 = inv_rotation_mat @ gaussian.positions[id2]
+ aabb1 = process_aabb(position1, gaussian.scales[id1], rotation1_mat)
+ aabb2 = process_aabb(position2, gaussian.scales[id2], rotation2_mat)
+ aabb = aabb_merge(aabb1, aabb2)
+
+ position = rotation_mat @ center(aabb)
+ scale = np.array([aabb[1] - aabb[0], aabb[3] - aabb[2], aabb[5] - aabb[4]]) / 2.
+
+ sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+ features_dc = sh[0, :]
+ features_rest = sh[1:, :]
+ opacity = (gaussian.opacity[id1] + gaussian.opacity[id2]) / 2.
+ return position, scale, rotation, features_dc, features_rest, opacity
+
+def merge(gaussian: Gaussian, id1: int, id2: int):
+ aabb1 = process_aabb(
+ gaussian.positions[id1],
+ gaussian.scales[id1],
+ gaussian.rotations[id1])
+ aabb2 = process_aabb(
+ gaussian.positions[id2],
+ gaussian.scales[id2],
+ gaussian.rotations[id2])
+ aabb = aabb_merge(aabb1, aabb2)
+ print(f"merging:\n {aabb1}\n {aabb2}\n\n {aabb}")
+
+ position = center(aabb)
+ rotation = gaussian.rotations[id1] + gaussian.rotations[id2]
+ rotation = rotation / np.linalg.norm(rotation)
+ scale = solve_scale(rotation, aabb)
+ if np.isnan(scale).any():
+ # 为什么会出现 nan??????
+ return None, None, None, None, None, None
+ sh = (gaussian.sh[id1] + gaussian.sh[id2]) / 2.
+ features_dc = sh[0, :]
+ features_rest = sh[1:, :]
+ opacity = (gaussian.opacity[id1] + gaussian.opacity[id2]) / 2.
+ # opacity = min(gaussian.opacity[id1], gaussian.opacity[id2])
+
+ # aabb_test = process_aabb(
+ # position,
+ # scale,
+ # rotation,
+ # scale_factor=1
+ # )
+ # assert(np.allclose(aabb, aabb_test))
+
+ return position, scale, rotation, features_dc, features_rest, opacity
+ # gaussian.replace(id1, id2, position, scale, rotation, sh, opacity)