prince-canuma commited on
Commit
5be8bc3
·
verified ·
1 Parent(s): 733f466

Upload MLX bf16 drafter from google/gemma-4-E4B-it-assistant

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ ocdbt.process_0/d/34e15ebf6483f34716b5e7e56e8eb731 filter=lfs diff=lfs merge=lfs -text
38
+ ocdbt.process_0/d/55650261a9ee8547ae9c29ce7e2e2f7e filter=lfs diff=lfs merge=lfs -text
39
+ ocdbt.process_0/d/82cf1abb2269c124294bfd46f896cacf filter=lfs diff=lfs merge=lfs -text
40
+ ocdbt.process_0/d/f03474ce9d620e5ab887fd1d181487a4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-4-E4B-it-assistant
3
+ library_name: mlx
4
+ license: gemma
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - mlx
8
+ - speculative-decoding
9
+ - mtp
10
+ - gemma
11
+ - drafter
12
+ ---
13
+ # mlx-community/gemma-4-E4B-it-assistant-bf16
14
+
15
+ This model was converted to MLX format from [`google/gemma-4-E4B-it-assistant`](https://huggingface.co/google/gemma-4-E4B-it-assistant) using mlx-vlm version **0.4.5**.
16
+ Refer to the [original model card](https://huggingface.co/google/gemma-4-E4B-it-assistant) for more details on the model.
17
+
18
+ ## Use with mlx
19
+
20
+ ```bash
21
+ pip install -U mlx-vlm
22
+ ```
23
+
24
+ ```bash
25
+ python -m mlx_vlm.generate --model mlx-community/gemma-4-E4B-it-assistant-bf16 --max-tokens 100 --temperature 0.0 --prompt "Describe this image." --image <path_to_image>
26
+ ```
27
+
28
+ ---
29
+
30
+ # Gemma 4 Assistant Drafter (MTP)
31
+
32
+ MLX port of Google's Gemma 4 **Multi-Token Prediction (MTP)** drafter for
33
+ speculative decoding. Reference:
34
+ [ai.google.dev/gemma/docs/mtp](https://ai.google.dev/gemma/docs/mtp/mtp).
35
+
36
+ ## What it is
37
+
38
+ A small, 4-layer "assistant" model trained to draft several candidate tokens
39
+ per round; the full Gemma 4 target verifies them in a single forward pass.
40
+ Accepted tokens advance, rejected ones (and everything after) are discarded.
41
+ Quality matches the target at temperature 0 (byte-identical greedy output).
42
+
43
+ The drafter is tightly coupled to the target's internals:
44
+
45
+ - **KV-cache sharing** — every drafter layer is `is_kv_shared_layer=True` and
46
+ reads K/V from the target's last full-attention and last sliding-attention
47
+ layers. The drafter has **no KV cache of its own**; its only recurrent
48
+ state is the target's last hidden, projected through `post_projection`.
49
+ - **Cross-attention from constant position** — the drafter's queries are
50
+ RoPE-rotated at the bonus token's absolute position and held constant
51
+ across all draft steps within a block.
52
+ - **Hidden+token concatenation** — drafter input each step is
53
+ `concat([target_embed(last_token), last_hidden_state], dim=-1)` of shape
54
+ `[B, 1, 2 * backbone_hidden_size]`, projected to drafter hidden size by
55
+ `pre_projection`.
56
+
57
+ ## Supported pairings
58
+
59
+ | Target | Drafter | LM head |
60
+ | ------------------------------------- | ---------------------------------------------------- | ---------------------- |
61
+ | `mlx-community/gemma-4-E2B-it-bf16` | `mlx-community/gemma-4-E2B-it-assistant-bf16` | centroid (sparse) |
62
+ | `mlx-community/gemma-4-E4B-it-bf16` | `mlx-community/gemma-4-E4B-it-assistant-bf16` | centroid (sparse) |
63
+ | `mlx-community/gemma-4-26B-A4B-it-bf16` | `mlx-community/gemma-4-26B-A4B-it-assistant-bf16` | tied dense |
64
+ | `mlx-community/gemma-4-31B-it-bf16` | `mlx-community/gemma-4-31B-it-assistant-bf16 ` | tied dense |
65
+
66
+ For E2B / E4B drafters, `use_ordered_embeddings=True` and the LM head is a
67
+ **centroid-routed sparse softmax** (`MaskedEmbedder`): the drafter scores
68
+ 2048 token clusters, materialises the top-K (default 32) clusters' tokens
69
+ (~4096 of 262144), and scatters those logits back into a full-vocab tensor —
70
+ non-selected positions filled with `min(selected) - 1` so they lose any
71
+ argmax / sampling competition.
72
+
73
+ ## Files
74
+
75
+ - `config.py` — `Gemma4AssistantConfig` (HF-compatible, flattened).
76
+ - `gemma4_assistant.py` — `Gemma4AssistantDraftModel` (forward, `bind`,
77
+ `set_shared_kv`, `draft_block`, `sanitize`).
78
+ - `masked_embedder.py` — centroid-routed sparse LM head for E2B / E4B.
79
+ - `masks.py` — bidirectional full / SWA masks for the drafter forward.
80
+ - `parity_check.py` — fake-target smoke test.
81
+
82
+ ## Usage
83
+
84
+ The drafter is auto-discovered by HF `model_type == "gemma4_assistant"`;
85
+ just pass `--draft-model` and `--draft-kind mtp` to `mlx_vlm.generate`:
86
+
87
+ ```bash
88
+ uv run python -m mlx_vlm.generate \
89
+ --model mlx-community/gemma-4-31B-it-bf16 \
90
+ --draft-model mlx-community/gemma-4-31B-it-assistant-bf16 \
91
+ --draft-kind mtp \
92
+ --draft-block-size 4 \
93
+ --prompt "Explain speculative decoding in 3 sentences." \
94
+ --max-tokens 256 --temp 0
95
+ ```
96
+
97
+ `--draft-block-size` is the number of speculatively drafted tokens per
98
+ round (google calls this `num_assistant_tokens`). The first token of the
99
+ block is the most recently accepted bonus, so the drafter actually
100
+ generates `block_size - 1` candidates each round.
101
+
102
+ Programmatic use:
103
+
104
+ ```python
105
+ from mlx_vlm.utils import load
106
+ from mlx_vlm.speculative.drafters import load_drafter
107
+ from mlx_vlm.generate import generate_step
108
+
109
+ model, processor = load("mlx-community/gemma-4-31B-it-bf16")
110
+ drafter = load_drafter("mlx-community/gemma-4-31B-it-assistant-bf16", kind="mtp")
111
+
112
+ for tok, _ in generate_step(
113
+ input_ids, model, None, None,
114
+ max_tokens=256,
115
+ draft_model=drafter,
116
+ draft_kind="mtp",
117
+ draft_block_size=4,
118
+ ):
119
+ ...
120
+ ```
121
+
122
+ ## Performance
123
+
124
+ Measured on Apple Silicon (M3 Max, 96GB RAM), 17-token prompt, max 64–96 tokens, greedy
125
+ (`temp=0`), output byte-identical to the no-drafter baseline.
126
+
127
+ Best `block_size` per (target, batch):
128
+
129
+ | Target | B | best bs | tot tok/s | speedup vs no-drafter |
130
+ |---------|----|---------|-----------|-----------------------|
131
+ | 26B-A4B | 4 | 3 | 85.5 | **3.94×** |
132
+ | 26B-A4B | 8 | 3 | 165.1 | **1.55×** |
133
+ | 31B | 4 | 3 | 17.1 | **2.29×** |
134
+ | 31B | 8 | 2 | 21.4 | **1.41×** |
135
+ | E4B | 4 | 4 | 62.1 | **1.56×** |
136
+ | E4B | 8 | 2 | 115.9 | 1.07× |
137
+ | E4B | 16 | — | — | drafter slower (≤1.0×)|
138
+
139
+ The drafter is most attractive on large/slow targets (26B-A4B, 31B) where
140
+ target forward time dominates. On the small E4B target, target forward is
141
+ already cheap and at high batch sizes the drafter's per-step overhead
142
+ exceeds the speedup it buys.
143
+
144
+ Reproduce with `scripts/mtp_batch_sweep.py`:
145
+
146
+ ```bash
147
+ uv run python scripts/mtp_batch_sweep.py \
148
+ --model mlx-community/gemma-4-26B-A4B-it-bf16 \
149
+ --drafter mlx-community/gemma-4-26B-A4B-it-assistant-bf16 \
150
+ --batch-sizes 4 8 --block-sizes 2 3 --max-tokens 64
151
+ ```
152
+
153
+ ## Smoke test
154
+
155
+ ```bash
156
+ uv run python -m mlx_vlm.speculative.drafters.gemma4_assistant.parity_check \
157
+ --drafter mlx-community/gemma-4-E4B-it-assistant-bf16
158
+ ```
159
+
160
+ Expects `forward OK: logits shape=(1, 1, 262144) ...` and
161
+ `draft_block OK: tokens shape=(1, 3) ...`. For drafters with the centroid
162
+ LM head the parity check exercises `MaskedEmbedder` end-to-end (most
163
+ positions land at the sparse `min - 1` floor).
164
+
165
+ ## Caveats
166
+
167
+ - **Sampling.** Greedy (temp 0) is verified byte-identical. Stochastic
168
+ sampling works but acceptance rates drop because drafter and target
169
+ draws diverge.
170
+ - **Multimodal prompts.** Image / audio prefill runs through the target
171
+ unchanged; speculative decoding only kicks in on the text-decode tail,
172
+ so multimodal works but the drafter only ever sees text tokens.
173
+ - **Sliding-window masks.** The bidirectional SWA mask in `masks.py`
174
+ short-circuits to `None` when `kv_len <= sliding_window`, which is the
175
+ only regime `RotatingKVCache` ever produces. Long-prompt mask paths are
176
+ effectively dead code today.
177
+ - **Batched generation.** Continuous-batching support is in
178
+ `_mtp_rounds_batch` (`mlx_vlm/generate.py`). For targets whose KV caches
179
+ don't implement `.filter()`, finished rows are kept in the batch and
180
+ simply stop emitting; throughput doesn't shrink with retired rows.
config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma4AssistantForCausalLM"
4
+ ],
5
+ "audio_token_id": 258881,
6
+ "backbone_hidden_size": 2560,
7
+ "boa_token_id": 256000,
8
+ "boi_token_id": 255999,
9
+ "centroid_intermediate_top_k": 32,
10
+ "dtype": "bfloat16",
11
+ "eoa_token_id": 258883,
12
+ "eoi_token_id": 258882,
13
+ "image_token_id": 258880,
14
+ "model_type": "gemma4_assistant",
15
+ "num_centroids": 2048,
16
+ "text_config": {
17
+ "_name_or_path": "",
18
+ "architectures": null,
19
+ "attention_bias": false,
20
+ "attention_dropout": 0.0,
21
+ "attention_k_eq_v": false,
22
+ "bos_token_id": 2,
23
+ "chunk_size_feed_forward": 0,
24
+ "dtype": "bfloat16",
25
+ "enable_moe_block": false,
26
+ "eos_token_id": 1,
27
+ "final_logit_softcapping": null,
28
+ "global_head_dim": 512,
29
+ "head_dim": 256,
30
+ "hidden_activation": "gelu_pytorch_tanh",
31
+ "hidden_size": 256,
32
+ "hidden_size_per_layer_input": 0,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1"
36
+ },
37
+ "initializer_range": 0.02,
38
+ "intermediate_size": 2048,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_types": [
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention"
49
+ ],
50
+ "max_position_embeddings": 131072,
51
+ "model_type": "gemma4_text",
52
+ "moe_intermediate_size": null,
53
+ "num_attention_heads": 4,
54
+ "num_experts": null,
55
+ "num_global_key_value_heads": null,
56
+ "num_hidden_layers": 4,
57
+ "num_key_value_heads": 2,
58
+ "num_kv_shared_layers": 4,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "pad_token_id": 0,
62
+ "problem_type": null,
63
+ "return_dict": true,
64
+ "rms_norm_eps": 1e-06,
65
+ "rope_parameters": {
66
+ "full_attention": {
67
+ "partial_rotary_factor": 0.25,
68
+ "rope_theta": 1000000.0,
69
+ "rope_type": "proportional"
70
+ },
71
+ "sliding_attention": {
72
+ "rope_theta": 10000.0,
73
+ "rope_type": "default"
74
+ }
75
+ },
76
+ "sliding_window": 512,
77
+ "tie_word_embeddings": true,
78
+ "top_k_experts": null,
79
+ "use_bidirectional_attention": null,
80
+ "use_cache": true,
81
+ "use_double_wide_mlp": false,
82
+ "vocab_size": 262144,
83
+ "vocab_size_per_layer_input": 0
84
+ },
85
+ "tie_word_embeddings": true,
86
+ "transformers_version": "5.7.0.dev0",
87
+ "use_ordered_embeddings": true
88
+ }
generation_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 1,
6
+ 106,
7
+ 50
8
+ ],
9
+ "is_assistant": true,
10
+ "num_assistant_tokens": 6,
11
+ "num_assistant_tokens_schedule": "constant",
12
+ "pad_token_id": 0,
13
+ "temperature": 1.0,
14
+ "top_k": 64,
15
+ "top_p": 0.95,
16
+ "transformers_version": "5.7.0.dev0"
17
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12875062fc25c51e8fa9b62abd2de7ad48b7d63f8559d5d604fbd5a3d6bcff16
3
+ size 159138208
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75a6583c1a418e2bbd79c60d95d28e0f5bf549ad3f2990b5bdb5238c6c2bf70c
3
+ size 32169440
tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_token": "<|audio|>",
3
+ "backend": "tokenizers",
4
+ "boa_token": "<|audio>",
5
+ "boi_token": "<|image>",
6
+ "bos_token": "<bos>",
7
+ "eoa_token": "<audio|>",
8
+ "eoc_token": "<channel|>",
9
+ "eoi_token": "<image|>",
10
+ "eos_token": "<eos>",
11
+ "eot_token": "<turn|>",
12
+ "escape_token": "<|\"|>",
13
+ "etc_token": "<tool_call|>",
14
+ "etd_token": "<tool|>",
15
+ "etr_token": "<tool_response|>",
16
+ "extra_special_tokens": [],
17
+ "image_token": "<|image|>",
18
+ "mask_token": "<mask>",
19
+ "model_max_length": 1000000000000000019884624838656,
20
+ "pad_token": "<pad>",
21
+ "padding_side": "left",
22
+ "soc_token": "<|channel>",
23
+ "sot_token": "<|turn>",
24
+ "stc_token": "<|tool_call>",
25
+ "std_token": "<|tool>",
26
+ "str_token": "<|tool_response>",
27
+ "think_token": "<|think|>",
28
+ "tokenizer_class": "GemmaTokenizer",
29
+ "unk_token": "<unk>"
30
+ }