| --- |
| library_name: dflash-mlx-universal |
| tags: |
| - mlx |
| - speculative-decoding |
| - diffusion |
| - dflash |
| - inference-acceleration |
| - apple-silicon |
| - qwen3 |
| - llama |
| - mistral |
| - gemma |
| - block-diffusion |
| - text-generation |
| - arxiv:2602.06036 |
| license: mit |
| --- |
| |
| # DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX |
|
|
| > **Universal** DFlash speculative decoding implementation for Apple Silicon (MLX). |
| > Works with **any MLX-converted model** β Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more. |
|
|
| [](https://python.org) |
| [](https://github.com/ml-explore/mlx) |
| [](LICENSE) |
| [](https://github.com/astral-sh/uv) |
|
|
| --- |
|
|
| ## π What is DFlash? |
|
|
| [DFlash](https://arxiv.org/abs/2602.06036) (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel** within each block, achieving **4-6Γ lossless speedup** over baseline inference. |
|
|
| **Key innovation:** The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates. |
|
|
| | Feature | Baseline | DFlash | Improvement | |
| |---------|----------|--------|-------------| |
| | **Speed** | ~20 tok/s | ~120 tok/s | **6Γ faster** | |
| | **Quality** | Same | Same | **Lossless** | |
| | **Acceptance** | β | Ο β 6-7 | **~6 tokens accepted per draft** | |
|
|
| --- |
|
|
| ## β¨ What's New in Universal (v0.2.0) |
|
|
| This is a **major rewrite** that fixes the critical gaps in earlier community ports: |
|
|
| | Gap | Before (v0.1.x) | **Now (v0.2.0)** | |
| |-----|-----------------|-------------------| |
| | **Architecture support** | Hardcoded to Qwen3 | β
**Universal adapters** for Qwen3/3.5, LLaMA, Mistral, Gemma | |
| | **Hidden state extraction** | Direct `.layers` access (breaks on most models) | β
**Architecture-aware adapter system** with per-family hooks | |
| | **KV cache management** | None β never rewound | β
**Proper trim/rewind** on draft rejection | |
| | **Attention masks** | `mask=None` (undefined behavior) | β
**Family-specific mask generation** | |
| | **Token acceptance** | Buggy `cumprod` logic | β
**First-mismatch detection** with bonus token | |
| | **Streaming** | Not supported | β
**Real-time text streaming** with generator interface | |
| | **OpenAI server** | Not supported | β
**FastAPI + simple HTTP** with metrics endpoint | |
| | **Model conversion** | PyTorchβMLX weight converter | β
**Updated for all z-lab drafters** | |
| | **Training** | Basic trainer | β
**Architecture-aware training** with adapter compatibility | |
| | **Benchmarking** | None | β
**Built-in benchmark** vs mlx_lm baseline | |
| | **uv support** | pip only | β
**`uv` + `uv run` workflow** with lock files | |
| |
| --- |
| |
| ## π¦ Installation |
| |
| ### Option 1: `uv` (Recommended β ultra-fast, reproducible) |
| |
| [`uv`](https://github.com/astral-sh/uv) is an extremely fast Python package manager written in Rust. It's the **recommended** way to install on macOS. |
| |
| ```bash |
| # 1. Install uv (one-time) |
| brew install uv |
| # or: curl -LsSf https://astral.sh/uv/install.sh | sh |
| |
| # 2. Clone and setup |
| git clone https://huggingface.co/tritesh/dflash-mlx-universal.git |
| cd dflash-mlx-universal |
| |
| # 3. One-command setup (creates venv, installs deps, locks) |
| chmod +x setup_uv.sh |
| ./setup_uv.sh |
| |
| # Or manually: |
| uv venv |
| uv pip install -e ".[dev,server]" |
| uv lock |
| ``` |
| |
| **Why `uv`?** |
| - 10-100Γ faster than `pip` (written in Rust) |
| - Automatic virtual environment management |
| - Lock file (`uv.lock`) for reproducible installs |
| - `uv run` β run any script without activating venv manually |
| |
| ```bash |
| # Examples of uv workflow |
| uv run python examples/qwen3_4b_demo.py |
| uv run pytest tests/ -v |
| uv run python -m dflash_mlx.serve --target ... --draft ... --port 8000 |
| uv run black dflash_mlx/ |
| uv run ruff check dflash_mlx/ |
| ``` |
| |
| ### Option 2: pip (Classic) |
| |
| ```bash |
| pip install mlx-lm dflash-mlx-universal |
| ``` |
| |
| For Apple Silicon (M1/M2/M3/M4): |
| ```bash |
| pip install --upgrade pip |
| pip install mlx-lm dflash-mlx-universal |
| ``` |
| |
| **Optional** (for server mode): |
| ```bash |
| pip install fastapi uvicorn |
| ``` |
| |
| --- |
| |
| ## β‘ Quick Start |
| |
| ### Option 1: Pre-converted DFlash drafter (recommended) |
| |
| ```python |
| from dflash_mlx import DFlashSpeculativeDecoder |
| from dflash_mlx.convert import load_mlx_dflash, infer_target_model |
| from mlx_lm import load |
| |
| # 1. Load any MLX target model |
| target_path = "mlx-community/Qwen3-4B-bf16" |
| model, tokenizer = load(target_path) |
| |
| # 2. Load a pre-converted DFlash drafter |
| draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") |
| |
| # 3. Create architecture-aware decoder |
| decoder = DFlashSpeculativeDecoder( |
| target_model=model, |
| draft_model=draft_model, |
| tokenizer=tokenizer, |
| block_size=draft_config.get("block_size", 16), |
| ) |
| |
| # 4. Generate with 6Γ speedup |
| output = decoder.generate( |
| prompt="Explain quantum computing to a 10-year-old.", |
| max_tokens=1024, |
| temperature=0.0, |
| ) |
| print(output) |
| ``` |
| |
| ### Option 2: Universal decoder (auto-detects architecture) |
|
|
| ```python |
| from dflash_mlx.universal import UniversalDFlashDecoder |
| from mlx_lm import load |
| |
| # Works with ANY mlx_lm model |
| model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit") |
| |
| # Auto-detects architecture, creates generic drafter |
| decoder = UniversalDFlashDecoder( |
| target_model=model, |
| tokenizer=tokenizer, |
| draft_layers=5, |
| draft_hidden_size=1024, |
| block_size=16, |
| ) |
| |
| # Train a custom drafter (2-8 hours on Apple Silicon) |
| decoder.train_drafter( |
| dataset="open-web-math", |
| epochs=6, |
| lr=6e-4, |
| batch_size=16, |
| ) |
| |
| output = decoder.generate("Write a Python function to implement quicksort.") |
| print(output) |
| ``` |
|
|
| ### Option 3: Convert PyTorch drafter to MLX |
|
|
| ```bash |
| # Download official z-lab drafter and convert weights |
| python -m dflash_mlx.convert \ |
| --model z-lab/Qwen3-4B-DFlash-b16 \ |
| --output ./Qwen3-4B-DFlash-mlx |
| |
| # Or with uv (recommended) |
| uv run python -m dflash_mlx.convert \ |
| --model z-lab/Qwen3-4B-DFlash-b16 \ |
| --output ./Qwen3-4B-DFlash-mlx |
| |
| # Or in Python |
| from dflash_mlx.convert import convert_dflash_to_mlx |
| |
| convert_dflash_to_mlx( |
| pytorch_model_id="z-lab/Qwen3.5-9B-DFlash", |
| output_path="./Qwen3.5-9B-DFlash-mlx", |
| ) |
| ``` |
|
|
| --- |
|
|
| ## π― Supported Models |
|
|
| ### Pre-built DFlash drafters (convert to MLX) |
|
|
| All official `z-lab/*-DFlash` models can be converted: |
|
|
| | PyTorch Drafter | Target Model | Status | |
| |----------------|-------------|--------| |
| | `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | β
Ready | |
| | `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | β
Ready | |
| | `z-lab/Qwen3.5-4B-DFlash` | `Qwen/Qwen3.5-4B` | β
Ready | |
| | `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | β
Ready | |
| | `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | β
Ready | |
| | `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | β
Ready | |
| | `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | β
Ready | |
| | `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | β
Ready | |
| | `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | β
Ready | |
| | `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | β
Ready | |
| | `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | β
Ready | |
|
|
| ### Architecture adapters (built-in) |
|
|
| | Model Family | Adapter | Hidden States | KV Cache | Attention Mask | |
| |-------------|---------|---------------|----------|----------------| |
| | **Qwen3** | `Qwen3Adapter` | β
| β
`KVCache.trim()` | β
`qwen3.create_attention_mask` | |
| | **Qwen3.5** | `Qwen35Adapter` | β
| β
ArraysCache | β
Hybrid FA + SSM masks | |
| | **LLaMA 2/3** | `LlamaAdapter` | β
| β
`KVCache.trim()` | β
`llama.create_attention_mask` | |
| | **Mistral** | `MistralAdapter` | β
| β
`KVCache.trim()` | β
`mistral.create_attention_mask` | |
| | **Gemma** | `GemmaAdapter` | β
| β
`KVCache.trim()` | β
`gemma.create_attention_mask` | |
| | **Generic** | `MLXTargetAdapter` | β
| β
Basic trim | β οΈ Causal fallback | |
|
|
| --- |
|
|
| ## ποΈ Architecture Overview |
|
|
| ``` |
| βββββββββββββββββββ βββββββββββββββββββ |
| β Target Model ββββββΆβ Extract Hidden β |
| β (Any MLX LLM) β β Features (KV) β |
| βββββββββββββββββββ ββββββββββ¬βββββββββ |
| β |
| βΌ |
| βββββββββββββββββββ βββββββββββββββββββ |
| β Verify Drafts βββββββ DFlash Draft β |
| β (Parallel) β β Model (Diffusion) |
| βββββββββββββββββββ βββββββββββββββββββ |
| β β² |
| β Accepted Tokens β |
| ββββββββββββββββββββββββββ |
| ``` |
|
|
| ### Key Design |
|
|
| 1. **Architecture Adapters**: Per-family `MLXTargetAdapter` subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management |
| 2. **KV Injection**: Target model hidden states β draft model's K/V projections via `extract_context_features()` |
| 3. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially) |
| 4. **Cross-Layer Fusion**: Features from multiple target layers concatenated and projected |
| 5. **Exact Acceptance**: Draft tokens verified greedily; KV cache rewound to accepted prefix |
|
|
| --- |
|
|
| ## π Benchmarking |
|
|
| ```python |
| from dflash_mlx import DFlashSpeculativeDecoder |
| from dflash_mlx.convert import load_mlx_dflash |
| from mlx_lm import load |
| |
| model, tokenizer = load("Qwen/Qwen3-4B") |
| draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") |
| |
| decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16) |
| |
| # Built-in benchmark (runs warmup + multiple trials) |
| results = decoder.benchmark( |
| prompt="Write a quicksort in Python.", |
| max_tokens=512, |
| num_runs=5, |
| ) |
| # prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s |
| ``` |
|
|
| --- |
|
|
| ## π₯οΈ OpenAI-Compatible Server |
|
|
| ```bash |
| # Start server with DFlash acceleration |
| python -m dflash_mlx.serve \ |
| --target mlx-community/Qwen3.5-9B-4bit \ |
| --draft ./Qwen3.5-9B-DFlash-mlx \ |
| --block-size 16 \ |
| --port 8000 |
| |
| # With uv (recommended) |
| uv run python -m dflash_mlx.serve \ |
| --target mlx-community/Qwen3.5-9B-4bit \ |
| --draft ./Qwen3.5-9B-DFlash-mlx \ |
| --block-size 16 \ |
| --port 8000 |
| |
| # Query with curl |
| curl http://localhost:8000/v1/chat/completions \ |
| -H "Content-Type: application/json" \ |
| -d '{ |
| "model": "qwen3.5-9b", |
| "messages": [{"role": "user", "content": "Hello!"}], |
| "max_tokens": 256, |
| "temperature": 0.0, |
| "stream": false |
| }' |
| |
| # Streaming SSE |
| curl http://localhost:8000/v1/chat/completions \ |
| -H "Content-Type: application/json" \ |
| -d '{ |
| "model": "qwen3.5-9b", |
| "messages": [{"role": "user", "content": "Count to 10"}], |
| "max_tokens": 100, |
| "stream": true |
| }' |
| |
| # Check metrics |
| curl http://localhost:8000/metrics |
| ``` |
|
|
| **Endpoints:** |
| - `GET /health` β Server status and mode |
| - `GET /v1/models` β Available models |
| - `GET /metrics` β Request count, tok/s, recent history |
| - `POST /v1/chat/completions` β Chat completions (OpenAI-compatible) |
|
|
| --- |
|
|
| ## ποΈ Training Custom Drafters |
|
|
| ```python |
| from dflash_mlx.universal import UniversalDFlashDecoder |
| from mlx_lm import load |
| |
| model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit") |
| |
| decoder = UniversalDFlashDecoder( |
| target_model=model, |
| tokenizer=tokenizer, |
| draft_layers=5, |
| draft_hidden_size=1024, |
| ) |
| |
| # Train using paper recipe (6 epochs, lr=6e-4, AdamW) |
| decoder.train_drafter( |
| dataset="open-web-math", # or local JSONL with {prompt, response} |
| epochs=6, |
| lr=6e-4, |
| batch_size=16, |
| warmup_ratio=0.04, |
| grad_clip=1.0, |
| output_path="./my-llama-drafter", |
| ) |
| |
| # Save and reload |
| decoder.save_drafter("./my-llama-drafter") |
| ``` |
|
|
| **Training recipe** (from DFlash paper Β§5): |
| - Data mix: 50% Chat + 30% Math + 20% Code |
| - Random anchor sampling: real accepted tokens as block starts |
| - Sparse attention mask: bidirectional within block, causal across blocks |
| - Position-dependent loss decay: exponential decay from anchor |
| - AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule |
| |
| --- |
| |
| ## π Repository Structure |
| |
| ``` |
| dflash-mlx-universal/ |
| βββ dflash_mlx/ |
| β βββ __init__.py # Package exports |
| β βββ adapters.py # π Architecture adapters (NEW v0.2.0) |
| β βββ model.py # DFlash draft model (attention, diffusion) |
| β βββ speculative_decode.py # Core speculative decoding loop (FIXED) |
| β βββ convert.py # PyTorch β MLX weight converter |
| β βββ universal.py # Generic decoder for any model |
| β βββ trainer.py # DFlash drafter training |
| β βββ data.py # Training data generation |
| β βββ serve.py # OpenAI-compatible HTTP server (NEW) |
| βββ examples/ |
| β βββ qwen3_4b_demo.py # End-to-end Qwen3 demo |
| β βββ convert_drafter.py # CLI conversion script |
| β βββ train_custom_drafter.py # CLI training script |
| βββ tests/ |
| β βββ test_model.py # Model unit tests |
| β βββ test_adapters.py # Adapter tests (NEW) |
| βββ benchmark_m2.py # Apple Silicon benchmark |
| βββ setup_m2.sh # Automated setup script |
| βββ setup_uv.sh # β
UV setup script (NEW v0.2.0) |
| βββ .python-version # Python version pin for uv |
| βββ USAGE_GUIDE.md # Detailed usage guide |
| βββ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide |
| βββ README.md # This file |
| βββ pyproject.toml # Package configuration (with uv support) |
| ``` |
| |
| --- |
| |
| ## π§ͺ Testing |
| |
| ```bash |
| # With uv (recommended) |
| uv run pytest tests/ |
| uv run pytest tests/test_adapters.py -v |
| uv run pytest tests/test_model.py -v |
| uv run pytest --cov=dflash_mlx tests/ |
|
|
| # Classic pip |
| pytest tests/ |
| pytest tests/test_adapters.py -v |
| pytest tests/test_model.py -v |
| ``` |
| |
| --- |
| |
| ## π§ Adding a New Model Family |
| |
| To add support for a new architecture (e.g., Phi, Falcon): |
| |
| ```python |
| # 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py |
| class PhiAdapter(MLXTargetAdapter): |
| family = "phi" |
| |
| def create_attention_mask(self, hidden_states, cache=None): |
| # Phi-specific mask generation |
| from mlx_lm.models import phi |
| return phi.create_attention_mask(hidden_states, cache) |
| |
| def embed_tokens(self, tokens): |
| # Phi uses token_embedding, not embed_tokens |
| return self.model.token_embedding(tokens) |
| |
| # 2. Register in ADAPTERS dict |
| ADAPTERS["phi"] = PhiAdapter |
|
|
| # 3. Add alias if needed |
| def adapter_for_model_type(model_type): |
| if model_type.startswith("phi"): |
| return PhiAdapter |
| # ... |
| ``` |
| |
| See `ADDING_MODELS.md` (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria. |
|
|
| --- |
|
|
| ## π Performance (Reference) |
|
|
| Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+: |
|
|
| | Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory | |
| |-------|---------------|-------------|-------------|--------| |
| | Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ** | ~4.5GB | |
| | Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ** | ~6.5GB | |
| | Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ** | ~7.5GB | |
| | LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ** | ~6.5GB | |
| | Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ** | ~26GB | |
|
|
| > Actual numbers depend on prompt complexity, temperature, and hardware. |
|
|
| --- |
|
|
| ## π Citation |
|
|
| ```bibtex |
| @misc{chen2026dflash, |
| title={DFlash: Block Diffusion for Flash Speculative Decoding}, |
| author={Jian Chen and Yesheng Liang and Zhijian Liu}, |
| year={2026}, |
| eprint={2602.06036}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.CL} |
| } |
| ``` |
|
|
| --- |
|
|
| ## π License |
|
|
| MIT License β same as the original DFlash project. |
|
|
| --- |
|
|
| ## π Acknowledgements |
|
|
| - Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu |
| - **Aryagm** for the original MLX community port (`dflash-mlx`) and adapter pattern |
| - **bstnxbt** for the production MLX port with Metal kernels and prefix caching |
| - MLX team at Apple for the excellent MLX framework |
| - Hugging Face community for model hosting and tools |
|
|
| --- |
|
|
| **Get 6Γ faster LLM inference on Apple Silicon today!** π |
|
|
| > *Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+.* |
| ``` |