From 394e19b012cb9264feaec582948fa7ac8bff901c Mon Sep 17 00:00:00 2001 From: Chuyan Zhang Date: Mon, 15 Jan 2024 15:49:41 -0800 Subject: init commit --- dataloader.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 dataloader.py (limited to 'dataloader.py') diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..1ec880a --- /dev/null +++ b/dataloader.py @@ -0,0 +1,116 @@ +from plyfile import PlyData, PlyElement +from gaussian import Gaussian +import numpy as np +import os + +def load_ply(path: str, max_sh_degree: int) -> Gaussian: + ply_data = PlyData.read(path) + + # position + xyz = np.stack( + ( + np.asarray(ply_data.elements[0]["x"]), + np.asarray(ply_data.elements[0]["y"]), + np.asarray(ply_data.elements[0]["z"]), + ), + axis=1, + ) + + # opacity + def sigmoid(z): + return 1 / (1 + np.exp(-z)) + + opacities = np.asarray(ply_data.elements[0]["opacity"])[..., np.newaxis] + ## 过激活函数 + opacities = sigmoid(opacities) + + # sh + features_dc = np.zeros((xyz.shape[0], 3, 1), dtype=np.float32) + features_dc[:, 0, 0] = np.asarray(ply_data.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(ply_data.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(ply_data.elements[0]["f_dc_2"]) + + extra_f_names = [ + p.name for p in ply_data.elements[0].properties if p.name.startswith("f_rest_") + ] + extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) + assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names)), dtype=np.float32) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(ply_data.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape( + (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1) + ) + + features_dc = np.transpose(features_dc, (0, 2, 1)) + features_extra = np.transpose(features_extra, (0, 2, 1)) + ## 拼接得到完整的 sh + ## sh = np.transpose(np.concatenate((features_dc, features_extra), axis=2), (0, 2, 1)) + + # scale + scale_names = [ + p.name for p in ply_data.elements[0].properties if p.name.startswith("scale_") + ] + scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names)), dtype=np.float32) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(ply_data.elements[0][attr_name]) + ## 过激活函数 + scales = np.exp(scales) + + # rotation + rot_names = [ + p.name for p in ply_data.elements[0].properties if p.name.startswith("rot") + ] + rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names)), dtype=np.float32) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(ply_data.elements[0][attr_name]) + ## 过激活函数 + rot_length = np.linalg.norm(rots, axis=1) + rots = rots / np.expand_dims(rot_length, axis=1) + + return Gaussian(xyz, scales, rots, features_dc, features_extra, opacities) + ## return (num_gaussians, xyz, scales, rots, sh, opacities) + +def save_ply(path: str, gaussian: Gaussian): + print(f"saving {path}") + os.makedirs(os.path.dirname(path), exist_ok=True) + + def construct_list_of_attributes(gaussian): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(gaussian.features_dc.shape[1] * gaussian.features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(gaussian.features_rest.shape[1] * gaussian.features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(gaussian.scales.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(gaussian.rotations.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def inverse_sigmoid(x): + return np.log(x/(1-x)) + + xyz = gaussian.positions + normals = np.zeros_like(xyz) + f_dc = np.transpose(gaussian.features_dc, (0, 2, 1)) + f_dc = np.reshape(f_dc, (f_dc.shape[0], f_dc.shape[1] * f_dc.shape[2])) + f_rest = np.transpose(gaussian.features_rest, (0, 2, 1)) + f_rest = np.reshape(f_rest, (f_rest.shape[0], f_rest.shape[1] * f_rest.shape[2])) + opacities = inverse_sigmoid(gaussian.opacity) + scale = np.log(gaussian.scales) + rotation = gaussian.rotations + + dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(gaussian)] + + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)[:gaussian.num_gaussians] + + elements = np.empty(attributes.shape[0], dtype=dtype_full) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + -- cgit v1.2.3-70-g09d2