summaryrefslogtreecommitdiff
path: root/dataloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'dataloader.py')
-rw-r--r--dataloader.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/dataloader.py b/dataloader.py
new file mode 100644
index 0000000..1ec880a
--- /dev/null
+++ b/dataloader.py
@@ -0,0 +1,116 @@
+from plyfile import PlyData, PlyElement
+from gaussian import Gaussian
+import numpy as np
+import os
+
+def load_ply(path: str, max_sh_degree: int) -> Gaussian:
+ ply_data = PlyData.read(path)
+
+ # position
+ xyz = np.stack(
+ (
+ np.asarray(ply_data.elements[0]["x"]),
+ np.asarray(ply_data.elements[0]["y"]),
+ np.asarray(ply_data.elements[0]["z"]),
+ ),
+ axis=1,
+ )
+
+ # opacity
+ def sigmoid(z):
+ return 1 / (1 + np.exp(-z))
+
+ opacities = np.asarray(ply_data.elements[0]["opacity"])[..., np.newaxis]
+ ## 过激活函数
+ opacities = sigmoid(opacities)
+
+ # sh
+ features_dc = np.zeros((xyz.shape[0], 3, 1), dtype=np.float32)
+ features_dc[:, 0, 0] = np.asarray(ply_data.elements[0]["f_dc_0"])
+ features_dc[:, 1, 0] = np.asarray(ply_data.elements[0]["f_dc_1"])
+ features_dc[:, 2, 0] = np.asarray(ply_data.elements[0]["f_dc_2"])
+
+ extra_f_names = [
+ p.name for p in ply_data.elements[0].properties if p.name.startswith("f_rest_")
+ ]
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
+ assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)), dtype=np.float32)
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = np.asarray(ply_data.elements[0][attr_name])
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+ features_extra = features_extra.reshape(
+ (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
+ )
+
+ features_dc = np.transpose(features_dc, (0, 2, 1))
+ features_extra = np.transpose(features_extra, (0, 2, 1))
+ ## 拼接得到完整的 sh
+ ## sh = np.transpose(np.concatenate((features_dc, features_extra), axis=2), (0, 2, 1))
+
+ # scale
+ scale_names = [
+ p.name for p in ply_data.elements[0].properties if p.name.startswith("scale_")
+ ]
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
+ scales = np.zeros((xyz.shape[0], len(scale_names)), dtype=np.float32)
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(ply_data.elements[0][attr_name])
+ ## 过激活函数
+ scales = np.exp(scales)
+
+ # rotation
+ rot_names = [
+ p.name for p in ply_data.elements[0].properties if p.name.startswith("rot")
+ ]
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
+ rots = np.zeros((xyz.shape[0], len(rot_names)), dtype=np.float32)
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(ply_data.elements[0][attr_name])
+ ## 过激活函数
+ rot_length = np.linalg.norm(rots, axis=1)
+ rots = rots / np.expand_dims(rot_length, axis=1)
+
+ return Gaussian(xyz, scales, rots, features_dc, features_extra, opacities)
+ ## return (num_gaussians, xyz, scales, rots, sh, opacities)
+
+def save_ply(path: str, gaussian: Gaussian):
+ print(f"saving {path}")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ def construct_list_of_attributes(gaussian):
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+ # All channels except the 3 DC
+ for i in range(gaussian.features_dc.shape[1] * gaussian.features_dc.shape[2]):
+ l.append('f_dc_{}'.format(i))
+ for i in range(gaussian.features_rest.shape[1] * gaussian.features_rest.shape[2]):
+ l.append('f_rest_{}'.format(i))
+ l.append('opacity')
+ for i in range(gaussian.scales.shape[1]):
+ l.append('scale_{}'.format(i))
+ for i in range(gaussian.rotations.shape[1]):
+ l.append('rot_{}'.format(i))
+ return l
+
+ def inverse_sigmoid(x):
+ return np.log(x/(1-x))
+
+ xyz = gaussian.positions
+ normals = np.zeros_like(xyz)
+ f_dc = np.transpose(gaussian.features_dc, (0, 2, 1))
+ f_dc = np.reshape(f_dc, (f_dc.shape[0], f_dc.shape[1] * f_dc.shape[2]))
+ f_rest = np.transpose(gaussian.features_rest, (0, 2, 1))
+ f_rest = np.reshape(f_rest, (f_rest.shape[0], f_rest.shape[1] * f_rest.shape[2]))
+ opacities = inverse_sigmoid(gaussian.opacity)
+ scale = np.log(gaussian.scales)
+ rotation = gaussian.rotations
+
+ dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(gaussian)]
+
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)[:gaussian.num_gaussians]
+
+ elements = np.empty(attributes.shape[0], dtype=dtype_full)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+ PlyData([el]).write(path)
+