File size: 7,506 Bytes
9d8d475
 
 
 
12087e5
 
 
 
 
 
 
 
9a33b78
12087e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a33b78
 
12087e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a33b78
12087e5
 
9a33b78
12087e5
9a33b78
12087e5
9a33b78
12087e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a33b78
 
12087e5
 
 
 
 
 
 
 
9a33b78
12087e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a33b78
 
12087e5
 
 
9d8d475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
---
tags:
- ml-intern
---
# Point-SAM: Promptable 3D Segmentation

A clean, self-contained Python inference package for **Point-SAM** (ICLR 2025), extending SAM's promptable segmentation to 3D point clouds.

> **Paper**: [Point-SAM: Promptable 3D Segmentation Model for Point Clouds](https://arxiv.org/abs/2406.17741)  
> **Original Code**: [github.com/zyc00/Point-SAM](https://github.com/zyc00/Point-SAM)  
> **Pretrained Weights**: [`yuchen0187/Point-SAM`](https://huggingface.co/yuchen0187/Point-SAM)

---

## Quick Start

```bash
pip install torch timm safetensors huggingface_hub numpy
```

```python
from point_sam import PointSAM, load_pointcloud

# 1. Load a point cloud (PLY or PCD)
coords, rgb, original = load_pointcloud("scene.ply")
# coords: [N, 3] normalized to [-1, 1]
# rgb:    [N, 3] in [0, 255]

# 2. Load the pretrained model (downloads weights from HF Hub)
model = PointSAM.from_pretrained(checkpoint_path="model.safetensors", device="cuda")

# 3. Cache the cloud for fast repeated queries
model.set_pointcloud(coords, rgb)

# 4. Segment with a prompt point (in normalized [-1, 1] space)
masks, iou_scores = model.predict(
    coords=None,           # use cached cloud
    rgb=None,
    prompt_point=[0.5, 0.1, -0.2],
    prompt_label=1,        # 1 = foreground, 0 = background
    multimask_output=True,
)

# 5. Pick the best mask by IoU score
best_mask = masks[iou_scores.argmax()]   # [N] boolean
```

Command-line example:

```bash
python examples/segment_ply.py scene.ply 0.5 0.1 -0.2 --checkpoint model.safetensors
```

---

## How It Works Internally

Point-SAM is a direct 3D adaptation of [SAM](https://github.com/facebookresearch/segment-anything). It has the same three-part architecture, but replaces the 2D image backbone with a **point cloud encoder**.

### 1. Point-Cloud Encoder

The encoder turns an unstructured point cloud into a compact set of **patch embeddings** β€” the 3D equivalent of image patches.

**Voronoi Tokenizer** (the key speed trick)
- Sample `G` center points from the cloud via **Farthest Point Sampling** (FPS). This spreads centers evenly across the shape.
- Group each point with its **K nearest neighbors** around one of those centers.
- Run a small **PointNet-style MLP** on each group:
  - Input: relative XYZ positions + RGB colors
  - Max-pool over the K neighbors β†’ one vector per group
- Result: `G` patch embeddings, each summarizing a local neighborhood.

**Vision Transformer (ViT) backbone**
- The patch embeddings are fed into a standard ViT β€” `eva02_large_patch14_448` for the *large* variant, or `eva_giant_patch14_560` for *giant*.
- The ViT adds learned positional embeddings based on the 3D center coordinates and runs self-attention to build a global scene representation.
- Output: `[B, num_patches, D]` embedding tensor (default `D = 256`).

### 2. Prompt Encoder

- **Point prompts**: A user clicks (or specifies) a 3D coordinate. The coordinate is mapped through a random Fourier positional encoding (same Gaussian-frequency trick SAM uses) and then a learned embedding is added depending on whether the label is **positive** (foreground) or **negative** (background).
- **Mask prompts** (optional): If you already have a rough mask from a previous iteration, it is grouped into patches (same KNN grouping as the encoder) and encoded into dense embeddings. On the first call this is `None`, so a learned "no mask" embedding is used instead.

### 3. Mask Decoder

The decoder is a **two-way transformer** β€” identical in spirit to SAM's decoder:

1. **Cross-attention layers** alternate between:
   - *Prompt tokens β†’ point cloud patches* (the prompts "look at" the scene)
   - *Point cloud patches β†’ prompt tokens* (the scene "looks back" at the prompts)
2. After 2 layers, a **final attention** from prompts to patches refines the token representation.
3. **Upsampling**: The decoder works at patch resolution. To get back to per-point logits, features are interpolated to every original point using **inverse-distance weighted KNN** (3 nearest patch centers).
4. **Hypernetwork MLPs**: Each candidate mask has its own tiny MLP that produces a dynamic weight vector. This vector is dot-producted with the upsampled per-point features to produce the final mask logits.
5. **IoU head**: A small MLP on the IoU token predicts the quality of each mask candidate. At inference time you simply pick the one with the highest predicted IoU.

The decoder always outputs **4 candidates** (1 default + 3 multimask). The first candidate is a "safe" single mask; the other three are alternatives at different granularities.

### 4. Iterative Prompt Refinement (training only)

During training, Point-SAM simulates a user iteratively adding prompts:
- Iteration 0: no prompt β†’ random positive point from the target object.
- Iteration 1: previous mask is fed back as a mask prompt; a new point prompt is sampled from the **error region** (false positives / false negatives).
- ... repeated for 5 iterations (large model) or 10 (giant).

At **inference time** you only do a single forward pass with whatever prompt you provide β€” the model was trained to produce a good mask even from one point.

---

## Supported File Formats

| Format | Notes |
|--------|-------|
| **PLY** | ASCII `.ply` with `x y z r g b` columns |
| **PCD** | ASCII `.pcd` with `x y z r g b` columns (Point Cloud Library format) |

Both loaders normalize coordinates to a **unit sphere in [-1, 1]** and scale colors to **[0, 255]**. This normalization is **required** β€” the positional encoding will raise a `ValueError` if coordinates fall outside [-1, 1].

---

## Handling Large Point Clouds

If your cloud has > 100k points, increase the patch resolution to avoid OOM:

```python
model.adjust_patch_params(num_groups=2048, group_size=256)
```

The default is `num_groups=1024, group_size=256` for the large model.

---

## What Changed From the Original Repo?

| Original | This Package |
|----------|-------------|
| Requires `hydra` + `omegaconf` for config | Pure Python, no YAML configs needed |
| Requires compiling `torkit3d` (CUDA ops) | Pure-PyTorch FPS, KNN, and index operations |
| Requires compiling `apex` for FusedLayerNorm | Standard `nn.LayerNorm` by default; apex optional |
| Scattered evaluation scripts | One clean `PointSAM` class with `predict()` |
| Heavy training codebase | Only inference + minimal model code |

---

## Citation

```bibtex
@inproceedings{
  zhou2025pointsam,
  title={Point-{SAM}: Promptable 3D Segmentation Model for Point Clouds},
  author={Yuchen Zhou and Jiayuan Gu and Tung Yen Chiang and Fanbo Xiang and Hao Su},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=yXCTDhZDh6}
}
```

## License

MIT (same as the original repository).

<!-- ml-intern-provenance -->
## Generated by ML Intern

This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.

- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "bdck/point-sam-inference"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```

For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.