bdck commited on
Commit
72dfa8b
·
verified ·
1 Parent(s): 859d9ba

Upload scripts/segment.py

Browse files
Files changed (1) hide show
  1. scripts/segment.py +96 -0
scripts/segment.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ segment.py
3
+ ==========
4
+
5
+ Command-line script to run learned region growing on a PLY / PCD point cloud.
6
+
7
+ Example:
8
+ python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \
9
+ --output segmented_scene.ply --device cuda
10
+ """
11
+
12
+ import argparse
13
+ import torch
14
+ import numpy as np
15
+ from pathlib import Path
16
+
17
+ from learn_region_grow.io import load_point_cloud, save_ply
18
+ from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector
19
+ from learn_region_grow.lrg_net import LrgNet
20
+ from learn_region_grow.growing import RegionGrower
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud")
25
+ parser.add_argument("--input", required=True, help="Input .ply or .pcd file")
26
+ parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)")
27
+ parser.add_argument("--output", default="output.ply", help="Output segmented PLY")
28
+ parser.add_argument("--device", default="cuda", help="cuda or cpu")
29
+ parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters")
30
+ parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant")
31
+ parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)")
32
+ parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding")
33
+ parser.add_argument("--add_threshold", type=float, default=0.5)
34
+ parser.add_argument("--remove_threshold", type=float, default=0.5)
35
+ parser.add_argument("--cluster_threshold", type=int, default=10)
36
+ args = parser.parse_args()
37
+
38
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
39
+ print(f"Loading point cloud from {args.input} ...")
40
+ xyz, rgb, normals_input = load_point_cloud(args.input)
41
+ print(f" {len(xyz)} points loaded")
42
+
43
+ # Voxel equalization
44
+ print(f"Voxel equalization (resolution={args.resolution}m) ...")
45
+ eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution)
46
+ eq_rgb = rgb[eq_idx] if rgb is not None else None
47
+ print(f" {len(eq_xyz)} points after equalization")
48
+
49
+ # Normals / curvature
50
+ if normals_input is not None:
51
+ normals = np.abs(normals_input[eq_idx])
52
+ curvature = np.zeros(len(eq_xyz), dtype=np.float32) # could estimate from normals
53
+ # Better: still compute curvature from PCA
54
+ _, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
55
+ else:
56
+ print("Computing normals & curvature via PCA ...")
57
+ normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
58
+
59
+ # Build feature vector
60
+ features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature)
61
+ print(f"Feature vector shape: {features.shape}")
62
+
63
+ # Load model
64
+ print(f"Loading checkpoint {args.ckpt} ...")
65
+ model = LrgNet(in_channels=13, lite=args.lite)
66
+ model.load_state_dict(torch.load(args.ckpt, map_location=device))
67
+ model.to(device)
68
+ print("Model loaded.")
69
+
70
+ # Run region growing
71
+ grower = RegionGrower(
72
+ model, device,
73
+ add_threshold=args.add_threshold,
74
+ remove_threshold=args.remove_threshold,
75
+ cluster_threshold=args.cluster_threshold,
76
+ stochastic=args.stochastic,
77
+ )
78
+ print("Running learned region growing ...")
79
+ labels = grower.grow(eq_xyz, features, voxel_map, args.resolution)
80
+ n_instances = len(np.unique(labels[labels >= 0]))
81
+ print(f"Segmented into {n_instances} instances")
82
+
83
+ # Expand labels back to original dense cloud via nearest equalized neighbor
84
+ from scipy.spatial import cKDTree
85
+ tree = cKDTree(eq_xyz)
86
+ _, nn = tree.query(xyz)
87
+ full_labels = labels[nn]
88
+
89
+ # Save
90
+ print(f"Saving output to {args.output} ...")
91
+ save_ply(args.output, xyz, labels=full_labels)
92
+ print("Done.")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()