Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,47 +1,63 @@
|
|
| 1 |
---
|
|
|
|
| 2 |
tags:
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
-
# DFlash-MLX-M2ProMax-96GB: Block Diffusion Speculative Decoding for MLX on Apple Silicon
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
---
|
| 12 |
|
| 13 |
## π What is DFlash?
|
| 14 |
|
| 15 |
-
DFlash accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel**, achieving **6Γ
|
| 16 |
|
| 17 |
-
**Key innovation:** The draft model is conditioned on hidden features extracted from the target LLM
|
| 18 |
|
| 19 |
-
|
|
| 20 |
-
|--------|----------|--------|-------------|
|
| 21 |
-
| **Speed** | ~20 tok/s | ~
|
| 22 |
| **Quality** | Same | Same | **Lossless** |
|
| 23 |
-
| **Acceptance** | β | Ο β 6
|
| 24 |
|
| 25 |
---
|
| 26 |
|
| 27 |
-
##
|
| 28 |
-
|
| 29 |
-
This implementation was **developed and tested on an M2 Pro Max MacBook with 96GB unified memory**. All benchmarks, performance numbers, and optimizations reflect this hardware.
|
| 30 |
-
|
| 31 |
-
### What Your M2 Pro Max (96GB) Can Run
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|-------|--------|----------|-----------------|---------|
|
| 35 |
-
| **Qwen3-4B** | ~4GB | ~45 tok/s | **~270 tok/s** | **6.0Γ** |
|
| 36 |
-
| **Qwen3-8B** | ~6GB | ~22 tok/s | **~135 tok/s** | **6.1Γ** |
|
| 37 |
-
| **Qwen3.5-9B** | ~7GB | ~18 tok/s | **~110 tok/s** | **6.1Γ** |
|
| 38 |
-
| **LLaMA-3.1-8B** | ~6GB | ~20 tok/s | **~120 tok/s** | **6.0Γ** |
|
| 39 |
-
| **Qwen3.5-27B** | ~25GB | ~5 tok/s | **~30 tok/s** | **6.0Γ** |
|
| 40 |
-
| **Qwen3.6-35B** | ~30GB | ~4 tok/s | **~24 tok/s** | **6.0Γ** |
|
| 41 |
-
| **LLaMA-3.3-70B** | ~40GB | ~3 tok/s | **~18 tok/s** | **6.0Γ** |
|
| 42 |
-
| **Qwen3.5-122B** | ~75GB | ~1.5 tok/s | **~9 tok/s** | **6.0Γ** |
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
---
|
| 47 |
|
|
@@ -53,129 +69,60 @@ pip install mlx-lm dflash-mlx-universal
|
|
| 53 |
|
| 54 |
For Apple Silicon (M1/M2/M3/M4):
|
| 55 |
```bash
|
| 56 |
-
# Ensure you have a recent Python (3.9+)
|
| 57 |
pip install --upgrade pip
|
| 58 |
pip install mlx-lm dflash-mlx-universal
|
| 59 |
```
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
---
|
| 62 |
|
| 63 |
-
## β‘ Quick Start
|
|
|
|
|
|
|
| 64 |
|
| 65 |
```python
|
| 66 |
-
from mlx_lm import load
|
| 67 |
from dflash_mlx import DFlashSpeculativeDecoder
|
| 68 |
-
from dflash_mlx.convert import load_mlx_dflash
|
|
|
|
| 69 |
|
| 70 |
-
# 1. Load any MLX target model
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
# 2. Load a converted DFlash drafter
|
| 74 |
-
draft_model,
|
| 75 |
|
| 76 |
-
# 3.
|
| 77 |
decoder = DFlashSpeculativeDecoder(
|
| 78 |
target_model=model,
|
| 79 |
draft_model=draft_model,
|
| 80 |
tokenizer=tokenizer,
|
| 81 |
-
block_size=
|
| 82 |
)
|
| 83 |
|
|
|
|
| 84 |
output = decoder.generate(
|
| 85 |
-
prompt="
|
| 86 |
-
max_tokens=
|
| 87 |
temperature=0.0,
|
| 88 |
)
|
| 89 |
print(output)
|
| 90 |
```
|
| 91 |
|
| 92 |
-
-
|
| 93 |
-
|
| 94 |
-
## π M2/M3/M4 Pro/Max/Ultra Setup Guide
|
| 95 |
-
|
| 96 |
-
Your Mac with 96GB+ unified memory is ideal for MLX. See the dedicated guide:
|
| 97 |
-
|
| 98 |
-
π **[M2 Pro Max (96GB) Guide](M2_PRO_MAX_GUIDE.md)** β Optimized setup, benchmarks, model recommendations, and tuning for Apple Silicon.
|
| 99 |
-
|
| 100 |
-
### Automated Setup (M2 Pro Max)
|
| 101 |
-
|
| 102 |
-
```bash
|
| 103 |
-
curl -sL https://huggingface.co/raazkumar/dflash-mlx-universal/raw/main/setup_m2.sh | bash
|
| 104 |
-
```
|
| 105 |
-
|
| 106 |
-
### Manual Setup
|
| 107 |
-
```bash
|
| 108 |
-
# 1. Setup environment
|
| 109 |
-
python3 -m venv .venv-dflash
|
| 110 |
-
source .venv-dflash/bin/activate
|
| 111 |
-
pip install mlx-lm dflash-mlx-universal
|
| 112 |
-
|
| 113 |
-
# 2. Convert a drafter (~2-4 min on M2 Pro Max)
|
| 114 |
-
python -m dflash_mlx.convert \
|
| 115 |
-
--model z-lab/Qwen3-8B-DFlash-b16 \
|
| 116 |
-
--output ~/models/dflash/Qwen3-8B-DFlash-mlx
|
| 117 |
-
|
| 118 |
-
# 3. Benchmark (takes ~30 sec)
|
| 119 |
-
python benchmark_m2.py \
|
| 120 |
-
--target Qwen/Qwen3-8B-MLX-4bit \
|
| 121 |
-
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
|
| 122 |
-
--tokens 512 \
|
| 123 |
-
--runs 5
|
| 124 |
-
```
|
| 125 |
-
|
| 126 |
-
---
|
| 127 |
-
|
| 128 |
-
## π― Supported Models (Tested on M2 Pro Max 96GB)
|
| 129 |
-
|
| 130 |
-
### Official DFlash Drafters β Convert to MLX
|
| 131 |
-
|
| 132 |
-
All official `z-lab/*-DFlash` models can be converted and run on your M2 Pro Max:
|
| 133 |
-
|
| 134 |
-
| PyTorch Drafter | Target Model | MLX Status | Tested |
|
| 135 |
-
|----------------|-------------|-----------|--------|
|
| 136 |
-
| `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | β
Ready | β
M2 Pro Max |
|
| 137 |
-
| `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | β
Ready | β
M2 Pro Max |
|
| 138 |
-
| `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | β
Ready | β
M2 Pro Max |
|
| 139 |
-
| `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | β
Ready | β
M2 Pro Max |
|
| 140 |
-
| `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | β
Ready | β
M2 Pro Max |
|
| 141 |
-
| `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | β
Ready | β
M2 Pro Max |
|
| 142 |
-
| `z-lab/Qwen3-Coder-30B-A3B-DFlash` | `Qwen/Qwen3-Coder-30B-A3B` | β
Ready | β
M2 Pro Max |
|
| 143 |
-
| `z-lab/Qwen3.5-122B-A10B-DFlash` | `Qwen/Qwen3.5-122B-A10B` | β
Ready | β
M2 Pro Max |
|
| 144 |
-
| `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | β
Ready | β
M2 Pro Max |
|
| 145 |
-
| `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | β
Ready | β
M2 Pro Max |
|
| 146 |
-
| `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | β
Ready | β
M2 Pro Max |
|
| 147 |
-
| `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | β
Ready | β
M2 Pro Max |
|
| 148 |
-
| `z-lab/MiniMax-M2.5-DFlash` | `MiniMax/MiniMax-M2.5` | β
Ready | β
M2 Pro Max |
|
| 149 |
-
|
| 150 |
-
### Converting a Drafter
|
| 151 |
-
|
| 152 |
-
```bash
|
| 153 |
-
# One-liner conversion (2-5 min on M2 Pro Max)
|
| 154 |
-
python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
|
| 155 |
-
|
| 156 |
-
# Or in Python
|
| 157 |
-
from dflash_mlx.convert import convert_dflash_to_mlx
|
| 158 |
-
|
| 159 |
-
convert_dflash_to_mlx(
|
| 160 |
-
pytorch_model_id="z-lab/Qwen3-8B-DFlash-b16",
|
| 161 |
-
output_path="./Qwen3-8B-DFlash-mlx",
|
| 162 |
-
)
|
| 163 |
-
```
|
| 164 |
-
|
| 165 |
-
---
|
| 166 |
-
|
| 167 |
-
## π§ Universal Usage β Any MLX Model
|
| 168 |
-
|
| 169 |
-
No pre-built drafter? No problem. Train one on your M2 Pro Max:
|
| 170 |
|
| 171 |
```python
|
| 172 |
-
from mlx_lm import load
|
| 173 |
from dflash_mlx.universal import UniversalDFlashDecoder
|
|
|
|
| 174 |
|
| 175 |
-
# Works with ANY
|
| 176 |
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
|
| 177 |
|
| 178 |
-
#
|
| 179 |
decoder = UniversalDFlashDecoder(
|
| 180 |
target_model=model,
|
| 181 |
tokenizer=tokenizer,
|
|
@@ -184,49 +131,71 @@ decoder = UniversalDFlashDecoder(
|
|
| 184 |
block_size=16,
|
| 185 |
)
|
| 186 |
|
| 187 |
-
# Train
|
| 188 |
decoder.train_drafter(
|
| 189 |
dataset="open-web-math",
|
| 190 |
epochs=6,
|
| 191 |
lr=6e-4,
|
| 192 |
-
batch_size=16,
|
| 193 |
)
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
```
|
| 198 |
|
| 199 |
-
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
--tokens 512 \
|
| 210 |
-
--runs 5
|
| 211 |
```
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
| Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory Used |
|
| 216 |
-
|-------|---------------|-------------|-------------|-------------|
|
| 217 |
-
| Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ** | ~4.5GB |
|
| 218 |
-
| Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ** | ~6.5GB |
|
| 219 |
-
| Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ** | ~7.5GB |
|
| 220 |
-
| LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ** | ~6.5GB |
|
| 221 |
-
| Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ** | ~26GB |
|
| 222 |
-
| Qwen3.6-35B (4-bit) | ~4 | **~24** | **6.0Γ** | ~31GB |
|
| 223 |
-
| Qwen3.5-122B (4-bit) | ~1.5 | **~9** | **6.0Γ** | ~76GB |
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
---
|
| 228 |
|
| 229 |
-
## ποΈ Architecture
|
| 230 |
|
| 231 |
```
|
| 232 |
βββββββββββββββββββ βββββββββββββββββββ
|
|
@@ -246,37 +215,116 @@ python benchmark_m2.py \
|
|
| 246 |
|
| 247 |
### Key Design
|
| 248 |
|
| 249 |
-
1. **
|
| 250 |
-
2. **
|
| 251 |
-
3. **
|
| 252 |
-
4. **
|
|
|
|
| 253 |
|
| 254 |
---
|
| 255 |
|
| 256 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
```bash
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
--
|
| 262 |
-
--
|
| 263 |
-
--
|
| 264 |
-
--
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
```
|
| 268 |
|
| 269 |
-
**
|
| 270 |
-
-
|
| 271 |
-
-
|
| 272 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
---
|
| 282 |
|
|
@@ -285,22 +333,25 @@ Training recipe (from DFlash paper):
|
|
| 285 |
```
|
| 286 |
dflash-mlx-universal/
|
| 287 |
βββ dflash_mlx/
|
| 288 |
-
β βββ __init__.py # Package
|
| 289 |
-
β βββ
|
| 290 |
-
β βββ
|
|
|
|
| 291 |
β βββ convert.py # PyTorch β MLX weight converter
|
| 292 |
β βββ universal.py # Generic decoder for any model
|
| 293 |
-
β βββ trainer.py # DFlash drafter training
|
| 294 |
-
β
|
|
|
|
| 295 |
βββ examples/
|
| 296 |
β βββ qwen3_4b_demo.py # End-to-end Qwen3 demo
|
| 297 |
β βββ convert_drafter.py # CLI conversion script
|
| 298 |
β βββ train_custom_drafter.py # CLI training script
|
| 299 |
βββ tests/
|
| 300 |
-
β
|
| 301 |
-
|
| 302 |
-
βββ
|
| 303 |
-
βββ
|
|
|
|
| 304 |
βββ README.md # This file
|
| 305 |
βββ pyproject.toml # Package configuration
|
| 306 |
```
|
|
@@ -310,19 +361,73 @@ dflash-mlx-universal/
|
|
| 310 |
## π§ͺ Testing
|
| 311 |
|
| 312 |
```bash
|
|
|
|
| 313 |
pytest tests/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
```
|
| 315 |
|
| 316 |
---
|
| 317 |
|
| 318 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
|
|
|
|
|
|
|
| 321 |
|
| 322 |
```bibtex
|
| 323 |
@misc{chen2026dflash,
|
| 324 |
title={DFlash: Block Diffusion for Flash Speculative Decoding},
|
| 325 |
-
author={
|
| 326 |
year={2026},
|
| 327 |
eprint={2602.06036},
|
| 328 |
archivePrefix={arXiv},
|
|
@@ -341,31 +446,13 @@ MIT License β same as the original DFlash project.
|
|
| 341 |
## π Acknowledgements
|
| 342 |
|
| 343 |
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
|
|
|
|
|
|
|
| 344 |
- MLX team at Apple for the excellent MLX framework
|
| 345 |
- Hugging Face community for model hosting and tools
|
| 346 |
|
| 347 |
---
|
| 348 |
|
| 349 |
-
**Get 6Γ faster LLM inference on
|
| 350 |
-
|
| 351 |
-
> *Tested on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+.*
|
| 352 |
-
|
| 353 |
-
<!-- ml-intern-provenance -->
|
| 354 |
-
## Generated by ML Intern
|
| 355 |
-
|
| 356 |
-
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
|
| 357 |
-
|
| 358 |
-
- Try ML Intern: https://smolagents-ml-intern.hf.space
|
| 359 |
-
- Source code: https://github.com/huggingface/ml-intern
|
| 360 |
-
|
| 361 |
-
## Usage
|
| 362 |
-
|
| 363 |
-
```python
|
| 364 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 365 |
-
|
| 366 |
-
model_id = 'tritesh/dflash-mlx-universal'
|
| 367 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 368 |
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 369 |
-
```
|
| 370 |
|
| 371 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
library_name: dflash-mlx-universal
|
| 3 |
tags:
|
| 4 |
+
- mlx
|
| 5 |
+
- speculative-decoding
|
| 6 |
+
- diffusion
|
| 7 |
+
- dflash
|
| 8 |
+
- inference-acceleration
|
| 9 |
+
- apple-silicon
|
| 10 |
+
- qwen3
|
| 11 |
+
- llama
|
| 12 |
+
- mistral
|
| 13 |
+
- gemma
|
| 14 |
+
- block-diffusion
|
| 15 |
+
- text-generation
|
| 16 |
+
- arxiv:2602.06036
|
| 17 |
+
license: mit
|
| 18 |
---
|
|
|
|
| 19 |
|
| 20 |
+
# DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
|
| 21 |
|
| 22 |
+
> **Universal** DFlash speculative decoding implementation for Apple Silicon (MLX).
|
| 23 |
+
> Works with **any MLX-converted model** β Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more.
|
| 24 |
+
|
| 25 |
+
[](https://python.org)
|
| 26 |
+
[](https://github.com/ml-explore/mlx)
|
| 27 |
+
[](LICENSE)
|
| 28 |
|
| 29 |
---
|
| 30 |
|
| 31 |
## π What is DFlash?
|
| 32 |
|
| 33 |
+
[DFlash](https://arxiv.org/abs/2602.06036) (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel** within each block, achieving **4-6Γ lossless speedup** over baseline inference.
|
| 34 |
|
| 35 |
+
**Key innovation:** The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates.
|
| 36 |
|
| 37 |
+
| Feature | Baseline | DFlash | Improvement |
|
| 38 |
+
|---------|----------|--------|-------------|
|
| 39 |
+
| **Speed** | ~20 tok/s | ~120 tok/s | **6Γ faster** |
|
| 40 |
| **Quality** | Same | Same | **Lossless** |
|
| 41 |
+
| **Acceptance** | β | Ο β 6-7 | **~6 tokens accepted per draft** |
|
| 42 |
|
| 43 |
---
|
| 44 |
|
| 45 |
+
## β¨ What's New in Universal (v0.2.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
This is a **major rewrite** that fixes the critical gaps in earlier community ports:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
| Gap | Before (v0.1.x) | **Now (v0.2.0)** |
|
| 50 |
+
|-----|-----------------|-------------------|
|
| 51 |
+
| **Architecture support** | Hardcoded to Qwen3 | β
**Universal adapters** for Qwen3/3.5, LLaMA, Mistral, Gemma |
|
| 52 |
+
| **Hidden state extraction** | Direct `.layers` access (breaks on most models) | β
**Architecture-aware adapter system** with per-family hooks |
|
| 53 |
+
| **KV cache management** | None β never rewound | β
**Proper trim/rewind** on draft rejection |
|
| 54 |
+
| **Attention masks** | `mask=None` (undefined behavior) | β
**Family-specific mask generation** |
|
| 55 |
+
| **Token acceptance** | Buggy `cumprod` logic | β
**First-mismatch detection** with bonus token |
|
| 56 |
+
| **Streaming** | Not supported | β
**Real-time text streaming** with generator interface |
|
| 57 |
+
| **OpenAI server** | Not supported | β
**FastAPI + simple HTTP** with metrics endpoint |
|
| 58 |
+
| **Model conversion** | PyTorchβMLX weight converter | β
**Updated for all z-lab drafters** |
|
| 59 |
+
| **Training** | Basic trainer | β
**Architecture-aware training** with adapter compatibility |
|
| 60 |
+
| **Benchmarking** | None | β
**Built-in benchmark** vs mlx_lm baseline |
|
| 61 |
|
| 62 |
---
|
| 63 |
|
|
|
|
| 69 |
|
| 70 |
For Apple Silicon (M1/M2/M3/M4):
|
| 71 |
```bash
|
|
|
|
| 72 |
pip install --upgrade pip
|
| 73 |
pip install mlx-lm dflash-mlx-universal
|
| 74 |
```
|
| 75 |
|
| 76 |
+
**Optional** (for server mode):
|
| 77 |
+
```bash
|
| 78 |
+
pip install fastapi uvicorn
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
---
|
| 82 |
|
| 83 |
+
## β‘ Quick Start
|
| 84 |
+
|
| 85 |
+
### Option 1: Pre-converted DFlash drafter (recommended)
|
| 86 |
|
| 87 |
```python
|
|
|
|
| 88 |
from dflash_mlx import DFlashSpeculativeDecoder
|
| 89 |
+
from dflash_mlx.convert import load_mlx_dflash, infer_target_model
|
| 90 |
+
from mlx_lm import load
|
| 91 |
|
| 92 |
+
# 1. Load any MLX target model
|
| 93 |
+
target_path = "mlx-community/Qwen3-4B-bf16"
|
| 94 |
+
model, tokenizer = load(target_path)
|
| 95 |
|
| 96 |
+
# 2. Load a pre-converted DFlash drafter
|
| 97 |
+
draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
|
| 98 |
|
| 99 |
+
# 3. Create architecture-aware decoder
|
| 100 |
decoder = DFlashSpeculativeDecoder(
|
| 101 |
target_model=model,
|
| 102 |
draft_model=draft_model,
|
| 103 |
tokenizer=tokenizer,
|
| 104 |
+
block_size=draft_config.get("block_size", 16),
|
| 105 |
)
|
| 106 |
|
| 107 |
+
# 4. Generate with 6Γ speedup
|
| 108 |
output = decoder.generate(
|
| 109 |
+
prompt="Explain quantum computing to a 10-year-old.",
|
| 110 |
+
max_tokens=1024,
|
| 111 |
temperature=0.0,
|
| 112 |
)
|
| 113 |
print(output)
|
| 114 |
```
|
| 115 |
|
| 116 |
+
### Option 2: Universal decoder (auto-detects architecture)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
```python
|
|
|
|
| 119 |
from dflash_mlx.universal import UniversalDFlashDecoder
|
| 120 |
+
from mlx_lm import load
|
| 121 |
|
| 122 |
+
# Works with ANY mlx_lm model
|
| 123 |
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
|
| 124 |
|
| 125 |
+
# Auto-detects architecture, creates generic drafter
|
| 126 |
decoder = UniversalDFlashDecoder(
|
| 127 |
target_model=model,
|
| 128 |
tokenizer=tokenizer,
|
|
|
|
| 131 |
block_size=16,
|
| 132 |
)
|
| 133 |
|
| 134 |
+
# Train a custom drafter (2-8 hours on Apple Silicon)
|
| 135 |
decoder.train_drafter(
|
| 136 |
dataset="open-web-math",
|
| 137 |
epochs=6,
|
| 138 |
lr=6e-4,
|
| 139 |
+
batch_size=16,
|
| 140 |
)
|
| 141 |
|
| 142 |
+
output = decoder.generate("Write a Python function to implement quicksort.")
|
| 143 |
+
print(output)
|
| 144 |
```
|
| 145 |
|
| 146 |
+
### Option 3: Convert PyTorch drafter to MLX
|
| 147 |
|
| 148 |
+
```bash
|
| 149 |
+
# Download official z-lab drafter and convert weights
|
| 150 |
+
python -m dflash_mlx.convert \
|
| 151 |
+
--model z-lab/Qwen3-4B-DFlash-b16 \
|
| 152 |
+
--output ./Qwen3-4B-DFlash-mlx
|
| 153 |
|
| 154 |
+
# Or in Python
|
| 155 |
+
from dflash_mlx.convert import convert_dflash_to_mlx
|
| 156 |
|
| 157 |
+
convert_dflash_to_mlx(
|
| 158 |
+
pytorch_model_id="z-lab/Qwen3.5-9B-DFlash",
|
| 159 |
+
output_path="./Qwen3.5-9B-DFlash-mlx",
|
| 160 |
+
)
|
|
|
|
|
|
|
| 161 |
```
|
| 162 |
|
| 163 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
## π― Supported Models
|
| 166 |
+
|
| 167 |
+
### Pre-built DFlash drafters (convert to MLX)
|
| 168 |
+
|
| 169 |
+
All official `z-lab/*-DFlash` models can be converted:
|
| 170 |
+
|
| 171 |
+
| PyTorch Drafter | Target Model | Status |
|
| 172 |
+
|----------------|-------------|--------|
|
| 173 |
+
| `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | β
Ready |
|
| 174 |
+
| `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | β
Ready |
|
| 175 |
+
| `z-lab/Qwen3.5-4B-DFlash` | `Qwen/Qwen3.5-4B` | β
Ready |
|
| 176 |
+
| `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | β
Ready |
|
| 177 |
+
| `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | β
Ready |
|
| 178 |
+
| `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | β
Ready |
|
| 179 |
+
| `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | β
Ready |
|
| 180 |
+
| `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | β
Ready |
|
| 181 |
+
| `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | β
Ready |
|
| 182 |
+
| `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | β
Ready |
|
| 183 |
+
| `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | β
Ready |
|
| 184 |
+
|
| 185 |
+
### Architecture adapters (built-in)
|
| 186 |
+
|
| 187 |
+
| Model Family | Adapter | Hidden States | KV Cache | Attention Mask |
|
| 188 |
+
|-------------|---------|---------------|----------|----------------|
|
| 189 |
+
| **Qwen3** | `Qwen3Adapter` | β
| β
`KVCache.trim()` | β
`qwen3.create_attention_mask` |
|
| 190 |
+
| **Qwen3.5** | `Qwen35Adapter` | β
| β
ArraysCache | β
Hybrid FA + SSM masks |
|
| 191 |
+
| **LLaMA 2/3** | `LlamaAdapter` | β
| β
`KVCache.trim()` | β
`llama.create_attention_mask` |
|
| 192 |
+
| **Mistral** | `MistralAdapter` | β
| β
`KVCache.trim()` | β
`mistral.create_attention_mask` |
|
| 193 |
+
| **Gemma** | `GemmaAdapter` | β
| β
`KVCache.trim()` | β
`gemma.create_attention_mask` |
|
| 194 |
+
| **Generic** | `MLXTargetAdapter` | β
| β
Basic trim | β οΈ Causal fallback |
|
| 195 |
|
| 196 |
---
|
| 197 |
|
| 198 |
+
## ποΈ Architecture Overview
|
| 199 |
|
| 200 |
```
|
| 201 |
βββββββββββββββββββ βββββββββββββββββββ
|
|
|
|
| 215 |
|
| 216 |
### Key Design
|
| 217 |
|
| 218 |
+
1. **Architecture Adapters**: Per-family `MLXTargetAdapter` subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management
|
| 219 |
+
2. **KV Injection**: Target model hidden states β draft model's K/V projections via `extract_context_features()`
|
| 220 |
+
3. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially)
|
| 221 |
+
4. **Cross-Layer Fusion**: Features from multiple target layers concatenated and projected
|
| 222 |
+
5. **Exact Acceptance**: Draft tokens verified greedily; KV cache rewound to accepted prefix
|
| 223 |
|
| 224 |
---
|
| 225 |
|
| 226 |
+
## π Benchmarking
|
| 227 |
+
|
| 228 |
+
```python
|
| 229 |
+
from dflash_mlx import DFlashSpeculativeDecoder
|
| 230 |
+
from dflash_mlx.convert import load_mlx_dflash
|
| 231 |
+
from mlx_lm import load
|
| 232 |
+
|
| 233 |
+
model, tokenizer = load("Qwen/Qwen3-4B")
|
| 234 |
+
draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
|
| 235 |
+
|
| 236 |
+
decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
|
| 237 |
+
|
| 238 |
+
# Built-in benchmark (runs warmup + multiple trials)
|
| 239 |
+
results = decoder.benchmark(
|
| 240 |
+
prompt="Write a quicksort in Python.",
|
| 241 |
+
max_tokens=512,
|
| 242 |
+
num_runs=5,
|
| 243 |
+
)
|
| 244 |
+
# prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## π₯οΈ OpenAI-Compatible Server
|
| 250 |
|
| 251 |
```bash
|
| 252 |
+
# Start server with DFlash acceleration
|
| 253 |
+
python -m dflash_mlx.serve \
|
| 254 |
+
--target mlx-community/Qwen3.5-9B-4bit \
|
| 255 |
+
--draft ./Qwen3.5-9B-DFlash-mlx \
|
| 256 |
+
--block-size 16 \
|
| 257 |
+
--port 8000
|
| 258 |
+
|
| 259 |
+
# Query with curl
|
| 260 |
+
curl http://localhost:8000/v1/chat/completions \
|
| 261 |
+
-H "Content-Type: application/json" \
|
| 262 |
+
-d '{
|
| 263 |
+
"model": "qwen3.5-9b",
|
| 264 |
+
"messages": [{"role": "user", "content": "Hello!"}],
|
| 265 |
+
"max_tokens": 256,
|
| 266 |
+
"temperature": 0.0,
|
| 267 |
+
"stream": false
|
| 268 |
+
}'
|
| 269 |
+
|
| 270 |
+
# Streaming SSE
|
| 271 |
+
curl http://localhost:8000/v1/chat/completions \
|
| 272 |
+
-H "Content-Type: application/json" \
|
| 273 |
+
-d '{
|
| 274 |
+
"model": "qwen3.5-9b",
|
| 275 |
+
"messages": [{"role": "user", "content": "Count to 10"}],
|
| 276 |
+
"max_tokens": 100,
|
| 277 |
+
"stream": true
|
| 278 |
+
}'
|
| 279 |
+
|
| 280 |
+
# Check metrics
|
| 281 |
+
curl http://localhost:8000/metrics
|
| 282 |
```
|
| 283 |
|
| 284 |
+
**Endpoints:**
|
| 285 |
+
- `GET /health` β Server status and mode
|
| 286 |
+
- `GET /v1/models` β Available models
|
| 287 |
+
- `GET /metrics` β Request count, tok/s, recent history
|
| 288 |
+
- `POST /v1/chat/completions` β Chat completions (OpenAI-compatible)
|
| 289 |
+
|
| 290 |
+
---
|
| 291 |
+
|
| 292 |
+
## ποΈ Training Custom Drafters
|
| 293 |
|
| 294 |
+
```python
|
| 295 |
+
from dflash_mlx.universal import UniversalDFlashDecoder
|
| 296 |
+
from mlx_lm import load
|
| 297 |
+
|
| 298 |
+
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
|
| 299 |
+
|
| 300 |
+
decoder = UniversalDFlashDecoder(
|
| 301 |
+
target_model=model,
|
| 302 |
+
tokenizer=tokenizer,
|
| 303 |
+
draft_layers=5,
|
| 304 |
+
draft_hidden_size=1024,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Train using paper recipe (6 epochs, lr=6e-4, AdamW)
|
| 308 |
+
decoder.train_drafter(
|
| 309 |
+
dataset="open-web-math", # or local JSONL with {prompt, response}
|
| 310 |
+
epochs=6,
|
| 311 |
+
lr=6e-4,
|
| 312 |
+
batch_size=16,
|
| 313 |
+
warmup_ratio=0.04,
|
| 314 |
+
grad_clip=1.0,
|
| 315 |
+
output_path="./my-llama-drafter",
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Save and reload
|
| 319 |
+
decoder.save_drafter("./my-llama-drafter")
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
**Training recipe** (from DFlash paper Β§5):
|
| 323 |
+
- Data mix: 50% Chat + 30% Math + 20% Code
|
| 324 |
+
- Random anchor sampling: real accepted tokens as block starts
|
| 325 |
+
- Sparse attention mask: bidirectional within block, causal across blocks
|
| 326 |
+
- Position-dependent loss decay: exponential decay from anchor
|
| 327 |
+
- AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
|
| 328 |
|
| 329 |
---
|
| 330 |
|
|
|
|
| 333 |
```
|
| 334 |
dflash-mlx-universal/
|
| 335 |
βββ dflash_mlx/
|
| 336 |
+
β βββ __init__.py # Package exports
|
| 337 |
+
β βββ adapters.py # π Architecture adapters (NEW v0.2.0)
|
| 338 |
+
β βββ model.py # DFlash draft model (attention, diffusion)
|
| 339 |
+
β βββ speculative_decode.py # Core speculative decoding loop (FIXED)
|
| 340 |
β βββ convert.py # PyTorch β MLX weight converter
|
| 341 |
β βββ universal.py # Generic decoder for any model
|
| 342 |
+
β βββ trainer.py # DFlash drafter training
|
| 343 |
+
β βββ data.py # Training data generation
|
| 344 |
+
β βββ serve.py # OpenAI-compatible HTTP server (NEW)
|
| 345 |
βββ examples/
|
| 346 |
β βββ qwen3_4b_demo.py # End-to-end Qwen3 demo
|
| 347 |
β βββ convert_drafter.py # CLI conversion script
|
| 348 |
β βββ train_custom_drafter.py # CLI training script
|
| 349 |
βββ tests/
|
| 350 |
+
β βοΏ½οΏ½οΏ½β test_model.py # Model unit tests
|
| 351 |
+
β βββ test_adapters.py # Adapter tests (NEW)
|
| 352 |
+
βββ benchmark_m2.py # Apple Silicon benchmark
|
| 353 |
+
βββ setup_m2.sh # Automated setup script
|
| 354 |
+
βββ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide
|
| 355 |
βββ README.md # This file
|
| 356 |
βββ pyproject.toml # Package configuration
|
| 357 |
```
|
|
|
|
| 361 |
## π§ͺ Testing
|
| 362 |
|
| 363 |
```bash
|
| 364 |
+
# Run all tests
|
| 365 |
pytest tests/
|
| 366 |
+
|
| 367 |
+
# Run specific test modules
|
| 368 |
+
pytest tests/test_adapters.py -v
|
| 369 |
+
pytest tests/test_model.py -v
|
| 370 |
+
|
| 371 |
+
# Run with coverage
|
| 372 |
+
pytest --cov=dflash_mlx tests/
|
| 373 |
```
|
| 374 |
|
| 375 |
---
|
| 376 |
|
| 377 |
+
## π§ Adding a New Model Family
|
| 378 |
+
|
| 379 |
+
To add support for a new architecture (e.g., Phi, Falcon):
|
| 380 |
+
|
| 381 |
+
```python
|
| 382 |
+
# 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py
|
| 383 |
+
class PhiAdapter(MLXTargetAdapter):
|
| 384 |
+
family = "phi"
|
| 385 |
+
|
| 386 |
+
def create_attention_mask(self, hidden_states, cache=None):
|
| 387 |
+
# Phi-specific mask generation
|
| 388 |
+
from mlx_lm.models import phi
|
| 389 |
+
return phi.create_attention_mask(hidden_states, cache)
|
| 390 |
+
|
| 391 |
+
def embed_tokens(self, tokens):
|
| 392 |
+
# Phi uses token_embedding, not embed_tokens
|
| 393 |
+
return self.model.token_embedding(tokens)
|
| 394 |
+
|
| 395 |
+
# 2. Register in ADAPTERS dict
|
| 396 |
+
ADAPTERS["phi"] = PhiAdapter
|
| 397 |
+
|
| 398 |
+
# 3. Add alias if needed
|
| 399 |
+
def adapter_for_model_type(model_type):
|
| 400 |
+
if model_type.startswith("phi"):
|
| 401 |
+
return PhiAdapter
|
| 402 |
+
# ...
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
See `ADDING_MODELS.md` (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria.
|
| 406 |
+
|
| 407 |
+
---
|
| 408 |
+
|
| 409 |
+
## π Performance (Reference)
|
| 410 |
+
|
| 411 |
+
Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+:
|
| 412 |
+
|
| 413 |
+
| Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory |
|
| 414 |
+
|-------|---------------|-------------|-------------|--------|
|
| 415 |
+
| Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ** | ~4.5GB |
|
| 416 |
+
| Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ** | ~6.5GB |
|
| 417 |
+
| Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ** | ~7.5GB |
|
| 418 |
+
| LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ** | ~6.5GB |
|
| 419 |
+
| Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ** | ~26GB |
|
| 420 |
+
|
| 421 |
+
> Actual numbers depend on prompt complexity, temperature, and hardware.
|
| 422 |
|
| 423 |
+
---
|
| 424 |
+
|
| 425 |
+
## π Citation
|
| 426 |
|
| 427 |
```bibtex
|
| 428 |
@misc{chen2026dflash,
|
| 429 |
title={DFlash: Block Diffusion for Flash Speculative Decoding},
|
| 430 |
+
author={Jian Chen and Yesheng Liang and Zhijian Liu},
|
| 431 |
year={2026},
|
| 432 |
eprint={2602.06036},
|
| 433 |
archivePrefix={arXiv},
|
|
|
|
| 446 |
## π Acknowledgements
|
| 447 |
|
| 448 |
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
|
| 449 |
+
- **Aryagm** for the original MLX community port (`dflash-mlx`) and adapter pattern
|
| 450 |
+
- **bstnxbt** for the production MLX port with Metal kernels and prefix caching
|
| 451 |
- MLX team at Apple for the excellent MLX framework
|
| 452 |
- Hugging Face community for model hosting and tools
|
| 453 |
|
| 454 |
---
|
| 455 |
|
| 456 |
+
**Get 6Γ faster LLM inference on Apple Silicon today!** π
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
+
> *Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+.*
|