tritesh's picture
Upload README.md
f693d4a verified
---
library_name: dflash-mlx-universal
tags:
- mlx
- speculative-decoding
- diffusion
- dflash
- inference-acceleration
- apple-silicon
- qwen3
- llama
- mistral
- gemma
- block-diffusion
- text-generation
- arxiv:2602.06036
license: mit
---
# DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
> **Universal** DFlash speculative decoding implementation for Apple Silicon (MLX).
> Works with **any MLX-converted model** β€” Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more.
[![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://python.org)
[![MLX](https://img.shields.io/badge/MLX-latest-red)](https://github.com/ml-explore/mlx)
[![License](https://img.shields.io/badge/license-MIT-green)](LICENSE)
[![uv](https://img.shields.io/badge/uv-astral-purple)](https://github.com/astral-sh/uv)
---
## πŸš€ What is DFlash?
[DFlash](https://arxiv.org/abs/2602.06036) (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel** within each block, achieving **4-6Γ— lossless speedup** over baseline inference.
**Key innovation:** The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates.
| Feature | Baseline | DFlash | Improvement |
|---------|----------|--------|-------------|
| **Speed** | ~20 tok/s | ~120 tok/s | **6Γ— faster** |
| **Quality** | Same | Same | **Lossless** |
| **Acceptance** | β€” | Ο„ β‰ˆ 6-7 | **~6 tokens accepted per draft** |
---
## ✨ What's New in Universal (v0.2.0)
This is a **major rewrite** that fixes the critical gaps in earlier community ports:
| Gap | Before (v0.1.x) | **Now (v0.2.0)** |
|-----|-----------------|-------------------|
| **Architecture support** | Hardcoded to Qwen3 | βœ… **Universal adapters** for Qwen3/3.5, LLaMA, Mistral, Gemma |
| **Hidden state extraction** | Direct `.layers` access (breaks on most models) | βœ… **Architecture-aware adapter system** with per-family hooks |
| **KV cache management** | None β€” never rewound | βœ… **Proper trim/rewind** on draft rejection |
| **Attention masks** | `mask=None` (undefined behavior) | βœ… **Family-specific mask generation** |
| **Token acceptance** | Buggy `cumprod` logic | βœ… **First-mismatch detection** with bonus token |
| **Streaming** | Not supported | βœ… **Real-time text streaming** with generator interface |
| **OpenAI server** | Not supported | βœ… **FastAPI + simple HTTP** with metrics endpoint |
| **Model conversion** | PyTorchβ†’MLX weight converter | βœ… **Updated for all z-lab drafters** |
| **Training** | Basic trainer | βœ… **Architecture-aware training** with adapter compatibility |
| **Benchmarking** | None | βœ… **Built-in benchmark** vs mlx_lm baseline |
| **uv support** | pip only | βœ… **`uv` + `uv run` workflow** with lock files |
---
## πŸ“¦ Installation
### Option 1: `uv` (Recommended β€” ultra-fast, reproducible)
[`uv`](https://github.com/astral-sh/uv) is an extremely fast Python package manager written in Rust. It's the **recommended** way to install on macOS.
```bash
# 1. Install uv (one-time)
brew install uv
# or: curl -LsSf https://astral.sh/uv/install.sh | sh
# 2. Clone and setup
git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
cd dflash-mlx-universal
# 3. One-command setup (creates venv, installs deps, locks)
chmod +x setup_uv.sh
./setup_uv.sh
# Or manually:
uv venv
uv pip install -e ".[dev,server]"
uv lock
```
**Why `uv`?**
- 10-100Γ— faster than `pip` (written in Rust)
- Automatic virtual environment management
- Lock file (`uv.lock`) for reproducible installs
- `uv run` β€” run any script without activating venv manually
```bash
# Examples of uv workflow
uv run python examples/qwen3_4b_demo.py
uv run pytest tests/ -v
uv run python -m dflash_mlx.serve --target ... --draft ... --port 8000
uv run black dflash_mlx/
uv run ruff check dflash_mlx/
```
### Option 2: pip (Classic)
```bash
pip install mlx-lm dflash-mlx-universal
```
For Apple Silicon (M1/M2/M3/M4):
```bash
pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal
```
**Optional** (for server mode):
```bash
pip install fastapi uvicorn
```
---
## ⚑ Quick Start
### Option 1: Pre-converted DFlash drafter (recommended)
```python
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash, infer_target_model
from mlx_lm import load
# 1. Load any MLX target model
target_path = "mlx-community/Qwen3-4B-bf16"
model, tokenizer = load(target_path)
# 2. Load a pre-converted DFlash drafter
draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
# 3. Create architecture-aware decoder
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=draft_config.get("block_size", 16),
)
# 4. Generate with 6Γ— speedup
output = decoder.generate(
prompt="Explain quantum computing to a 10-year-old.",
max_tokens=1024,
temperature=0.0,
)
print(output)
```
### Option 2: Universal decoder (auto-detects architecture)
```python
from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load
# Works with ANY mlx_lm model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
# Auto-detects architecture, creates generic drafter
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
block_size=16,
)
# Train a custom drafter (2-8 hours on Apple Silicon)
decoder.train_drafter(
dataset="open-web-math",
epochs=6,
lr=6e-4,
batch_size=16,
)
output = decoder.generate("Write a Python function to implement quicksort.")
print(output)
```
### Option 3: Convert PyTorch drafter to MLX
```bash
# Download official z-lab drafter and convert weights
python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ./Qwen3-4B-DFlash-mlx
# Or with uv (recommended)
uv run python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ./Qwen3-4B-DFlash-mlx
# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx
convert_dflash_to_mlx(
pytorch_model_id="z-lab/Qwen3.5-9B-DFlash",
output_path="./Qwen3.5-9B-DFlash-mlx",
)
```
---
## 🎯 Supported Models
### Pre-built DFlash drafters (convert to MLX)
All official `z-lab/*-DFlash` models can be converted:
| PyTorch Drafter | Target Model | Status |
|----------------|-------------|--------|
| `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | βœ… Ready |
| `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | βœ… Ready |
| `z-lab/Qwen3.5-4B-DFlash` | `Qwen/Qwen3.5-4B` | βœ… Ready |
| `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | βœ… Ready |
| `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | βœ… Ready |
| `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | βœ… Ready |
| `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | βœ… Ready |
| `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | βœ… Ready |
| `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | βœ… Ready |
| `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | βœ… Ready |
| `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | βœ… Ready |
### Architecture adapters (built-in)
| Model Family | Adapter | Hidden States | KV Cache | Attention Mask |
|-------------|---------|---------------|----------|----------------|
| **Qwen3** | `Qwen3Adapter` | βœ… | βœ… `KVCache.trim()` | βœ… `qwen3.create_attention_mask` |
| **Qwen3.5** | `Qwen35Adapter` | βœ… | βœ… ArraysCache | βœ… Hybrid FA + SSM masks |
| **LLaMA 2/3** | `LlamaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `llama.create_attention_mask` |
| **Mistral** | `MistralAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `mistral.create_attention_mask` |
| **Gemma** | `GemmaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `gemma.create_attention_mask` |
| **Generic** | `MLXTargetAdapter` | βœ… | βœ… Basic trim | ⚠️ Causal fallback |
---
## πŸ—οΈ Architecture Overview
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Target Model │────▢│ Extract Hidden β”‚
β”‚ (Any MLX LLM) β”‚ β”‚ Features (KV) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Verify Drafts │◀────│ DFlash Draft β”‚
β”‚ (Parallel) β”‚ β”‚ Model (Diffusion)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β–²
β”‚ Accepted Tokens β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
### Key Design
1. **Architecture Adapters**: Per-family `MLXTargetAdapter` subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management
2. **KV Injection**: Target model hidden states β†’ draft model's K/V projections via `extract_context_features()`
3. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially)
4. **Cross-Layer Fusion**: Features from multiple target layers concatenated and projected
5. **Exact Acceptance**: Draft tokens verified greedily; KV cache rewound to accepted prefix
---
## πŸ“Š Benchmarking
```python
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
from mlx_lm import load
model, tokenizer = load("Qwen/Qwen3-4B")
draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
# Built-in benchmark (runs warmup + multiple trials)
results = decoder.benchmark(
prompt="Write a quicksort in Python.",
max_tokens=512,
num_runs=5,
)
# prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s
```
---
## πŸ–₯️ OpenAI-Compatible Server
```bash
# Start server with DFlash acceleration
python -m dflash_mlx.serve \
--target mlx-community/Qwen3.5-9B-4bit \
--draft ./Qwen3.5-9B-DFlash-mlx \
--block-size 16 \
--port 8000
# With uv (recommended)
uv run python -m dflash_mlx.serve \
--target mlx-community/Qwen3.5-9B-4bit \
--draft ./Qwen3.5-9B-DFlash-mlx \
--block-size 16 \
--port 8000
# Query with curl
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3.5-9b",
"messages": [{"role": "user", "content": "Hello!"}],
"max_tokens": 256,
"temperature": 0.0,
"stream": false
}'
# Streaming SSE
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3.5-9b",
"messages": [{"role": "user", "content": "Count to 10"}],
"max_tokens": 100,
"stream": true
}'
# Check metrics
curl http://localhost:8000/metrics
```
**Endpoints:**
- `GET /health` β€” Server status and mode
- `GET /v1/models` β€” Available models
- `GET /metrics` β€” Request count, tok/s, recent history
- `POST /v1/chat/completions` β€” Chat completions (OpenAI-compatible)
---
## πŸ‹οΈ Training Custom Drafters
```python
from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
)
# Train using paper recipe (6 epochs, lr=6e-4, AdamW)
decoder.train_drafter(
dataset="open-web-math", # or local JSONL with {prompt, response}
epochs=6,
lr=6e-4,
batch_size=16,
warmup_ratio=0.04,
grad_clip=1.0,
output_path="./my-llama-drafter",
)
# Save and reload
decoder.save_drafter("./my-llama-drafter")
```
**Training recipe** (from DFlash paper Β§5):
- Data mix: 50% Chat + 30% Math + 20% Code
- Random anchor sampling: real accepted tokens as block starts
- Sparse attention mask: bidirectional within block, causal across blocks
- Position-dependent loss decay: exponential decay from anchor
- AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
---
## πŸ“ Repository Structure
```
dflash-mlx-universal/
β”œβ”€β”€ dflash_mlx/
β”‚ β”œβ”€β”€ __init__.py # Package exports
β”‚ β”œβ”€β”€ adapters.py # πŸ”‘ Architecture adapters (NEW v0.2.0)
β”‚ β”œβ”€β”€ model.py # DFlash draft model (attention, diffusion)
β”‚ β”œβ”€β”€ speculative_decode.py # Core speculative decoding loop (FIXED)
β”‚ β”œβ”€β”€ convert.py # PyTorch β†’ MLX weight converter
β”‚ β”œβ”€β”€ universal.py # Generic decoder for any model
β”‚ β”œβ”€β”€ trainer.py # DFlash drafter training
β”‚ β”œβ”€β”€ data.py # Training data generation
β”‚ └── serve.py # OpenAI-compatible HTTP server (NEW)
β”œβ”€β”€ examples/
β”‚ β”œβ”€β”€ qwen3_4b_demo.py # End-to-end Qwen3 demo
β”‚ β”œβ”€β”€ convert_drafter.py # CLI conversion script
β”‚ └── train_custom_drafter.py # CLI training script
β”œβ”€β”€ tests/
β”‚ β”œβ”€β”€ test_model.py # Model unit tests
β”‚ └── test_adapters.py # Adapter tests (NEW)
β”œβ”€β”€ benchmark_m2.py # Apple Silicon benchmark
β”œβ”€β”€ setup_m2.sh # Automated setup script
β”œβ”€β”€ setup_uv.sh # βœ… UV setup script (NEW v0.2.0)
β”œβ”€β”€ .python-version # Python version pin for uv
β”œβ”€β”€ USAGE_GUIDE.md # Detailed usage guide
β”œβ”€β”€ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide
β”œβ”€β”€ README.md # This file
└── pyproject.toml # Package configuration (with uv support)
```
---
## πŸ§ͺ Testing
```bash
# With uv (recommended)
uv run pytest tests/
uv run pytest tests/test_adapters.py -v
uv run pytest tests/test_model.py -v
uv run pytest --cov=dflash_mlx tests/
# Classic pip
pytest tests/
pytest tests/test_adapters.py -v
pytest tests/test_model.py -v
```
---
## πŸ”§ Adding a New Model Family
To add support for a new architecture (e.g., Phi, Falcon):
```python
# 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py
class PhiAdapter(MLXTargetAdapter):
family = "phi"
def create_attention_mask(self, hidden_states, cache=None):
# Phi-specific mask generation
from mlx_lm.models import phi
return phi.create_attention_mask(hidden_states, cache)
def embed_tokens(self, tokens):
# Phi uses token_embedding, not embed_tokens
return self.model.token_embedding(tokens)
# 2. Register in ADAPTERS dict
ADAPTERS["phi"] = PhiAdapter
# 3. Add alias if needed
def adapter_for_model_type(model_type):
if model_type.startswith("phi"):
return PhiAdapter
# ...
```
See `ADDING_MODELS.md` (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria.
---
## πŸ“Š Performance (Reference)
Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+:
| Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory |
|-------|---------------|-------------|-------------|--------|
| Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ—** | ~4.5GB |
| Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ—** | ~6.5GB |
| Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ—** | ~7.5GB |
| LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ—** | ~6.5GB |
| Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ—** | ~26GB |
> Actual numbers depend on prompt complexity, temperature, and hardware.
---
## πŸ“ Citation
```bibtex
@misc{chen2026dflash,
title={DFlash: Block Diffusion for Flash Speculative Decoding},
author={Jian Chen and Yesheng Liang and Zhijian Liu},
year={2026},
eprint={2602.06036},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
---
## πŸ“„ License
MIT License β€” same as the original DFlash project.
---
## πŸ™ Acknowledgements
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
- **Aryagm** for the original MLX community port (`dflash-mlx`) and adapter pattern
- **bstnxbt** for the production MLX port with Metal kernels and prefix caching
- MLX team at Apple for the excellent MLX framework
- Hugging Face community for model hosting and tools
---
**Get 6Γ— faster LLM inference on Apple Silicon today!** πŸš€
> *Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+.*
```