File size: 2,996 Bytes
728e958 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """Example: Segment a PLY point cloud with Point-SAM.
Usage:
python segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors
"""
import argparse
import numpy as np
from point_sam import PointSAM, load_pointcloud
def main():
parser = argparse.ArgumentParser(description="Segment a point cloud with Point-SAM")
parser.add_argument("ply_file", type=str, help="Path to .ply or .pcd file")
parser.add_argument("px", type=float, help="Prompt point x")
parser.add_argument("py", type=float, help="Prompt point y")
parser.add_argument("pz", type=float, help="Prompt point z")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to .safetensors checkpoint")
parser.add_argument("--variant", type=str, default="large", choices=["large", "giant"])
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--output", type=str, default="mask.npy", help="Where to save the boolean mask")
args = parser.parse_args()
# 1. Load point cloud
coords, rgb, original_coords = load_pointcloud(args.ply_file, normalize=True)
print(f"Loaded {len(coords)} points from {args.ply_file}")
# 2. Load model
model = PointSAM.from_pretrained(
checkpoint_path=args.checkpoint,
variant=args.variant,
device=args.device,
)
# 3. Cache the point cloud for fast repeated queries
model.set_pointcloud(coords, rgb)
# 4. Run segmentation
prompt = [args.px, args.py, args.pz]
masks, iou_scores = model.predict(
coords=None, # use cached
rgb=None, # use cached
prompt_point=prompt,
prompt_label=1, # foreground
multimask_output=True,
)
# 5. Pick best mask by IoU score
best_idx = int(np.argmax(iou_scores))
best_mask = masks[best_idx]
print(f"IoU scores: {iou_scores}")
print(f"Selected mask {best_idx} with {best_mask.sum()} points")
# 6. Save
np.save(args.output, best_mask)
print(f"Saved mask to {args.output}")
# 7. Optional: save segmented point cloud as PLY
segmented = original_coords[best_mask]
segmented_rgb = rgb[best_mask]
out_ply = args.output.replace(".npy", "_segmented.ply")
save_ascii_ply(out_ply, segmented, segmented_rgb)
print(f"Saved segmented point cloud to {out_ply}")
def save_ascii_ply(path, coords, rgb):
"""Save a point cloud as an ASCII PLY file."""
with open(path, "w") as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(coords)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
f.write("end_header\n")
for (x, y, z), (r, g, b) in zip(coords, rgb):
f.write(f"{x} {y} {z} {int(r)} {int(g)} {int(b)}\n")
if __name__ == "__main__":
main()
|