File size: 2,996 Bytes
728e958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Example: Segment a PLY point cloud with Point-SAM.

Usage:
    python segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors
"""

import argparse

import numpy as np
from point_sam import PointSAM, load_pointcloud


def main():
    parser = argparse.ArgumentParser(description="Segment a point cloud with Point-SAM")
    parser.add_argument("ply_file", type=str, help="Path to .ply or .pcd file")
    parser.add_argument("px", type=float, help="Prompt point x")
    parser.add_argument("py", type=float, help="Prompt point y")
    parser.add_argument("pz", type=float, help="Prompt point z")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to .safetensors checkpoint")
    parser.add_argument("--variant", type=str, default="large", choices=["large", "giant"])
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--output", type=str, default="mask.npy", help="Where to save the boolean mask")
    args = parser.parse_args()

    # 1. Load point cloud
    coords, rgb, original_coords = load_pointcloud(args.ply_file, normalize=True)
    print(f"Loaded {len(coords)} points from {args.ply_file}")

    # 2. Load model
    model = PointSAM.from_pretrained(
        checkpoint_path=args.checkpoint,
        variant=args.variant,
        device=args.device,
    )

    # 3. Cache the point cloud for fast repeated queries
    model.set_pointcloud(coords, rgb)

    # 4. Run segmentation
    prompt = [args.px, args.py, args.pz]
    masks, iou_scores = model.predict(
        coords=None,  # use cached
        rgb=None,     # use cached
        prompt_point=prompt,
        prompt_label=1,  # foreground
        multimask_output=True,
    )

    # 5. Pick best mask by IoU score
    best_idx = int(np.argmax(iou_scores))
    best_mask = masks[best_idx]
    print(f"IoU scores: {iou_scores}")
    print(f"Selected mask {best_idx} with {best_mask.sum()} points")

    # 6. Save
    np.save(args.output, best_mask)
    print(f"Saved mask to {args.output}")

    # 7. Optional: save segmented point cloud as PLY
    segmented = original_coords[best_mask]
    segmented_rgb = rgb[best_mask]
    out_ply = args.output.replace(".npy", "_segmented.ply")
    save_ascii_ply(out_ply, segmented, segmented_rgb)
    print(f"Saved segmented point cloud to {out_ply}")


def save_ascii_ply(path, coords, rgb):
    """Save a point cloud as an ASCII PLY file."""
    with open(path, "w") as f:
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {len(coords)}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        f.write("property uchar red\n")
        f.write("property uchar green\n")
        f.write("property uchar blue\n")
        f.write("end_header\n")
        for (x, y, z), (r, g, b) in zip(coords, rgb):
            f.write(f"{x} {y} {z} {int(r)} {int(g)} {int(b)}\n")


if __name__ == "__main__":
    main()