bdck commited on
Commit
50a800c
·
verified ·
1 Parent(s): 5a90b18

Upload nksr_wrapper/reconstructor.py

Browse files
Files changed (1) hide show
  1. nksr_wrapper/reconstructor.py +326 -0
nksr_wrapper/reconstructor.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core NKSR wrapper: high-level mesh reconstruction API.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Optional, Union, Callable
10
+ import warnings
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ try:
17
+ import nksr
18
+ except ImportError as exc:
19
+ raise ImportError(
20
+ "The `nksr` package is required but not installed. "
21
+ "Please install it from https://github.com/nv-tlabs/NKSR:\n"
22
+ " git clone https://github.com/nv-tlabs/NKSR.git\n"
23
+ " cd NKSR && pip install --no-build-isolation package/\n"
24
+ "See the README for environment setup details."
25
+ ) from exc
26
+
27
+
28
+ @dataclass
29
+ class MeshResult:
30
+ """Result container for a reconstructed mesh."""
31
+
32
+ vertices: np.ndarray
33
+ """(V, 3) float array of mesh vertex positions."""
34
+
35
+ faces: np.ndarray
36
+ """(F, 3) int array of triangle face indices."""
37
+
38
+ vertex_colors: Optional[np.ndarray] = None
39
+ """(V, 3) float array of per-vertex colors, if texture was reconstructed."""
40
+
41
+ def save(self, path: Union[str, Path]) -> None:
42
+ """Save the mesh to a file using Trimesh."""
43
+ import trimesh
44
+
45
+ mesh = trimesh.Trimesh(
46
+ vertices=self.vertices,
47
+ faces=self.faces,
48
+ vertex_colors=self.vertex_colors,
49
+ )
50
+ mesh.export(str(path))
51
+
52
+
53
+ class NKSRMeshReconstructor:
54
+ """
55
+ High-level wrapper around the NKSR reconstructor.
56
+
57
+ This class hides the internal complexity of NKSR and exposes a single
58
+ ``reconstruct()`` call that takes a point cloud (with optional normals)
59
+ and returns a watertight triangle mesh.
60
+
61
+ Parameters
62
+ ----------
63
+ device : str or torch.device, optional
64
+ PyTorch device to run inference on. Default ``"cuda:0"``.
65
+ config : str, optional
66
+ NKSR model configuration to load. Default ``"ks"`` (kitchen-sink,
67
+ general-purpose pretrained model). Other options include ``"snet"``
68
+ (ShapeNet objects with normals) and ``"snet-wonormal"`` (ShapeNet
69
+ without normals).
70
+ chunk_tmp_device : str or torch.device, optional
71
+ Temporary offload device for finished chunks when reconstructing very
72
+ large scenes. Default ``"cpu"``. Set to ``None`` to disable
73
+ off-loading (keeps everything on *device*).
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ device: Union[str, torch.device] = "cuda:0",
79
+ config: str = "ks",
80
+ chunk_tmp_device: Optional[Union[str, torch.device]] = "cpu",
81
+ ):
82
+ self.device = torch.device(device)
83
+ self.reconstructor = nksr.Reconstructor(self.device, config=config)
84
+
85
+ if chunk_tmp_device is not None:
86
+ self.reconstructor.chunk_tmp_device = torch.device(chunk_tmp_device)
87
+
88
+ self._config_name = config
89
+
90
+ # ------------------------------------------------------------------ #
91
+ # Public API #
92
+ # ------------------------------------------------------------------ #
93
+
94
+ def reconstruct(
95
+ self,
96
+ points: np.ndarray,
97
+ normals: Optional[np.ndarray] = None,
98
+ sensor_positions: Optional[np.ndarray] = None,
99
+ colors: Optional[np.ndarray] = None,
100
+ *,
101
+ detail_level: float = 1.0,
102
+ voxel_size: Optional[float] = None,
103
+ chunk_size: float = -1.0,
104
+ overlap_ratio: float = 0.05,
105
+ approx_kernel_grad: bool = False,
106
+ solver_max_iter: int = 2000,
107
+ solver_tol: float = 1e-5,
108
+ nystrom_min_depth: int = 100,
109
+ fused_mode: bool = True,
110
+ mise_iter: int = 1,
111
+ estimate_normals_if_missing: bool = True,
112
+ normal_knn: int = 64,
113
+ normal_drop_threshold_deg: float = 85.0,
114
+ ) -> MeshResult:
115
+ """
116
+ Reconstruct a watertight mesh from a point cloud.
117
+
118
+ Parameters
119
+ ----------
120
+ points : np.ndarray
121
+ (N, 3) array of point positions.
122
+ normals : np.ndarray, optional
123
+ (N, 3) array of **oriented** point normals. If ``None`` and
124
+ *sensor_positions* are also ``None``, normals are estimated on
125
+ the fly (requires *estimate_normals_if_missing* = ``True``).
126
+ sensor_positions : np.ndarray, optional
127
+ (N, 3) array of per-point sensor/camera positions. When normals
128
+ are missing, NKSR can infer orientation from the point-to-sensor
129
+ vector using the internal ``get_estimate_normal_preprocess_fn``.
130
+ colors : np.ndarray, optional
131
+ (N, 3) array of RGB colors in ``[0, 255]`` or ``[0, 1]``. If
132
+ provided, the returned mesh will contain per-vertex colors.
133
+ detail_level : float, default 1.0
134
+ Trade-off between smoothness and detail. ``0.0`` = very smooth,
135
+ ``1.0`` = maximum detail (may over-fit noise). Ignored when
136
+ *chunk_size* > 0 or *voxel_size* is set.
137
+ voxel_size : float, optional
138
+ Explicit voxel size controlling the reconstruction resolution.
139
+ Overrides *detail_level*.
140
+ chunk_size : float, default -1.0
141
+ Spatial extent of each chunk for out-of-core reconstruction.
142
+ ``-1.0`` disables chunking (process everything at once). Positive
143
+ values are required for very large point clouds (> few million
144
+ points) to avoid out-of-memory errors.
145
+ overlap_ratio : float, default 0.05
146
+ Overlap between adjacent chunks (as a fraction of *chunk_size*).
147
+ approx_kernel_grad : bool, default False
148
+ Whether to approximate kernel gradients — slightly faster but a
149
+ bit less accurate.
150
+ solver_max_iter : int, default 2000
151
+ Maximum iterations for the sparse PCG linear solver.
152
+ solver_tol : float, default 1e-5
153
+ Convergence tolerance for the PCG solver.
154
+ nystrom_min_depth : int, default 100
155
+ Minimum depth for the Nyström low-rank approximation used by the
156
+ kernel field.
157
+ fused_mode : bool, default True
158
+ Memory-efficient fusion mode when chunking is enabled.
159
+ mise_iter : int, default 1
160
+ Number of MISE (Multi-resolution IsoSurface Extraction) iterations.
161
+ ``0`` = base grid resolution, each additional iteration doubles
162
+ the effective resolution in subdivided cells.
163
+ estimate_normals_if_missing : bool, default True
164
+ If ``True`` and no normals are provided, estimate them from the
165
+ local geometry. This only works well when the surface is
166
+ sufficiently sampled.
167
+ normal_knn : int, default 64
168
+ k-NN neighborhood size for on-the-fly normal estimation.
169
+ normal_drop_threshold_deg : float, default 85.0
170
+ Maximum angle (in degrees) between the estimated normal and the
171
+ point-to-sensor vector. Points exceeding this are dropped.
172
+
173
+ Returns
174
+ -------
175
+ MeshResult
176
+ Container with ``vertices``, ``faces``, and optionally
177
+ ``vertex_colors``.
178
+
179
+ Notes
180
+ -----
181
+ 1. **Normals matter.** NKSR is designed for oriented normals. If
182
+ your input lacks them, the wrapper will try to estimate them, but
183
+ orientation may be arbitrary (leading to inside-out meshes).
184
+ Providing *sensor_positions* gives the best auto-orientation.
185
+ 2. **Scale.** The default ``voxel_size`` in the ``"ks"`` config is
186
+ ``0.1``. If your point cloud is in millimetres and represents a
187
+ room-scale scene, ``0.1`` = 10 cm, which is reasonable. Adjust
188
+ *voxel_size* or scale your data accordingly.
189
+ 3. **Chunking.** When ``chunk_size > 0``, *detail_level* and
190
+ *voxel_size* are ignored by the underlying NKSR code. To control
191
+ detail in chunked mode, pre-scale the point cloud by
192
+ ``0.1 / desired_voxel_size``.
193
+ """
194
+ points = self._to_tensor(points, "points")
195
+
196
+ # ---- handle normals ------------------------------------------------
197
+ preprocess_fn: Optional[Callable] = None
198
+
199
+ if normals is not None:
200
+ normals = self._to_tensor(normals, "normals")
201
+ elif sensor_positions is not None:
202
+ sensor_positions = self._to_tensor(sensor_positions, "sensor_positions")
203
+ preprocess_fn = nksr.get_estimate_normal_preprocess_fn(
204
+ knn=normal_knn,
205
+ drop_threshold_degrees=normal_drop_threshold_deg,
206
+ )
207
+ elif estimate_normals_if_missing:
208
+ warnings.warn(
209
+ "No normals or sensor positions provided. "
210
+ "Estimating normals from geometry — orientation may be arbitrary. "
211
+ "Consider providing sensor_positions for best results.",
212
+ UserWarning,
213
+ )
214
+ normals = self._estimate_normals_from_points(points, normal_knn)
215
+
216
+ # ---- colors ---------------------------------------------------------
217
+ color_tensor: Optional[torch.Tensor] = None
218
+ if colors is not None:
219
+ colors = np.asarray(colors)
220
+ if colors.max() > 1.0:
221
+ colors = colors / 255.0
222
+ color_tensor = self._to_tensor(colors, "colors")
223
+
224
+ # ---- reconstruct ----------------------------------------------------
225
+ field = self.reconstructor.reconstruct(
226
+ xyz=points,
227
+ normal=normals,
228
+ sensor=sensor_positions,
229
+ detail_level=detail_level,
230
+ voxel_size=voxel_size,
231
+ chunk_size=chunk_size,
232
+ overlap_ratio=overlap_ratio,
233
+ approx_kernel_grad=approx_kernel_grad,
234
+ solver_max_iter=solver_max_iter,
235
+ solver_tol=solver_tol,
236
+ nystrom_min_depth=nystrom_min_depth,
237
+ fused_mode=fused_mode,
238
+ preprocess_fn=preprocess_fn,
239
+ )
240
+
241
+ # ---- optional texture ------------------------------------------------
242
+ if color_tensor is not None:
243
+ field.set_texture_field(nksr.fields.PCNNField(points, color_tensor))
244
+ if mise_iter < 2:
245
+ warnings.warn(
246
+ "Color reconstruction requested but mise_iter < 2. "
247
+ "Increasing to 2 for better color resolution.",
248
+ UserWarning,
249
+ )
250
+ mise_iter = 2
251
+
252
+ # ---- extract mesh ---------------------------------------------------
253
+ mesh = field.extract_dual_mesh(mise_iter=mise_iter)
254
+
255
+ vertices = mesh.v.cpu().numpy() if hasattr(mesh.v, "cpu") else np.asarray(mesh.v)
256
+ faces = mesh.f.cpu().numpy() if hasattr(mesh.f, "cpu") else np.asarray(mesh.f)
257
+
258
+ vertex_colors = None
259
+ if hasattr(mesh, "c") and mesh.c is not None:
260
+ vertex_colors = (
261
+ mesh.c.cpu().numpy() if hasattr(mesh.c, "cpu") else np.asarray(mesh.c)
262
+ )
263
+
264
+ return MeshResult(
265
+ vertices=vertices,
266
+ faces=faces,
267
+ vertex_colors=vertex_colors,
268
+ )
269
+
270
+ # ------------------------------------------------------------------ #
271
+ # Helpers #
272
+ # ------------------------------------------------------------------ #
273
+
274
+ def _to_tensor(self, arr: np.ndarray, name: str) -> torch.Tensor:
275
+ """Convert a numpy array to a float tensor on the target device."""
276
+ arr = np.asarray(arr)
277
+ if arr.ndim != 2 or arr.shape[1] != 3:
278
+ raise ValueError(
279
+ f"{name} must have shape (N, 3), got {arr.shape}"
280
+ )
281
+ return torch.from_numpy(arr).float().to(self.device)
282
+
283
+ def _estimate_normals_from_points(
284
+ self, points: torch.Tensor, k: int = 64
285
+ ) -> torch.Tensor:
286
+ """
287
+ Fast PCA-based normal estimation using PyTorch (no Open3D dependency).
288
+
289
+ This estimates **unoriented** normals. Orientation is arbitrary,
290
+ so the resulting mesh may be inside-out.
291
+ """
292
+ # Simple k-NN with brute force — acceptable for moderate N (< 100k).
293
+ # For larger clouds the user should pre-compute normals externally.
294
+ N = points.shape[0]
295
+ if N > 100_000:
296
+ warnings.warn(
297
+ f"Point cloud has {N} points; on-the-fly normal estimation "
298
+ f"may be slow. Consider pre-computing normals with Open3D.",
299
+ UserWarning,
300
+ )
301
+
302
+ # Build a KD-tree or use brute force — we use a chunked brute-force
303
+ # approach to keep memory reasonable.
304
+ batch_size = 4096
305
+ normals_list = []
306
+
307
+ for i in range(0, N, batch_size):
308
+ batch = points[i : i + batch_size] # (B, 3)
309
+ # pairwise distances to all points
310
+ dists = torch.cdist(batch, points) # (B, N)
311
+ _, idx = torch.topk(dists, k=min(k, N), dim=-1, largest=False) # (B, k)
312
+ neighbors = points[idx] # (B, k, 3)
313
+ centered = neighbors - neighbors.mean(dim=1, keepdim=True) # (B, k, 3)
314
+ cov = centered.transpose(1, 2) @ centered # (B, 3, 3)
315
+ # smallest eigenvector = normal
316
+ eigvals, eigvecs = torch.linalg.eigh(cov)
317
+ normal = eigvecs[:, :, 0] # (B, 3)
318
+ normals_list.append(normal)
319
+
320
+ normals = torch.cat(normals_list, dim=0)
321
+ # arbitrary orientation — flip to point roughly outward from centroid
322
+ centroid = points.mean(dim=0, keepdim=True)
323
+ outward = points - centroid
324
+ flip = (normals * outward).sum(dim=-1, keepdim=True) < 0
325
+ normals = torch.where(flip, -normals, normals)
326
+ return normals