nksr-wrapper / nksr_wrapper /reconstructor.py
bdck's picture
Upload nksr_wrapper/reconstructor.py
50a800c verified
"""
Core NKSR wrapper: high-level mesh reconstruction API.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, Callable
import warnings
import numpy as np
import torch
try:
import nksr
except ImportError as exc:
raise ImportError(
"The `nksr` package is required but not installed. "
"Please install it from https://github.com/nv-tlabs/NKSR:\n"
" git clone https://github.com/nv-tlabs/NKSR.git\n"
" cd NKSR && pip install --no-build-isolation package/\n"
"See the README for environment setup details."
) from exc
@dataclass
class MeshResult:
"""Result container for a reconstructed mesh."""
vertices: np.ndarray
"""(V, 3) float array of mesh vertex positions."""
faces: np.ndarray
"""(F, 3) int array of triangle face indices."""
vertex_colors: Optional[np.ndarray] = None
"""(V, 3) float array of per-vertex colors, if texture was reconstructed."""
def save(self, path: Union[str, Path]) -> None:
"""Save the mesh to a file using Trimesh."""
import trimesh
mesh = trimesh.Trimesh(
vertices=self.vertices,
faces=self.faces,
vertex_colors=self.vertex_colors,
)
mesh.export(str(path))
class NKSRMeshReconstructor:
"""
High-level wrapper around the NKSR reconstructor.
This class hides the internal complexity of NKSR and exposes a single
``reconstruct()`` call that takes a point cloud (with optional normals)
and returns a watertight triangle mesh.
Parameters
----------
device : str or torch.device, optional
PyTorch device to run inference on. Default ``"cuda:0"``.
config : str, optional
NKSR model configuration to load. Default ``"ks"`` (kitchen-sink,
general-purpose pretrained model). Other options include ``"snet"``
(ShapeNet objects with normals) and ``"snet-wonormal"`` (ShapeNet
without normals).
chunk_tmp_device : str or torch.device, optional
Temporary offload device for finished chunks when reconstructing very
large scenes. Default ``"cpu"``. Set to ``None`` to disable
off-loading (keeps everything on *device*).
"""
def __init__(
self,
device: Union[str, torch.device] = "cuda:0",
config: str = "ks",
chunk_tmp_device: Optional[Union[str, torch.device]] = "cpu",
):
self.device = torch.device(device)
self.reconstructor = nksr.Reconstructor(self.device, config=config)
if chunk_tmp_device is not None:
self.reconstructor.chunk_tmp_device = torch.device(chunk_tmp_device)
self._config_name = config
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
def reconstruct(
self,
points: np.ndarray,
normals: Optional[np.ndarray] = None,
sensor_positions: Optional[np.ndarray] = None,
colors: Optional[np.ndarray] = None,
*,
detail_level: float = 1.0,
voxel_size: Optional[float] = None,
chunk_size: float = -1.0,
overlap_ratio: float = 0.05,
approx_kernel_grad: bool = False,
solver_max_iter: int = 2000,
solver_tol: float = 1e-5,
nystrom_min_depth: int = 100,
fused_mode: bool = True,
mise_iter: int = 1,
estimate_normals_if_missing: bool = True,
normal_knn: int = 64,
normal_drop_threshold_deg: float = 85.0,
) -> MeshResult:
"""
Reconstruct a watertight mesh from a point cloud.
Parameters
----------
points : np.ndarray
(N, 3) array of point positions.
normals : np.ndarray, optional
(N, 3) array of **oriented** point normals. If ``None`` and
*sensor_positions* are also ``None``, normals are estimated on
the fly (requires *estimate_normals_if_missing* = ``True``).
sensor_positions : np.ndarray, optional
(N, 3) array of per-point sensor/camera positions. When normals
are missing, NKSR can infer orientation from the point-to-sensor
vector using the internal ``get_estimate_normal_preprocess_fn``.
colors : np.ndarray, optional
(N, 3) array of RGB colors in ``[0, 255]`` or ``[0, 1]``. If
provided, the returned mesh will contain per-vertex colors.
detail_level : float, default 1.0
Trade-off between smoothness and detail. ``0.0`` = very smooth,
``1.0`` = maximum detail (may over-fit noise). Ignored when
*chunk_size* > 0 or *voxel_size* is set.
voxel_size : float, optional
Explicit voxel size controlling the reconstruction resolution.
Overrides *detail_level*.
chunk_size : float, default -1.0
Spatial extent of each chunk for out-of-core reconstruction.
``-1.0`` disables chunking (process everything at once). Positive
values are required for very large point clouds (> few million
points) to avoid out-of-memory errors.
overlap_ratio : float, default 0.05
Overlap between adjacent chunks (as a fraction of *chunk_size*).
approx_kernel_grad : bool, default False
Whether to approximate kernel gradients — slightly faster but a
bit less accurate.
solver_max_iter : int, default 2000
Maximum iterations for the sparse PCG linear solver.
solver_tol : float, default 1e-5
Convergence tolerance for the PCG solver.
nystrom_min_depth : int, default 100
Minimum depth for the Nyström low-rank approximation used by the
kernel field.
fused_mode : bool, default True
Memory-efficient fusion mode when chunking is enabled.
mise_iter : int, default 1
Number of MISE (Multi-resolution IsoSurface Extraction) iterations.
``0`` = base grid resolution, each additional iteration doubles
the effective resolution in subdivided cells.
estimate_normals_if_missing : bool, default True
If ``True`` and no normals are provided, estimate them from the
local geometry. This only works well when the surface is
sufficiently sampled.
normal_knn : int, default 64
k-NN neighborhood size for on-the-fly normal estimation.
normal_drop_threshold_deg : float, default 85.0
Maximum angle (in degrees) between the estimated normal and the
point-to-sensor vector. Points exceeding this are dropped.
Returns
-------
MeshResult
Container with ``vertices``, ``faces``, and optionally
``vertex_colors``.
Notes
-----
1. **Normals matter.** NKSR is designed for oriented normals. If
your input lacks them, the wrapper will try to estimate them, but
orientation may be arbitrary (leading to inside-out meshes).
Providing *sensor_positions* gives the best auto-orientation.
2. **Scale.** The default ``voxel_size`` in the ``"ks"`` config is
``0.1``. If your point cloud is in millimetres and represents a
room-scale scene, ``0.1`` = 10 cm, which is reasonable. Adjust
*voxel_size* or scale your data accordingly.
3. **Chunking.** When ``chunk_size > 0``, *detail_level* and
*voxel_size* are ignored by the underlying NKSR code. To control
detail in chunked mode, pre-scale the point cloud by
``0.1 / desired_voxel_size``.
"""
points = self._to_tensor(points, "points")
# ---- handle normals ------------------------------------------------
preprocess_fn: Optional[Callable] = None
if normals is not None:
normals = self._to_tensor(normals, "normals")
elif sensor_positions is not None:
sensor_positions = self._to_tensor(sensor_positions, "sensor_positions")
preprocess_fn = nksr.get_estimate_normal_preprocess_fn(
knn=normal_knn,
drop_threshold_degrees=normal_drop_threshold_deg,
)
elif estimate_normals_if_missing:
warnings.warn(
"No normals or sensor positions provided. "
"Estimating normals from geometry — orientation may be arbitrary. "
"Consider providing sensor_positions for best results.",
UserWarning,
)
normals = self._estimate_normals_from_points(points, normal_knn)
# ---- colors ---------------------------------------------------------
color_tensor: Optional[torch.Tensor] = None
if colors is not None:
colors = np.asarray(colors)
if colors.max() > 1.0:
colors = colors / 255.0
color_tensor = self._to_tensor(colors, "colors")
# ---- reconstruct ----------------------------------------------------
field = self.reconstructor.reconstruct(
xyz=points,
normal=normals,
sensor=sensor_positions,
detail_level=detail_level,
voxel_size=voxel_size,
chunk_size=chunk_size,
overlap_ratio=overlap_ratio,
approx_kernel_grad=approx_kernel_grad,
solver_max_iter=solver_max_iter,
solver_tol=solver_tol,
nystrom_min_depth=nystrom_min_depth,
fused_mode=fused_mode,
preprocess_fn=preprocess_fn,
)
# ---- optional texture ------------------------------------------------
if color_tensor is not None:
field.set_texture_field(nksr.fields.PCNNField(points, color_tensor))
if mise_iter < 2:
warnings.warn(
"Color reconstruction requested but mise_iter < 2. "
"Increasing to 2 for better color resolution.",
UserWarning,
)
mise_iter = 2
# ---- extract mesh ---------------------------------------------------
mesh = field.extract_dual_mesh(mise_iter=mise_iter)
vertices = mesh.v.cpu().numpy() if hasattr(mesh.v, "cpu") else np.asarray(mesh.v)
faces = mesh.f.cpu().numpy() if hasattr(mesh.f, "cpu") else np.asarray(mesh.f)
vertex_colors = None
if hasattr(mesh, "c") and mesh.c is not None:
vertex_colors = (
mesh.c.cpu().numpy() if hasattr(mesh.c, "cpu") else np.asarray(mesh.c)
)
return MeshResult(
vertices=vertices,
faces=faces,
vertex_colors=vertex_colors,
)
# ------------------------------------------------------------------ #
# Helpers #
# ------------------------------------------------------------------ #
def _to_tensor(self, arr: np.ndarray, name: str) -> torch.Tensor:
"""Convert a numpy array to a float tensor on the target device."""
arr = np.asarray(arr)
if arr.ndim != 2 or arr.shape[1] != 3:
raise ValueError(
f"{name} must have shape (N, 3), got {arr.shape}"
)
return torch.from_numpy(arr).float().to(self.device)
def _estimate_normals_from_points(
self, points: torch.Tensor, k: int = 64
) -> torch.Tensor:
"""
Fast PCA-based normal estimation using PyTorch (no Open3D dependency).
This estimates **unoriented** normals. Orientation is arbitrary,
so the resulting mesh may be inside-out.
"""
# Simple k-NN with brute force — acceptable for moderate N (< 100k).
# For larger clouds the user should pre-compute normals externally.
N = points.shape[0]
if N > 100_000:
warnings.warn(
f"Point cloud has {N} points; on-the-fly normal estimation "
f"may be slow. Consider pre-computing normals with Open3D.",
UserWarning,
)
# Build a KD-tree or use brute force — we use a chunked brute-force
# approach to keep memory reasonable.
batch_size = 4096
normals_list = []
for i in range(0, N, batch_size):
batch = points[i : i + batch_size] # (B, 3)
# pairwise distances to all points
dists = torch.cdist(batch, points) # (B, N)
_, idx = torch.topk(dists, k=min(k, N), dim=-1, largest=False) # (B, k)
neighbors = points[idx] # (B, k, 3)
centered = neighbors - neighbors.mean(dim=1, keepdim=True) # (B, k, 3)
cov = centered.transpose(1, 2) @ centered # (B, 3, 3)
# smallest eigenvector = normal
eigvals, eigvecs = torch.linalg.eigh(cov)
normal = eigvecs[:, :, 0] # (B, 3)
normals_list.append(normal)
normals = torch.cat(normals_list, dim=0)
# arbitrary orientation — flip to point roughly outward from centroid
centroid = points.mean(dim=0, keepdim=True)
outward = points - centroid
flip = (normals * outward).sum(dim=-1, keepdim=True) < 0
normals = torch.where(flip, -normals, normals)
return normals