learn_region_grow / scripts /segment.py
bdck's picture
Upload scripts/segment.py
72dfa8b verified
"""
segment.py
==========
Command-line script to run learned region growing on a PLY / PCD point cloud.
Example:
python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \
--output segmented_scene.ply --device cuda
"""
import argparse
import torch
import numpy as np
from pathlib import Path
from learn_region_grow.io import load_point_cloud, save_ply
from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector
from learn_region_grow.lrg_net import LrgNet
from learn_region_grow.growing import RegionGrower
def main():
parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud")
parser.add_argument("--input", required=True, help="Input .ply or .pcd file")
parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)")
parser.add_argument("--output", default="output.ply", help="Output segmented PLY")
parser.add_argument("--device", default="cuda", help="cuda or cpu")
parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters")
parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant")
parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)")
parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding")
parser.add_argument("--add_threshold", type=float, default=0.5)
parser.add_argument("--remove_threshold", type=float, default=0.5)
parser.add_argument("--cluster_threshold", type=int, default=10)
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"Loading point cloud from {args.input} ...")
xyz, rgb, normals_input = load_point_cloud(args.input)
print(f" {len(xyz)} points loaded")
# Voxel equalization
print(f"Voxel equalization (resolution={args.resolution}m) ...")
eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution)
eq_rgb = rgb[eq_idx] if rgb is not None else None
print(f" {len(eq_xyz)} points after equalization")
# Normals / curvature
if normals_input is not None:
normals = np.abs(normals_input[eq_idx])
curvature = np.zeros(len(eq_xyz), dtype=np.float32) # could estimate from normals
# Better: still compute curvature from PCA
_, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
else:
print("Computing normals & curvature via PCA ...")
normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
# Build feature vector
features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature)
print(f"Feature vector shape: {features.shape}")
# Load model
print(f"Loading checkpoint {args.ckpt} ...")
model = LrgNet(in_channels=13, lite=args.lite)
model.load_state_dict(torch.load(args.ckpt, map_location=device))
model.to(device)
print("Model loaded.")
# Run region growing
grower = RegionGrower(
model, device,
add_threshold=args.add_threshold,
remove_threshold=args.remove_threshold,
cluster_threshold=args.cluster_threshold,
stochastic=args.stochastic,
)
print("Running learned region growing ...")
labels = grower.grow(eq_xyz, features, voxel_map, args.resolution)
n_instances = len(np.unique(labels[labels >= 0]))
print(f"Segmented into {n_instances} instances")
# Expand labels back to original dense cloud via nearest equalized neighbor
from scipy.spatial import cKDTree
tree = cKDTree(eq_xyz)
_, nn = tree.query(xyz)
full_labels = labels[nn]
# Save
print(f"Saving output to {args.output} ...")
save_ply(args.output, xyz, labels=full_labels)
print("Done.")
if __name__ == "__main__":
main()