| """Example: Segment a PLY point cloud with Point-SAM. |
| |
| Usage: |
| python segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors |
| """ |
|
|
| import argparse |
|
|
| import numpy as np |
| from point_sam import PointSAM, load_pointcloud |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Segment a point cloud with Point-SAM") |
| parser.add_argument("ply_file", type=str, help="Path to .ply or .pcd file") |
| parser.add_argument("px", type=float, help="Prompt point x") |
| parser.add_argument("py", type=float, help="Prompt point y") |
| parser.add_argument("pz", type=float, help="Prompt point z") |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to .safetensors checkpoint") |
| parser.add_argument("--variant", type=str, default="large", choices=["large", "giant"]) |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--output", type=str, default="mask.npy", help="Where to save the boolean mask") |
| args = parser.parse_args() |
|
|
| |
| coords, rgb, original_coords = load_pointcloud(args.ply_file, normalize=True) |
| print(f"Loaded {len(coords)} points from {args.ply_file}") |
|
|
| |
| model = PointSAM.from_pretrained( |
| checkpoint_path=args.checkpoint, |
| variant=args.variant, |
| device=args.device, |
| ) |
|
|
| |
| model.set_pointcloud(coords, rgb) |
|
|
| |
| prompt = [args.px, args.py, args.pz] |
| masks, iou_scores = model.predict( |
| coords=None, |
| rgb=None, |
| prompt_point=prompt, |
| prompt_label=1, |
| multimask_output=True, |
| ) |
|
|
| |
| best_idx = int(np.argmax(iou_scores)) |
| best_mask = masks[best_idx] |
| print(f"IoU scores: {iou_scores}") |
| print(f"Selected mask {best_idx} with {best_mask.sum()} points") |
|
|
| |
| np.save(args.output, best_mask) |
| print(f"Saved mask to {args.output}") |
|
|
| |
| segmented = original_coords[best_mask] |
| segmented_rgb = rgb[best_mask] |
| out_ply = args.output.replace(".npy", "_segmented.ply") |
| save_ascii_ply(out_ply, segmented, segmented_rgb) |
| print(f"Saved segmented point cloud to {out_ply}") |
|
|
|
|
| def save_ascii_ply(path, coords, rgb): |
| """Save a point cloud as an ASCII PLY file.""" |
| with open(path, "w") as f: |
| f.write("ply\n") |
| f.write("format ascii 1.0\n") |
| f.write(f"element vertex {len(coords)}\n") |
| f.write("property float x\n") |
| f.write("property float y\n") |
| f.write("property float z\n") |
| f.write("property uchar red\n") |
| f.write("property uchar green\n") |
| f.write("property uchar blue\n") |
| f.write("end_header\n") |
| for (x, y, z), (r, g, b) in zip(coords, rgb): |
| f.write(f"{x} {y} {z} {int(r)} {int(g)} {int(b)}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|