--- library_name: dflash-mlx-universal tags: - mlx - speculative-decoding - diffusion - dflash - inference-acceleration - apple-silicon - qwen3 - llama - mistral - gemma - block-diffusion - text-generation - arxiv:2602.06036 license: mit --- # DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX > **Universal** DFlash speculative decoding implementation for Apple Silicon (MLX). > Works with **any MLX-converted model** β€” Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more. [![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://python.org) [![MLX](https://img.shields.io/badge/MLX-latest-red)](https://github.com/ml-explore/mlx) [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) [![uv](https://img.shields.io/badge/uv-astral-purple)](https://github.com/astral-sh/uv) --- ## πŸš€ What is DFlash? [DFlash](https://arxiv.org/abs/2602.06036) (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel** within each block, achieving **4-6Γ— lossless speedup** over baseline inference. **Key innovation:** The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates. | Feature | Baseline | DFlash | Improvement | |---------|----------|--------|-------------| | **Speed** | ~20 tok/s | ~120 tok/s | **6Γ— faster** | | **Quality** | Same | Same | **Lossless** | | **Acceptance** | β€” | Ο„ β‰ˆ 6-7 | **~6 tokens accepted per draft** | --- ## ✨ What's New in Universal (v0.2.0) This is a **major rewrite** that fixes the critical gaps in earlier community ports: | Gap | Before (v0.1.x) | **Now (v0.2.0)** | |-----|-----------------|-------------------| | **Architecture support** | Hardcoded to Qwen3 | βœ… **Universal adapters** for Qwen3/3.5, LLaMA, Mistral, Gemma | | **Hidden state extraction** | Direct `.layers` access (breaks on most models) | βœ… **Architecture-aware adapter system** with per-family hooks | | **KV cache management** | None β€” never rewound | βœ… **Proper trim/rewind** on draft rejection | | **Attention masks** | `mask=None` (undefined behavior) | βœ… **Family-specific mask generation** | | **Token acceptance** | Buggy `cumprod` logic | βœ… **First-mismatch detection** with bonus token | | **Streaming** | Not supported | βœ… **Real-time text streaming** with generator interface | | **OpenAI server** | Not supported | βœ… **FastAPI + simple HTTP** with metrics endpoint | | **Model conversion** | PyTorchβ†’MLX weight converter | βœ… **Updated for all z-lab drafters** | | **Training** | Basic trainer | βœ… **Architecture-aware training** with adapter compatibility | | **Benchmarking** | None | βœ… **Built-in benchmark** vs mlx_lm baseline | | **uv support** | pip only | βœ… **`uv` + `uv run` workflow** with lock files | --- ## πŸ“¦ Installation ### Option 1: `uv` (Recommended β€” ultra-fast, reproducible) [`uv`](https://github.com/astral-sh/uv) is an extremely fast Python package manager written in Rust. It's the **recommended** way to install on macOS. ```bash # 1. Install uv (one-time) brew install uv # or: curl -LsSf https://astral.sh/uv/install.sh | sh # 2. Clone and setup git clone https://huggingface.co/tritesh/dflash-mlx-universal.git cd dflash-mlx-universal # 3. One-command setup (creates venv, installs deps, locks) chmod +x setup_uv.sh ./setup_uv.sh # Or manually: uv venv uv pip install -e ".[dev,server]" uv lock ``` **Why `uv`?** - 10-100Γ— faster than `pip` (written in Rust) - Automatic virtual environment management - Lock file (`uv.lock`) for reproducible installs - `uv run` β€” run any script without activating venv manually ```bash # Examples of uv workflow uv run python examples/qwen3_4b_demo.py uv run pytest tests/ -v uv run python -m dflash_mlx.serve --target ... --draft ... --port 8000 uv run black dflash_mlx/ uv run ruff check dflash_mlx/ ``` ### Option 2: pip (Classic) ```bash pip install mlx-lm dflash-mlx-universal ``` For Apple Silicon (M1/M2/M3/M4): ```bash pip install --upgrade pip pip install mlx-lm dflash-mlx-universal ``` **Optional** (for server mode): ```bash pip install fastapi uvicorn ``` --- ## ⚑ Quick Start ### Option 1: Pre-converted DFlash drafter (recommended) ```python from dflash_mlx import DFlashSpeculativeDecoder from dflash_mlx.convert import load_mlx_dflash, infer_target_model from mlx_lm import load # 1. Load any MLX target model target_path = "mlx-community/Qwen3-4B-bf16" model, tokenizer = load(target_path) # 2. Load a pre-converted DFlash drafter draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") # 3. Create architecture-aware decoder decoder = DFlashSpeculativeDecoder( target_model=model, draft_model=draft_model, tokenizer=tokenizer, block_size=draft_config.get("block_size", 16), ) # 4. Generate with 6Γ— speedup output = decoder.generate( prompt="Explain quantum computing to a 10-year-old.", max_tokens=1024, temperature=0.0, ) print(output) ``` ### Option 2: Universal decoder (auto-detects architecture) ```python from dflash_mlx.universal import UniversalDFlashDecoder from mlx_lm import load # Works with ANY mlx_lm model model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit") # Auto-detects architecture, creates generic drafter decoder = UniversalDFlashDecoder( target_model=model, tokenizer=tokenizer, draft_layers=5, draft_hidden_size=1024, block_size=16, ) # Train a custom drafter (2-8 hours on Apple Silicon) decoder.train_drafter( dataset="open-web-math", epochs=6, lr=6e-4, batch_size=16, ) output = decoder.generate("Write a Python function to implement quicksort.") print(output) ``` ### Option 3: Convert PyTorch drafter to MLX ```bash # Download official z-lab drafter and convert weights python -m dflash_mlx.convert \ --model z-lab/Qwen3-4B-DFlash-b16 \ --output ./Qwen3-4B-DFlash-mlx # Or with uv (recommended) uv run python -m dflash_mlx.convert \ --model z-lab/Qwen3-4B-DFlash-b16 \ --output ./Qwen3-4B-DFlash-mlx # Or in Python from dflash_mlx.convert import convert_dflash_to_mlx convert_dflash_to_mlx( pytorch_model_id="z-lab/Qwen3.5-9B-DFlash", output_path="./Qwen3.5-9B-DFlash-mlx", ) ``` --- ## 🎯 Supported Models ### Pre-built DFlash drafters (convert to MLX) All official `z-lab/*-DFlash` models can be converted: | PyTorch Drafter | Target Model | Status | |----------------|-------------|--------| | `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | βœ… Ready | | `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | βœ… Ready | | `z-lab/Qwen3.5-4B-DFlash` | `Qwen/Qwen3.5-4B` | βœ… Ready | | `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | βœ… Ready | | `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | βœ… Ready | | `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | βœ… Ready | | `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | βœ… Ready | | `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | βœ… Ready | | `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | βœ… Ready | | `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | βœ… Ready | | `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | βœ… Ready | ### Architecture adapters (built-in) | Model Family | Adapter | Hidden States | KV Cache | Attention Mask | |-------------|---------|---------------|----------|----------------| | **Qwen3** | `Qwen3Adapter` | βœ… | βœ… `KVCache.trim()` | βœ… `qwen3.create_attention_mask` | | **Qwen3.5** | `Qwen35Adapter` | βœ… | βœ… ArraysCache | βœ… Hybrid FA + SSM masks | | **LLaMA 2/3** | `LlamaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `llama.create_attention_mask` | | **Mistral** | `MistralAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `mistral.create_attention_mask` | | **Gemma** | `GemmaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `gemma.create_attention_mask` | | **Generic** | `MLXTargetAdapter` | βœ… | βœ… Basic trim | ⚠️ Causal fallback | --- ## πŸ—οΈ Architecture Overview ``` β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Target Model │────▢│ Extract Hidden β”‚ β”‚ (Any MLX LLM) β”‚ β”‚ Features (KV) β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β–Ό β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Verify Drafts │◀────│ DFlash Draft β”‚ β”‚ (Parallel) β”‚ β”‚ Model (Diffusion) β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β–² β”‚ Accepted Tokens β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ ``` ### Key Design 1. **Architecture Adapters**: Per-family `MLXTargetAdapter` subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management 2. **KV Injection**: Target model hidden states β†’ draft model's K/V projections via `extract_context_features()` 3. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially) 4. **Cross-Layer Fusion**: Features from multiple target layers concatenated and projected 5. **Exact Acceptance**: Draft tokens verified greedily; KV cache rewound to accepted prefix --- ## πŸ“Š Benchmarking ```python from dflash_mlx import DFlashSpeculativeDecoder from dflash_mlx.convert import load_mlx_dflash from mlx_lm import load model, tokenizer = load("Qwen/Qwen3-4B") draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16) # Built-in benchmark (runs warmup + multiple trials) results = decoder.benchmark( prompt="Write a quicksort in Python.", max_tokens=512, num_runs=5, ) # prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s ``` --- ## πŸ–₯️ OpenAI-Compatible Server ```bash # Start server with DFlash acceleration python -m dflash_mlx.serve \ --target mlx-community/Qwen3.5-9B-4bit \ --draft ./Qwen3.5-9B-DFlash-mlx \ --block-size 16 \ --port 8000 # With uv (recommended) uv run python -m dflash_mlx.serve \ --target mlx-community/Qwen3.5-9B-4bit \ --draft ./Qwen3.5-9B-DFlash-mlx \ --block-size 16 \ --port 8000 # Query with curl curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "qwen3.5-9b", "messages": [{"role": "user", "content": "Hello!"}], "max_tokens": 256, "temperature": 0.0, "stream": false }' # Streaming SSE curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "qwen3.5-9b", "messages": [{"role": "user", "content": "Count to 10"}], "max_tokens": 100, "stream": true }' # Check metrics curl http://localhost:8000/metrics ``` **Endpoints:** - `GET /health` β€” Server status and mode - `GET /v1/models` β€” Available models - `GET /metrics` β€” Request count, tok/s, recent history - `POST /v1/chat/completions` β€” Chat completions (OpenAI-compatible) --- ## πŸ‹οΈ Training Custom Drafters ```python from dflash_mlx.universal import UniversalDFlashDecoder from mlx_lm import load model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit") decoder = UniversalDFlashDecoder( target_model=model, tokenizer=tokenizer, draft_layers=5, draft_hidden_size=1024, ) # Train using paper recipe (6 epochs, lr=6e-4, AdamW) decoder.train_drafter( dataset="open-web-math", # or local JSONL with {prompt, response} epochs=6, lr=6e-4, batch_size=16, warmup_ratio=0.04, grad_clip=1.0, output_path="./my-llama-drafter", ) # Save and reload decoder.save_drafter("./my-llama-drafter") ``` **Training recipe** (from DFlash paper Β§5): - Data mix: 50% Chat + 30% Math + 20% Code - Random anchor sampling: real accepted tokens as block starts - Sparse attention mask: bidirectional within block, causal across blocks - Position-dependent loss decay: exponential decay from anchor - AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule --- ## πŸ“ Repository Structure ``` dflash-mlx-universal/ β”œβ”€β”€ dflash_mlx/ β”‚ β”œβ”€β”€ __init__.py # Package exports β”‚ β”œβ”€β”€ adapters.py # πŸ”‘ Architecture adapters (NEW v0.2.0) β”‚ β”œβ”€β”€ model.py # DFlash draft model (attention, diffusion) β”‚ β”œβ”€β”€ speculative_decode.py # Core speculative decoding loop (FIXED) β”‚ β”œβ”€β”€ convert.py # PyTorch β†’ MLX weight converter β”‚ β”œβ”€β”€ universal.py # Generic decoder for any model β”‚ β”œβ”€β”€ trainer.py # DFlash drafter training β”‚ β”œβ”€β”€ data.py # Training data generation β”‚ └── serve.py # OpenAI-compatible HTTP server (NEW) β”œβ”€β”€ examples/ β”‚ β”œβ”€β”€ qwen3_4b_demo.py # End-to-end Qwen3 demo β”‚ β”œβ”€β”€ convert_drafter.py # CLI conversion script β”‚ └── train_custom_drafter.py # CLI training script β”œβ”€β”€ tests/ β”‚ β”œβ”€β”€ test_model.py # Model unit tests β”‚ └── test_adapters.py # Adapter tests (NEW) β”œβ”€β”€ benchmark_m2.py # Apple Silicon benchmark β”œβ”€β”€ setup_m2.sh # Automated setup script β”œβ”€β”€ setup_uv.sh # βœ… UV setup script (NEW v0.2.0) β”œβ”€β”€ .python-version # Python version pin for uv β”œβ”€β”€ USAGE_GUIDE.md # Detailed usage guide β”œβ”€β”€ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide β”œβ”€β”€ README.md # This file └── pyproject.toml # Package configuration (with uv support) ``` --- ## πŸ§ͺ Testing ```bash # With uv (recommended) uv run pytest tests/ uv run pytest tests/test_adapters.py -v uv run pytest tests/test_model.py -v uv run pytest --cov=dflash_mlx tests/ # Classic pip pytest tests/ pytest tests/test_adapters.py -v pytest tests/test_model.py -v ``` --- ## πŸ”§ Adding a New Model Family To add support for a new architecture (e.g., Phi, Falcon): ```python # 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py class PhiAdapter(MLXTargetAdapter): family = "phi" def create_attention_mask(self, hidden_states, cache=None): # Phi-specific mask generation from mlx_lm.models import phi return phi.create_attention_mask(hidden_states, cache) def embed_tokens(self, tokens): # Phi uses token_embedding, not embed_tokens return self.model.token_embedding(tokens) # 2. Register in ADAPTERS dict ADAPTERS["phi"] = PhiAdapter # 3. Add alias if needed def adapter_for_model_type(model_type): if model_type.startswith("phi"): return PhiAdapter # ... ``` See `ADDING_MODELS.md` (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria. --- ## πŸ“Š Performance (Reference) Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+: | Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory | |-------|---------------|-------------|-------------|--------| | Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ—** | ~4.5GB | | Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ—** | ~6.5GB | | Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ—** | ~7.5GB | | LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ—** | ~6.5GB | | Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ—** | ~26GB | > Actual numbers depend on prompt complexity, temperature, and hardware. --- ## πŸ“ Citation ```bibtex @misc{chen2026dflash, title={DFlash: Block Diffusion for Flash Speculative Decoding}, author={Jian Chen and Yesheng Liang and Zhijian Liu}, year={2026}, eprint={2602.06036}, archivePrefix={arXiv}, primaryClass={cs.CL} } ``` --- ## πŸ“„ License MIT License β€” same as the original DFlash project. --- ## πŸ™ Acknowledgements - Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu - **Aryagm** for the original MLX community port (`dflash-mlx`) and adapter pattern - **bstnxbt** for the production MLX port with Metal kernels and prefix caching - MLX team at Apple for the excellent MLX framework - Hugging Face community for model hosting and tools --- **Get 6Γ— faster LLM inference on Apple Silicon today!** πŸš€ > *Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+.* ```