""" segment.py ========== Command-line script to run learned region growing on a PLY / PCD point cloud. Example: python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \ --output segmented_scene.ply --device cuda """ import argparse import torch import numpy as np from pathlib import Path from learn_region_grow.io import load_point_cloud, save_ply from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector from learn_region_grow.lrg_net import LrgNet from learn_region_grow.growing import RegionGrower def main(): parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud") parser.add_argument("--input", required=True, help="Input .ply or .pcd file") parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)") parser.add_argument("--output", default="output.ply", help="Output segmented PLY") parser.add_argument("--device", default="cuda", help="cuda or cpu") parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters") parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant") parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)") parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding") parser.add_argument("--add_threshold", type=float, default=0.5) parser.add_argument("--remove_threshold", type=float, default=0.5) parser.add_argument("--cluster_threshold", type=int, default=10) args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else 'cpu') print(f"Loading point cloud from {args.input} ...") xyz, rgb, normals_input = load_point_cloud(args.input) print(f" {len(xyz)} points loaded") # Voxel equalization print(f"Voxel equalization (resolution={args.resolution}m) ...") eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution) eq_rgb = rgb[eq_idx] if rgb is not None else None print(f" {len(eq_xyz)} points after equalization") # Normals / curvature if normals_input is not None: normals = np.abs(normals_input[eq_idx]) curvature = np.zeros(len(eq_xyz), dtype=np.float32) # could estimate from normals # Better: still compute curvature from PCA _, curvature = compute_normals_and_curvature(eq_xyz, args.resolution) else: print("Computing normals & curvature via PCA ...") normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution) # Build feature vector features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature) print(f"Feature vector shape: {features.shape}") # Load model print(f"Loading checkpoint {args.ckpt} ...") model = LrgNet(in_channels=13, lite=args.lite) model.load_state_dict(torch.load(args.ckpt, map_location=device)) model.to(device) print("Model loaded.") # Run region growing grower = RegionGrower( model, device, add_threshold=args.add_threshold, remove_threshold=args.remove_threshold, cluster_threshold=args.cluster_threshold, stochastic=args.stochastic, ) print("Running learned region growing ...") labels = grower.grow(eq_xyz, features, voxel_map, args.resolution) n_instances = len(np.unique(labels[labels >= 0])) print(f"Segmented into {n_instances} instances") # Expand labels back to original dense cloud via nearest equalized neighbor from scipy.spatial import cKDTree tree = cKDTree(eq_xyz) _, nn = tree.query(xyz) full_labels = labels[nn] # Save print(f"Saving output to {args.output} ...") save_ply(args.output, xyz, labels=full_labels) print("Done.") if __name__ == "__main__": main()