avbiswas commited on
Commit
281b886
·
verified ·
1 Parent(s): 44eb08d

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +214 -0
  2. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sam-mlx
2
+
3
+ MLX inference port of Meta's SAM 2.1, currently targeting
4
+ `facebook/sam2.1-hiera-small`.
5
+
6
+ The runtime package is Python 3.14 + MLX and does not install PyTorch. PyTorch is
7
+ only used through the optional `torch-parity` extra for checkpoint conversion and
8
+ parity fixtures.
9
+
10
+ ## Current Checkpoint
11
+
12
+ Expected local source checkpoint:
13
+
14
+ ```text
15
+ checkpoints/sam2.1_hiera_small.pt
16
+ ```
17
+
18
+ Converted MLX checkpoint:
19
+
20
+ ```text
21
+ checkpoints/sam2.1_hiera_small_image_segmenter.safetensors
22
+ ```
23
+
24
+ This converted checkpoint includes:
25
+
26
+ - Hiera image encoder
27
+ - FPN neck
28
+ - prompt encoder
29
+ - mask decoder
30
+ - object pointer projection
31
+ - memory encoder
32
+ - memory attention
33
+
34
+ The older image-encoder-only conversion may also exist locally:
35
+
36
+ ```text
37
+ checkpoints/sam2.1_hiera_small_image_encoder.safetensors
38
+ ```
39
+
40
+ Generated checkpoints are ignored by git.
41
+
42
+ ## Setup
43
+
44
+ ```bash
45
+ uv sync --python 3.14
46
+ ```
47
+
48
+ For Torch parity and conversion scripts:
49
+
50
+ ```bash
51
+ uv sync --python 3.14 --extra torch-parity
52
+ ```
53
+
54
+ Reference repositories are expected locally but are not runtime dependencies:
55
+
56
+ ```text
57
+ third_party/sam2
58
+ references/mlx-vlm
59
+ ```
60
+
61
+ ## Convert Weights
62
+
63
+ ```bash
64
+ uv run --extra torch-parity python scripts/convert_image_encoder_weights.py
65
+ ```
66
+
67
+ This writes:
68
+
69
+ ```text
70
+ checkpoints/sam2.1_hiera_small_image_segmenter.safetensors
71
+ ```
72
+
73
+ ## Parity Fixtures
74
+
75
+ Generate Torch image-embedding fixtures:
76
+
77
+ ```bash
78
+ uv run --extra torch-parity python scripts/export_torch_image_embeddings.py --frames 2
79
+ uv run python scripts/compare_image_embeddings.py
80
+ ```
81
+
82
+ Generate Torch prompted-mask fixtures:
83
+
84
+ ```bash
85
+ uv run --extra torch-parity python scripts/export_torch_prompt_mask.py
86
+ uv run python scripts/compare_prompt_mask.py
87
+ ```
88
+
89
+ Current parity results:
90
+
91
+ - Image `vision_features` max abs error: about `1.63e-05`
92
+ - Prompted low-res masks max abs error: about `4.67e-05`
93
+ - Prompted IoU max abs error: about `4.77e-07`
94
+
95
+ Reports are written under:
96
+
97
+ ```text
98
+ outputs/parity/
99
+ ```
100
+
101
+ ## Image Segmentation
102
+
103
+ Run one prompted frame and write an overlay:
104
+
105
+ ```bash
106
+ uv run python scripts/predict_image_mask.py \
107
+ --point 500 610 \
108
+ --output-video outputs/image_prompt_overlay.mp4 \
109
+ --output-mask outputs/image_prompt_mask.npy
110
+ ```
111
+
112
+ Coordinates are in the resized `1024x1024` SAM input space.
113
+
114
+ ## Video Tracking
115
+
116
+ Mask-prompt feedback baseline:
117
+
118
+ ```bash
119
+ uv run python scripts/propagate_video_masks.py --frames 30
120
+ ```
121
+
122
+ SAM2 memory tracker:
123
+
124
+ ```bash
125
+ uv run python scripts/track_video_memory.py --frames 150 \
126
+ --point 500 610 \
127
+ --output-video outputs/dog_memory_overlay_150f_v2.mp4 \
128
+ --output-mask outputs/dog_memory_masks_150f_v2.npy \
129
+ --report outputs/benchmarks/dog_memory_latency_150f_v2.json
130
+ ```
131
+
132
+ The current memory tracker uses:
133
+
134
+ - first-frame point prompt
135
+ - SAM2 memory encoder
136
+ - SAM2 memory attention
137
+ - object pointers
138
+ - up to the last six memory frames
139
+
140
+ It is not yet a drop-in clone of Facebook's full `SAM2VideoPredictor` state
141
+ machine. Missing higher-level behavior includes correction clicks,
142
+ bidirectional propagation, multi-object consolidation, official conditioning
143
+ frame selection, and exact full-video parity tests.
144
+
145
+ ## Overlay Utility
146
+
147
+ Render masks onto a video:
148
+
149
+ ```bash
150
+ uv run python scripts/overlay_masks.py \
151
+ --masks outputs/dog_memory_masks_150f_v2.npy \
152
+ --output outputs/dog_memory_overlay_from_masks.mp4
153
+ ```
154
+
155
+ The overlay script accepts `.npy` or `.npz` masks shaped `T,H,W` or `T,1,H,W`.
156
+ Synthetic overlays are only for writer smoke tests and require:
157
+
158
+ ```bash
159
+ uv run python scripts/overlay_masks.py --synthetic-smoke-test
160
+ ```
161
+
162
+ ## Benchmarks
163
+
164
+ Image encoder:
165
+
166
+ ```bash
167
+ uv run --extra torch-parity python scripts/benchmark_image_encoder.py --warmup 3 --runs 10
168
+ ```
169
+
170
+ Prompt segmentation:
171
+
172
+ ```bash
173
+ uv run python scripts/benchmark_prompt_segmenter.py --warmup 3 --runs 20
174
+ ```
175
+
176
+ Video memory tracking:
177
+
178
+ ```bash
179
+ uv run python scripts/track_video_memory.py --frames 150 \
180
+ --report outputs/benchmarks/video_memory_latency_150f.json
181
+ ```
182
+
183
+ Current indicative numbers on this machine:
184
+
185
+ - Image encoder MLX: about `81 ms/frame`
186
+ - Image encoder Torch/MPS: about `104 ms/frame`
187
+ - MLX image encoder speedup: about `1.28x`
188
+ - Cached prompt decode: about `4 ms`
189
+ - Full image + prompt: about `85 ms`
190
+ - Last-six-frame memory tracker: about `235 ms/frame` on the 150-frame run
191
+
192
+ Benchmark reports are written under:
193
+
194
+ ```text
195
+ outputs/benchmarks/
196
+ ```
197
+
198
+ ## Runtime Dependency Boundary
199
+
200
+ Default runtime should not include Torch:
201
+
202
+ ```bash
203
+ uv sync --python 3.14
204
+ uv run python - <<'PY'
205
+ import importlib.util as u
206
+ print({m: bool(u.find_spec(m)) for m in ["torch", "torchvision", "hydra", "iopath", "mlx", "cv2"]})
207
+ PY
208
+ ```
209
+
210
+ Expected:
211
+
212
+ ```text
213
+ torch=False, torchvision=False, hydra=False, iopath=False, mlx=True, cv2=True
214
+ ```
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:913d910396de4fb221b2baad3272fae29f35c7ae691fcfa1523e86d923d16a09
3
+ size 209356497