bdck commited on
Commit
728e958
·
verified ·
1 Parent(s): 0232680

Upload examples/segment_ply.py

Browse files
Files changed (1) hide show
  1. examples/segment_ply.py +85 -0
examples/segment_ply.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example: Segment a PLY point cloud with Point-SAM.
2
+
3
+ Usage:
4
+ python segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors
5
+ """
6
+
7
+ import argparse
8
+
9
+ import numpy as np
10
+ from point_sam import PointSAM, load_pointcloud
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(description="Segment a point cloud with Point-SAM")
15
+ parser.add_argument("ply_file", type=str, help="Path to .ply or .pcd file")
16
+ parser.add_argument("px", type=float, help="Prompt point x")
17
+ parser.add_argument("py", type=float, help="Prompt point y")
18
+ parser.add_argument("pz", type=float, help="Prompt point z")
19
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to .safetensors checkpoint")
20
+ parser.add_argument("--variant", type=str, default="large", choices=["large", "giant"])
21
+ parser.add_argument("--device", type=str, default="cuda")
22
+ parser.add_argument("--output", type=str, default="mask.npy", help="Where to save the boolean mask")
23
+ args = parser.parse_args()
24
+
25
+ # 1. Load point cloud
26
+ coords, rgb, original_coords = load_pointcloud(args.ply_file, normalize=True)
27
+ print(f"Loaded {len(coords)} points from {args.ply_file}")
28
+
29
+ # 2. Load model
30
+ model = PointSAM.from_pretrained(
31
+ checkpoint_path=args.checkpoint,
32
+ variant=args.variant,
33
+ device=args.device,
34
+ )
35
+
36
+ # 3. Cache the point cloud for fast repeated queries
37
+ model.set_pointcloud(coords, rgb)
38
+
39
+ # 4. Run segmentation
40
+ prompt = [args.px, args.py, args.pz]
41
+ masks, iou_scores = model.predict(
42
+ coords=None, # use cached
43
+ rgb=None, # use cached
44
+ prompt_point=prompt,
45
+ prompt_label=1, # foreground
46
+ multimask_output=True,
47
+ )
48
+
49
+ # 5. Pick best mask by IoU score
50
+ best_idx = int(np.argmax(iou_scores))
51
+ best_mask = masks[best_idx]
52
+ print(f"IoU scores: {iou_scores}")
53
+ print(f"Selected mask {best_idx} with {best_mask.sum()} points")
54
+
55
+ # 6. Save
56
+ np.save(args.output, best_mask)
57
+ print(f"Saved mask to {args.output}")
58
+
59
+ # 7. Optional: save segmented point cloud as PLY
60
+ segmented = original_coords[best_mask]
61
+ segmented_rgb = rgb[best_mask]
62
+ out_ply = args.output.replace(".npy", "_segmented.ply")
63
+ save_ascii_ply(out_ply, segmented, segmented_rgb)
64
+ print(f"Saved segmented point cloud to {out_ply}")
65
+
66
+
67
+ def save_ascii_ply(path, coords, rgb):
68
+ """Save a point cloud as an ASCII PLY file."""
69
+ with open(path, "w") as f:
70
+ f.write("ply\n")
71
+ f.write("format ascii 1.0\n")
72
+ f.write(f"element vertex {len(coords)}\n")
73
+ f.write("property float x\n")
74
+ f.write("property float y\n")
75
+ f.write("property float z\n")
76
+ f.write("property uchar red\n")
77
+ f.write("property uchar green\n")
78
+ f.write("property uchar blue\n")
79
+ f.write("end_header\n")
80
+ for (x, y, z), (r, g, b) in zip(coords, rgb):
81
+ f.write(f"{x} {y} {z} {int(r)} {int(g)} {int(b)}\n")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()