nksr-wrapper / scripts /reconstruct.py
bdck's picture
Upload scripts/reconstruct.py
30857ce verified
#!/usr/bin/env python3
"""
CLI script: reconstruct a mesh from a point-cloud file using NKSR.
Usage
-----
python reconstruct.py input.ply output.ply --detail 1.0 --mise-iter 1
python reconstruct.py input.ply output.ply --chunk-size 50.0 --no-normals
"""
import argparse
import sys
from pathlib import Path
import torch
# Allow running from repo root without installing
sys.path.insert(0, str(Path(__file__).parent.parent))
from nksr_wrapper import NKSRMeshReconstructor, load_point_cloud, save_mesh
def main() -> None:
parser = argparse.ArgumentParser(
description="NKSR point-cloud → mesh reconstruction"
)
parser.add_argument("input", type=Path, help="Input PLY or PCD file")
parser.add_argument("output", type=Path, help="Output mesh file (PLY/OBJ/GLB)")
parser.add_argument("--device", default="cuda:0", help="PyTorch device")
parser.add_argument("--config", default="ks", help="NKSR model config (ks/snet/snet-wonormal)")
parser.add_argument("--detail", type=float, default=1.0, help="Detail level 0.0-1.0")
parser.add_argument("--voxel-size", type=float, default=None, help="Override voxel size")
parser.add_argument("--chunk-size", type=float, default=-1.0, help="Chunk size for large scenes")
parser.add_argument("--mise-iter", type=int, default=1, help="MISE iterations")
parser.add_argument("--no-normals", action="store_true", help="Ignore normals in file; estimate them")
parser.add_argument("--estimate-normals", action="store_true", help="Estimate normals if file lacks them")
parser.add_argument("--sensor", type=Path, default=None, help="Optional NPY file with sensor positions")
parser.add_argument("--colors", type=Path, default=None, help="Optional NPY file with per-point RGB colors")
parser.add_argument("--solver-iter", type=int, default=2000, help="PCG solver max iterations")
parser.add_argument("--solver-tol", type=float, default=1e-5, help="PCG solver tolerance")
parser.add_argument("--verbose", action="store_true", help="Print extra progress info")
args = parser.parse_args()
if not args.input.exists():
parser.error(f"Input file not found: {args.input}")
# ---- load point cloud -----------------------------------------------
print(f"Loading point cloud from {args.input} ...")
points, normals = load_point_cloud(
args.input,
estimate_normals=args.estimate_normals or args.no_normals,
)
print(f" Loaded {len(points)} points")
if normals is not None:
print(f" Normals present: {normals.shape}")
elif not args.no_normals:
print(" No normals found in file — will estimate on-the-fly")
if args.no_normals:
normals = None
print(" --no-normals set: normals will be estimated")
# ---- optional extras ------------------------------------------------
sensor = None
if args.sensor:
import numpy as np
sensor = np.load(args.sensor)
print(f" Sensor positions loaded: {sensor.shape}")
colors = None
if args.colors:
import numpy as np
colors = np.load(args.colors)
print(f" Per-point colors loaded: {colors.shape}")
# ---- reconstruct ----------------------------------------------------
print("\nInitialising NKSR reconstructor ...")
if not torch.cuda.is_available() and args.device.startswith("cuda"):
print("WARNING: CUDA not available, falling back to CPU (very slow)")
args.device = "cpu"
recon = NKSRMeshReconstructor(
device=args.device,
config=args.config,
chunk_tmp_device="cpu" if args.chunk_size > 0 else None,
)
print("Reconstructing mesh ...")
mesh = recon.reconstruct(
points=points,
normals=normals,
sensor_positions=sensor,
colors=colors,
detail_level=args.detail,
voxel_size=args.voxel_size,
chunk_size=args.chunk_size,
mise_iter=args.mise_iter,
solver_max_iter=args.solver_iter,
solver_tol=args.solver_tol,
)
# ---- save -----------------------------------------------------------
print(f"\nSaving mesh to {args.output} ...")
save_mesh(args.output, mesh.vertices, mesh.faces, mesh.vertex_colors)
print(f" Vertices: {len(mesh.vertices):,} | Faces: {len(mesh.faces):,}")
if mesh.vertex_colors is not None:
print(" Vertex colors included")
print("Done.")
if __name__ == "__main__":
main()