File size: 4,330 Bytes
e6cc1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
---
license: apache-2.0
library_name: mlx
pipeline_tag: image-segmentation
tags:
  - mlx
  - sam2
  - segment-anything
  - image-segmentation
  - video-segmentation
  - video-object-tracking
  - apple-silicon
base_model:
  - facebook/sam2.1-hiera-tiny
  - facebook/sam2.1-hiera-small
  - facebook/sam2.1-hiera-base-plus
  - facebook/sam2.1-hiera-large
---

# SAM 2.1 MLX

MLX-native ports of Meta/Facebook SAM 2.1 models for Apple Silicon.

This model is converted from Meta's SAM 2.1 checkpoints and the official
`facebookresearch/sam2` implementation. It is intended for local image
segmentation and video object tracking with MLX, without requiring PyTorch at
runtime.

- Project repo: https://github.com/avbiswas/sam2-mlx
- Model collection: https://huggingface.co/collections/avbiswas/sam2-mlx-6a0a0dcfbbbcb089d13d23cd
- Original SAM2 repo: https://github.com/facebookresearch/sam2
- Original models: https://huggingface.co/facebook

## Install

```bash
pip install mlx-sam
```

or with uv:

```bash
uv pip install mlx-sam
```

## Usage

```python
import numpy as np
from mlx_sam import SAM2VideoPredictor

predictor = SAM2VideoPredictor.from_pretrained(
    "avbiswas/sam2.1-hiera-small-mlx"  # replace with this model repo id
)

state = predictor.init_state("path/to/video_or_frames")

predictor.add_new_points_or_box(
    state,
    frame_idx=0,
    obj_id=1,
    points=np.array([[625.0, 429.0]], dtype=np.float32),
    labels=np.array([1], dtype=np.int32),
)

for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
    # masks: NumPy float32 array shaped [objects, 1, height, width]
    pass
```

## Benchmarks

Benchmarks were run on an Apple M2 Max with 32 GB unified memory. Video tests
use the SAM2 dog demo clip: `1280x720`, 289 frames, 29.97 FPS, `9.64 s`.

### FP32 MLX vs Torch/MPS

Prompted first-frame fixture at `1024x1024` internal resolution.

| Model | Size | Torch/MPS | MLX | Speedup | Parity vs Torch |
| --- | ---: | ---: | ---: | ---: | --- |
| `sam2.1-hiera-tiny-mlx` | `172.6 MiB` | `96.6 ms` | `71.3 ms` | `1.36x` | mask mean abs `1.17e-05` |
| `sam2.1-hiera-small-mlx` | `199.7 MiB` | `112.5 ms` | `84.5 ms` | `1.33x` | mask mean abs `8.14e-06` |
| `sam2.1-hiera-base-plus-mlx` | `336.4 MiB` | `203.5 ms` | `144.7 ms` | `1.41x` | mask mean abs `5.04e-06` |
| `sam2.1-hiera-large-mlx` | `892.2 MiB` | `433.0 ms` | `341.1 ms` | `1.27x` | mask mean abs `7.84e-06` |

### Video Tracking

For `sam2.1-hiera-small-mlx` on the 9.64 second dog clip:

| Workload | Torch/MPS | MLX | Result |
| --- | ---: | ---: | --- |
| Full video, post-prompt propagation | `331 ms/frame` | `189 ms/frame` | MLX `1.75x` faster |
| Full video, total run | `100.5 s` | `94.8 s` | MLX faster end to end |
| Raw propagation, no save/overlay/final resize | `407 ms/frame` | `287 ms/frame` | MLX `1.42x` faster |

Experimental preview mode at `768x768` internal resolution:

| Setting | Propagation | Quality vs 1024 |
| --- | ---: | --- |
| `1024x1024` baseline | `268.5 ms/frame` | reference |
| `768x768`, fp16 memory attention | `52.9 ms/frame` | mean IoU `0.949`, presence `80 / 80` on 80-frame dog clip |

### Quantized Variants

Quantized models reduce download size and memory footprint. On current MLX
kernels, quantization should not be assumed to speed up video tracking; it
primarily helps memory and distribution size.

| Variant | Typical Size Reduction | Notes |
| --- | ---: | --- |
| `*-mlx-16bit` | about `2x` smaller | fp16 weights, closest quantized parity |
| `*-mlx-8bit` | about `2.5x-3x` smaller | int8 linear quantization |
| `*-mlx-4bit` | about `3.5x` smaller | mixed recipe: int8 trunk/mask decoder, int4 memory/object-pointer layers |

Example small model parity vs fp32 MLX:

| Model | Size | Parity vs fp32 MLX |
| --- | ---: | --- |
| `sam2.1-hiera-small-mlx-16bit` | `99.9 MiB` | mask mean abs `8.24e-03` |
| `sam2.1-hiera-small-mlx-8bit` | `76.7 MiB` | mask mean abs `2.99e-02` |
| `sam2.1-hiera-small-mlx-4bit` | `56.4 MiB` | mask mean abs `2.87e-02` |

## License

This MLX port is released under the Apache 2.0 license.

The original SAM 2 repository and source models are from Meta/Facebook and are
also Apache 2.0 licensed.

- Original SAM2 license: https://github.com/facebookresearch/sam2/blob/main/LICENSE
- Original SAM2 repo: https://github.com/facebookresearch/sam2