bdck's picture
Upload nksr_wrapper/io.py
5a90b18 verified
"""
I/O helpers for point clouds (PLY/PCD) and meshes (OBJ/PLY/GLB).
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np
def load_point_cloud(
path: Union[str, Path],
*,
estimate_normals: bool = False,
normal_knn: int = 30,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Load a point cloud from a ``.ply`` or ``.pcd`` file.
Parameters
----------
path : str or Path
Path to the input file.
estimate_normals : bool, default False
If the file does not contain normals, estimate them with Open3D.
normal_knn : int, default 30
k-NN neighbourhood size for Open3D normal estimation.
Returns
-------
points : np.ndarray
(N, 3) float array of point positions.
normals : np.ndarray or None
(N, 3) float array of point normals, if available or estimated.
"""
path = Path(path)
suffix = path.suffix.lower()
if suffix in (".ply", ".pcd"):
points, normals = _load_with_open3d(path, estimate_normals, normal_knn)
else:
raise ValueError(f"Unsupported point-cloud format: {suffix}")
return points, normals
def _load_with_open3d(
path: Path, estimate_normals: bool, normal_knn: int
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
try:
import open3d as o3d
except ImportError as exc:
raise ImportError(
"open3d is required to load PLY/PCD files. "
"Install it with: pip install open3d"
) from exc
pcd = o3d.io.read_point_cloud(str(path))
points = np.asarray(pcd.points)
normals = None
if pcd.has_normals():
normals = np.asarray(pcd.normals)
elif estimate_normals and len(points) > 0:
pcd.estimate_normals(
search_param=o3d.geometry.KDTreeSearchParamKNN(knn=normal_knn)
)
pcd.orient_normals_consistent_tangent_plane(k=normal_knn)
normals = np.asarray(pcd.normals)
return points, normals
def save_mesh(
path: Union[str, Path],
vertices: np.ndarray,
faces: np.ndarray,
vertex_colors: Optional[np.ndarray] = None,
) -> None:
"""
Save a triangle mesh to a file.
Supported formats: ``.ply``, ``.obj``, ``.glb``, ``.stl``, ``.off``
(anything Trimesh supports).
"""
import trimesh
path = Path(path)
# Ensure correct dtypes
vertices = np.asarray(vertices, dtype=np.float32)
faces = np.asarray(faces, dtype=np.int32)
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_colors=vertex_colors,
)
mesh.export(str(path))