| """Contains basic data structures and functionality for 3D Gaussians. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from pathlib import Path |
| from typing import Any, Literal, NamedTuple |
|
|
| import numpy as np |
| import torch |
| from plyfile import PlyData, PlyElement |
|
|
| from sharp.utils import color_space as cs_utils |
| from sharp.utils import linalg |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
|
|
| BackgroundColor = Literal["black", "white", "random_color", "random_pixel"] |
|
|
|
|
| class Gaussians3D(NamedTuple): |
| """Represents a collection of 3D Gaussians.""" |
|
|
| mean_vectors: torch.Tensor |
| singular_values: torch.Tensor |
| quaternions: torch.Tensor |
| colors: torch.Tensor |
| opacities: torch.Tensor |
|
|
| def to(self, device: torch.device) -> Gaussians3D: |
| """Move Gaussians to device.""" |
| return Gaussians3D( |
| mean_vectors=self.mean_vectors.to(device), |
| singular_values=self.singular_values.to(device), |
| quaternions=self.quaternions.to(device), |
| colors=self.colors.to(device), |
| opacities=self.opacities.to(device), |
| ) |
|
|
|
|
| class SceneMetaData(NamedTuple): |
| """Meta data about Gaussian scene.""" |
|
|
| focal_length_px: float |
| resolution_px: tuple[int, int] |
| color_space: cs_utils.ColorSpace |
|
|
|
|
| def get_unprojection_matrix( |
| extrinsics: torch.Tensor, |
| intrinsics: torch.Tensor, |
| image_shape: tuple[int, int], |
| ) -> torch.Tensor: |
| """Compute unprojection matrix to transform Gaussians to Euclidean space. |
| |
| Args: |
| extrinsics: The 4x4 extrinsics matrix of the camera view. |
| intrinsics: The 4x4 intrinsics matrix of the camera view. |
| image_shape: The (width, height) of the input image. |
| |
| Returns: |
| A 4x4 matrix to transform Gaussians from NDC space to Euclidean space. |
| """ |
| device = intrinsics.device |
| image_width, image_height = image_shape |
| |
| |
| |
| |
| |
| |
| ndc_matrix = torch.tensor( |
| [ |
| [2.0 / image_width, 0.0, -1.0, 0.0], |
| [0.0, 2.0 / image_height, -1.0, 0.0], |
| [0.0, 0.0, 1.0, 0.0], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| device=device, |
| ) |
| return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics) |
|
|
|
|
| def unproject_gaussians( |
| gaussians_ndc: Gaussians3D, |
| extrinsics: torch.Tensor, |
| intrinsics: torch.Tensor, |
| image_shape: tuple[int, int], |
| ) -> Gaussians3D: |
| """Unproject Gaussians from NDC space to world coordinates.""" |
| unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape) |
| gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3]) |
| return gaussians |
|
|
|
|
| def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D: |
| """Apply an affine transformation to 3D Gaussians. |
| |
| Args: |
| gaussians: The Gaussians to transform. |
| transform: An affine transform with shape 3x4. |
| |
| Returns: |
| The transformed Gaussians. |
| |
| Note: This operation is not differentiable. |
| """ |
| transform_linear = transform[..., :3, :3] |
| transform_offset = transform[..., :3, 3] |
|
|
| mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset |
| covariance_matrices = compose_covariance_matrices( |
| gaussians.quaternions, gaussians.singular_values |
| ) |
| covariance_matrices = ( |
| transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2) |
| ) |
| quaternions, singular_values = decompose_covariance_matrices(covariance_matrices) |
|
|
| return Gaussians3D( |
| mean_vectors=mean_vectors, |
| singular_values=singular_values, |
| quaternions=quaternions, |
| colors=gaussians.colors, |
| opacities=gaussians.opacities, |
| ) |
|
|
|
|
| def decompose_covariance_matrices( |
| covariance_matrices: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Decompose 3D covariance matrices into quaternions and singular values. |
| |
| Args: |
| covariance_matrices: The covariance matrices to decompose. |
| |
| Returns: |
| Quaternion and singular values corresponding to the orientation and scales of |
| the diagonalized matrix. |
| |
| Note: This operation is not differentiable. |
| """ |
| device = covariance_matrices.device |
| dtype = covariance_matrices.dtype |
|
|
| |
| covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64) |
| rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices) |
|
|
| |
| |
| batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0) |
| num_reflections = len(gaussian_idx) |
| if num_reflections > 0: |
| LOGGER.warning( |
| "Received %d reflection matrices from SVD. Flipping them to rotations.", |
| num_reflections, |
| ) |
| |
| rotations[batch_idx, gaussian_idx, :, -1] *= -1 |
| quaternions = linalg.quaternions_from_rotation_matrices(rotations) |
| quaternions = quaternions.to(dtype=dtype, device=device) |
| singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device) |
| return quaternions, singular_values |
|
|
|
|
| def compose_covariance_matrices( |
| quaternions: torch.Tensor, singular_values: torch.Tensor |
| ) -> torch.Tensor: |
| """Compose 3D covariance matrices into quaternions and singular values. |
| |
| Args: |
| quaternions: The quaternions describing the principal basis. |
| singular_values: The scales of the diagonalized matrix. |
| |
| Returns: |
| The 3x3 covariances matrices. |
| """ |
| device = quaternions.device |
| rotations = linalg.rotation_matrices_from_quaternions(quaternions) |
| diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None] |
| return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2) |
|
|
|
|
| def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor: |
| """Convert degree-0 spherical harmonics to RGB. |
| |
| Reference: |
| https://en.wikipedia.org/wiki/Table_of_spherical_harmonics |
| """ |
| coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) |
| return sh0 * coeff_degree0 + 0.5 |
|
|
|
|
| def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor: |
| """Convert RGB to degree-0 spherical harmonics. |
| |
| Reference: |
| https://en.wikipedia.org/wiki/Table_of_spherical_harmonics |
| """ |
| coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) |
| return (rgb - 0.5) / coeff_degree0 |
|
|
|
|
| def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: |
| """Loads a ply from a file.""" |
| plydata = PlyData.read(path) |
|
|
| vertices = next(filter(lambda x: x.name == "vertex", plydata.elements)) |
|
|
| properties = ["x", "y", "z"] |
| properties.extend([f"f_dc_{i}" for i in range(3)]) |
| properties.extend([f"scale_{i}" for i in range(3)]) |
| properties.extend([f"rot_{i}" for i in range(3)]) |
|
|
| for prop in properties: |
| if prop not in vertices: |
| raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.") |
| mean_vectors = np.stack( |
| ( |
| np.asarray(vertices["x"]), |
| np.asarray(vertices["y"]), |
| np.asarray(vertices["z"]), |
| ), |
| axis=1, |
| ) |
|
|
| scale_logits = np.stack( |
| ( |
| np.asarray(vertices["scale_0"]), |
| np.asarray(vertices["scale_1"]), |
| np.asarray(vertices["scale_2"]), |
| ), |
| axis=1, |
| ) |
|
|
| quaternions = np.stack( |
| ( |
| np.asarray(vertices["rot_0"]), |
| np.asarray(vertices["rot_1"]), |
| np.asarray(vertices["rot_2"]), |
| np.asarray(vertices["rot_3"]), |
| ), |
| axis=1, |
| ) |
|
|
| spherical_harmonics_deg0 = np.stack( |
| ( |
| np.asarray(vertices["f_dc_0"]), |
| np.asarray(vertices["f_dc_1"]), |
| np.asarray(vertices["f_dc_2"]), |
| ), |
| axis=1, |
| ) |
|
|
| colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0) |
|
|
| opacity_logits = np.asarray(vertices["opacity"])[..., None] |
|
|
| supplement_elements = [element for element in plydata.elements if element.name != "vertex"] |
| supplement_data: dict[str, Any] = {} |
| supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"] |
|
|
| for element in supplement_elements: |
| for key in supplement_keys: |
| if key not in supplement_data and key in element: |
| supplement_data[key] = np.asarray(element[key]) |
|
|
| |
| if "intrinsic" in supplement_data: |
| intrinsics_data = supplement_data["intrinsic"] |
|
|
| |
| if "image_size" not in supplement_data: |
| if len(intrinsics_data) != 4: |
| raise ValueError( |
| "Expect legacy intrinsics with len=4 containing image size, " |
| f"but received len={len(intrinsics_data)}" |
| ) |
| focal_length_px = (intrinsics_data[0], intrinsics_data[1]) |
| width = int(intrinsics_data[2]) |
| height = int(intrinsics_data[3]) |
|
|
| else: |
| if len(intrinsics_data) != 9: |
| raise ValueError( |
| "Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}." |
| ) |
| intrinsics_matrix = intrinsics_data.reshape((3, 3)) |
| focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1]) |
|
|
| image_size_data = supplement_data["image_size"] |
| width = image_size_data[0] |
| height = image_size_data[1] |
|
|
| |
| else: |
| focal_length_px = (512, 512) |
| width = 640 |
| height = 480 |
|
|
| |
| extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten()) |
| extrinsics_matrix = np.eye(4) |
|
|
| |
| if len(extrinsics_data) == 12: |
| extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4)) |
| extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T |
| elif len(extrinsics_data) == 16: |
| extrinsics_matrix[:] = extrinsics_data.reshape((4, 4)) |
| else: |
| raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}") |
|
|
| |
| color_space_index = supplement_data.get("color_space", 1) |
| color_space = cs_utils.decode_color_space(color_space_index) |
| if color_space == "sRGB": |
| colors = cs_utils.sRGB2linearRGB(colors) |
|
|
| mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float() |
| quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float() |
| singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float() |
| opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float() |
| colors = torch.from_numpy(colors).view(1, -1, 3).float() |
|
|
| gaussians = Gaussians3D( |
| mean_vectors=mean_vectors, |
| quaternions=quaternions, |
| singular_values=singular_values, |
| opacities=opacities, |
| colors=colors, |
| ) |
| metadata = SceneMetaData(focal_length_px[0], (width, height), color_space) |
| return gaussians, metadata |
|
|
|
|
| @torch.no_grad() |
| def save_ply( |
| gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path |
| ) -> PlyData: |
| """Save a predicted Gaussian3D to a ply file.""" |
|
|
| def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: |
| return torch.log(tensor / (1.0 - tensor)) |
|
|
| xyz = gaussians.mean_vectors.flatten(0, 1) |
| scale_logits = torch.log(gaussians.singular_values).flatten(0, 1) |
| quaternions = gaussians.quaternions.flatten(0, 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| colors = convert_rgb_to_spherical_harmonics( |
| cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)) |
| ) |
| color_space_index = cs_utils.encode_color_space("sRGB") |
|
|
| |
| opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1) |
|
|
| attributes = torch.cat( |
| ( |
| xyz, |
| colors, |
| opacity_logits, |
| scale_logits, |
| quaternions, |
| ), |
| dim=1, |
| ) |
|
|
| dtype_full = [ |
| (attribute, "f4") |
| for attribute in ["x", "y", "z"] |
| + [f"f_dc_{i}" for i in range(3)] |
| + ["opacity"] |
| + [f"scale_{i}" for i in range(3)] |
| + [f"rot_{i}" for i in range(4)] |
| ] |
|
|
| num_gaussians = len(xyz) |
| elements = np.empty(num_gaussians, dtype=dtype_full) |
| elements[:] = list(map(tuple, attributes.detach().cpu().numpy())) |
| vertex_elements = PlyElement.describe(elements, "vertex") |
|
|
| |
| image_height, image_width = image_shape |
|
|
| |
| dtype_image_size = [("image_size", "u4")] |
| image_size_array = np.empty(2, dtype=dtype_image_size) |
| image_size_array[:] = np.array([image_width, image_height]) |
| image_size_element = PlyElement.describe(image_size_array, "image_size") |
|
|
| |
| dtype_intrinsic = [("intrinsic", "f4")] |
| intrinsic_array = np.empty(9, dtype=dtype_intrinsic) |
| intrinsic = np.array( |
| [ |
| f_px, |
| 0, |
| image_width * 0.5, |
| 0, |
| f_px, |
| image_height * 0.5, |
| 0, |
| 0, |
| 1, |
| ] |
| ) |
| intrinsic_array[:] = intrinsic.flatten() |
| intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic") |
|
|
| |
| dtype_extrinsic = [("extrinsic", "f4")] |
| extrinsic_array = np.empty(16, dtype=dtype_extrinsic) |
| extrinsic_array[:] = np.eye(4).flatten() |
| extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic") |
|
|
| |
| dtype_frames = [("frame", "i4")] |
| frame_array = np.empty(2, dtype=dtype_frames) |
| frame_array[:] = np.array([1, num_gaussians], dtype=np.int32) |
| frame_element = PlyElement.describe(frame_array, "frame") |
|
|
| |
| dtype_disparity = [("disparity", "f4")] |
| disparity_array = np.empty(2, dtype=dtype_disparity) |
|
|
| disparity = 1.0 / gaussians.mean_vectors[0, ..., -1] |
| quantiles = ( |
| torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device)) |
| .float() |
| .cpu() |
| .numpy() |
| ) |
| disparity_array[:] = quantiles |
| disparity_element = PlyElement.describe(disparity_array, "disparity") |
|
|
| |
| dtype_color_space = [("color_space", "u1")] |
| color_space_array = np.empty(1, dtype=dtype_color_space) |
| color_space_array[:] = np.array([color_space_index]).flatten() |
| color_space_element = PlyElement.describe(color_space_array, "color_space") |
|
|
| dtype_version = [("version", "u1")] |
| version_array = np.empty(3, dtype=dtype_version) |
| version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten() |
| version_element = PlyElement.describe(version_array, "version") |
|
|
| plydata = PlyData( |
| [ |
| vertex_elements, |
| extrinsic_element, |
| intrinsic_element, |
| image_size_element, |
| frame_element, |
| disparity_element, |
| color_space_element, |
| version_element, |
| ] |
| ) |
|
|
| plydata.write(path) |
| return plydata |
|
|