summaryrefslogtreecommitdiff
path: root/dataloader.py
blob: 1ec880a510338f9685deddae44c074c49cb41efd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from plyfile import PlyData, PlyElement
from gaussian import Gaussian
import numpy as np
import os

def load_ply(path: str, max_sh_degree: int) -> Gaussian:
    ply_data = PlyData.read(path)

    # position
    xyz = np.stack(
        (
            np.asarray(ply_data.elements[0]["x"]),
            np.asarray(ply_data.elements[0]["y"]),
            np.asarray(ply_data.elements[0]["z"]),
        ),
        axis=1,
    )

    # opacity
    def sigmoid(z):
        return 1 / (1 + np.exp(-z))

    opacities = np.asarray(ply_data.elements[0]["opacity"])[..., np.newaxis]
    ## 过激活函数
    opacities = sigmoid(opacities)

    # sh
    features_dc = np.zeros((xyz.shape[0], 3, 1), dtype=np.float32)
    features_dc[:, 0, 0] = np.asarray(ply_data.elements[0]["f_dc_0"])
    features_dc[:, 1, 0] = np.asarray(ply_data.elements[0]["f_dc_1"])
    features_dc[:, 2, 0] = np.asarray(ply_data.elements[0]["f_dc_2"])

    extra_f_names = [
        p.name for p in ply_data.elements[0].properties if p.name.startswith("f_rest_")
    ]
    extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
    assert len(extra_f_names) == 3 * (max_sh_degree + 1) ** 2 - 3
    features_extra = np.zeros((xyz.shape[0], len(extra_f_names)), dtype=np.float32)
    for idx, attr_name in enumerate(extra_f_names):
        features_extra[:, idx] = np.asarray(ply_data.elements[0][attr_name])
    # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
    features_extra = features_extra.reshape(
        (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
    )

    features_dc = np.transpose(features_dc, (0, 2, 1))
    features_extra = np.transpose(features_extra, (0, 2, 1))
    ## 拼接得到完整的 sh
    ## sh = np.transpose(np.concatenate((features_dc, features_extra), axis=2), (0, 2, 1))

    # scale
    scale_names = [
        p.name for p in ply_data.elements[0].properties if p.name.startswith("scale_")
    ]
    scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
    scales = np.zeros((xyz.shape[0], len(scale_names)), dtype=np.float32)
    for idx, attr_name in enumerate(scale_names):
        scales[:, idx] = np.asarray(ply_data.elements[0][attr_name])
    ## 过激活函数
    scales = np.exp(scales)

    # rotation
    rot_names = [
        p.name for p in ply_data.elements[0].properties if p.name.startswith("rot")
    ]
    rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
    rots = np.zeros((xyz.shape[0], len(rot_names)), dtype=np.float32)
    for idx, attr_name in enumerate(rot_names):
        rots[:, idx] = np.asarray(ply_data.elements[0][attr_name])
    ## 过激活函数
    rot_length = np.linalg.norm(rots, axis=1)
    rots = rots / np.expand_dims(rot_length, axis=1)

    return Gaussian(xyz, scales, rots, features_dc, features_extra, opacities)
    ## return (num_gaussians, xyz, scales, rots, sh, opacities)

def save_ply(path: str, gaussian: Gaussian):
    print(f"saving {path}")
    os.makedirs(os.path.dirname(path), exist_ok=True)

    def construct_list_of_attributes(gaussian):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(gaussian.features_dc.shape[1] * gaussian.features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(gaussian.features_rest.shape[1] * gaussian.features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(gaussian.scales.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(gaussian.rotations.shape[1]):
            l.append('rot_{}'.format(i))
        return l
    
    def inverse_sigmoid(x):
        return np.log(x/(1-x))
    
    xyz = gaussian.positions
    normals = np.zeros_like(xyz)
    f_dc = np.transpose(gaussian.features_dc, (0, 2, 1))
    f_dc = np.reshape(f_dc, (f_dc.shape[0], f_dc.shape[1] * f_dc.shape[2]))
    f_rest = np.transpose(gaussian.features_rest, (0, 2, 1))
    f_rest = np.reshape(f_rest, (f_rest.shape[0], f_rest.shape[1] * f_rest.shape[2]))
    opacities = inverse_sigmoid(gaussian.opacity)
    scale = np.log(gaussian.scales)
    rotation = gaussian.rotations

    dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(gaussian)]

    attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)[:gaussian.num_gaussians]

    elements = np.empty(attributes.shape[0], dtype=dtype_full)
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)