"""Example: Segment a PLY point cloud with Point-SAM. Usage: python segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors """ import argparse import numpy as np from point_sam import PointSAM, load_pointcloud def main(): parser = argparse.ArgumentParser(description="Segment a point cloud with Point-SAM") parser.add_argument("ply_file", type=str, help="Path to .ply or .pcd file") parser.add_argument("px", type=float, help="Prompt point x") parser.add_argument("py", type=float, help="Prompt point y") parser.add_argument("pz", type=float, help="Prompt point z") parser.add_argument("--checkpoint", type=str, required=True, help="Path to .safetensors checkpoint") parser.add_argument("--variant", type=str, default="large", choices=["large", "giant"]) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--output", type=str, default="mask.npy", help="Where to save the boolean mask") args = parser.parse_args() # 1. Load point cloud coords, rgb, original_coords = load_pointcloud(args.ply_file, normalize=True) print(f"Loaded {len(coords)} points from {args.ply_file}") # 2. Load model model = PointSAM.from_pretrained( checkpoint_path=args.checkpoint, variant=args.variant, device=args.device, ) # 3. Cache the point cloud for fast repeated queries model.set_pointcloud(coords, rgb) # 4. Run segmentation prompt = [args.px, args.py, args.pz] masks, iou_scores = model.predict( coords=None, # use cached rgb=None, # use cached prompt_point=prompt, prompt_label=1, # foreground multimask_output=True, ) # 5. Pick best mask by IoU score best_idx = int(np.argmax(iou_scores)) best_mask = masks[best_idx] print(f"IoU scores: {iou_scores}") print(f"Selected mask {best_idx} with {best_mask.sum()} points") # 6. Save np.save(args.output, best_mask) print(f"Saved mask to {args.output}") # 7. Optional: save segmented point cloud as PLY segmented = original_coords[best_mask] segmented_rgb = rgb[best_mask] out_ply = args.output.replace(".npy", "_segmented.ply") save_ascii_ply(out_ply, segmented, segmented_rgb) print(f"Saved segmented point cloud to {out_ply}") def save_ascii_ply(path, coords, rgb): """Save a point cloud as an ASCII PLY file.""" with open(path, "w") as f: f.write("ply\n") f.write("format ascii 1.0\n") f.write(f"element vertex {len(coords)}\n") f.write("property float x\n") f.write("property float y\n") f.write("property float z\n") f.write("property uchar red\n") f.write("property uchar green\n") f.write("property uchar blue\n") f.write("end_header\n") for (x, y, z), (r, g, b) in zip(coords, rgb): f.write(f"{x} {y} {z} {int(r)} {int(g)} {int(b)}\n") if __name__ == "__main__": main()