Upload scripts/segment.py
Browse files- scripts/segment.py +96 -0
scripts/segment.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
segment.py
|
| 3 |
+
==========
|
| 4 |
+
|
| 5 |
+
Command-line script to run learned region growing on a PLY / PCD point cloud.
|
| 6 |
+
|
| 7 |
+
Example:
|
| 8 |
+
python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \
|
| 9 |
+
--output segmented_scene.ply --device cuda
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from learn_region_grow.io import load_point_cloud, save_ply
|
| 18 |
+
from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector
|
| 19 |
+
from learn_region_grow.lrg_net import LrgNet
|
| 20 |
+
from learn_region_grow.growing import RegionGrower
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud")
|
| 25 |
+
parser.add_argument("--input", required=True, help="Input .ply or .pcd file")
|
| 26 |
+
parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)")
|
| 27 |
+
parser.add_argument("--output", default="output.ply", help="Output segmented PLY")
|
| 28 |
+
parser.add_argument("--device", default="cuda", help="cuda or cpu")
|
| 29 |
+
parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters")
|
| 30 |
+
parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant")
|
| 31 |
+
parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)")
|
| 32 |
+
parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding")
|
| 33 |
+
parser.add_argument("--add_threshold", type=float, default=0.5)
|
| 34 |
+
parser.add_argument("--remove_threshold", type=float, default=0.5)
|
| 35 |
+
parser.add_argument("--cluster_threshold", type=int, default=10)
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
| 39 |
+
print(f"Loading point cloud from {args.input} ...")
|
| 40 |
+
xyz, rgb, normals_input = load_point_cloud(args.input)
|
| 41 |
+
print(f" {len(xyz)} points loaded")
|
| 42 |
+
|
| 43 |
+
# Voxel equalization
|
| 44 |
+
print(f"Voxel equalization (resolution={args.resolution}m) ...")
|
| 45 |
+
eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution)
|
| 46 |
+
eq_rgb = rgb[eq_idx] if rgb is not None else None
|
| 47 |
+
print(f" {len(eq_xyz)} points after equalization")
|
| 48 |
+
|
| 49 |
+
# Normals / curvature
|
| 50 |
+
if normals_input is not None:
|
| 51 |
+
normals = np.abs(normals_input[eq_idx])
|
| 52 |
+
curvature = np.zeros(len(eq_xyz), dtype=np.float32) # could estimate from normals
|
| 53 |
+
# Better: still compute curvature from PCA
|
| 54 |
+
_, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
|
| 55 |
+
else:
|
| 56 |
+
print("Computing normals & curvature via PCA ...")
|
| 57 |
+
normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
|
| 58 |
+
|
| 59 |
+
# Build feature vector
|
| 60 |
+
features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature)
|
| 61 |
+
print(f"Feature vector shape: {features.shape}")
|
| 62 |
+
|
| 63 |
+
# Load model
|
| 64 |
+
print(f"Loading checkpoint {args.ckpt} ...")
|
| 65 |
+
model = LrgNet(in_channels=13, lite=args.lite)
|
| 66 |
+
model.load_state_dict(torch.load(args.ckpt, map_location=device))
|
| 67 |
+
model.to(device)
|
| 68 |
+
print("Model loaded.")
|
| 69 |
+
|
| 70 |
+
# Run region growing
|
| 71 |
+
grower = RegionGrower(
|
| 72 |
+
model, device,
|
| 73 |
+
add_threshold=args.add_threshold,
|
| 74 |
+
remove_threshold=args.remove_threshold,
|
| 75 |
+
cluster_threshold=args.cluster_threshold,
|
| 76 |
+
stochastic=args.stochastic,
|
| 77 |
+
)
|
| 78 |
+
print("Running learned region growing ...")
|
| 79 |
+
labels = grower.grow(eq_xyz, features, voxel_map, args.resolution)
|
| 80 |
+
n_instances = len(np.unique(labels[labels >= 0]))
|
| 81 |
+
print(f"Segmented into {n_instances} instances")
|
| 82 |
+
|
| 83 |
+
# Expand labels back to original dense cloud via nearest equalized neighbor
|
| 84 |
+
from scipy.spatial import cKDTree
|
| 85 |
+
tree = cKDTree(eq_xyz)
|
| 86 |
+
_, nn = tree.query(xyz)
|
| 87 |
+
full_labels = labels[nn]
|
| 88 |
+
|
| 89 |
+
# Save
|
| 90 |
+
print(f"Saving output to {args.output} ...")
|
| 91 |
+
save_ply(args.output, xyz, labels=full_labels)
|
| 92 |
+
print("Done.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
main()
|