DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
Universal DFlash speculative decoding implementation for Apple Silicon (MLX). Works with any MLX-converted model β Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more.
π What is DFlash?
DFlash (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight block diffusion model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens in parallel within each block, achieving 4-6Γ lossless speedup over baseline inference.
Key innovation: The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates.
| Feature | Baseline | DFlash | Improvement |
|---|---|---|---|
| Speed | ~20 tok/s | ~120 tok/s | 6Γ faster |
| Quality | Same | Same | Lossless |
| Acceptance | β | Ο β 6-7 | ~6 tokens accepted per draft |
β¨ What's New in Universal (v0.2.0)
This is a major rewrite that fixes the critical gaps in earlier community ports:
| Gap | Before (v0.1.x) | Now (v0.2.0) |
|---|---|---|
| Architecture support | Hardcoded to Qwen3 | β Universal adapters for Qwen3/3.5, LLaMA, Mistral, Gemma |
| Hidden state extraction | Direct .layers access (breaks on most models) |
β Architecture-aware adapter system with per-family hooks |
| KV cache management | None β never rewound | β Proper trim/rewind on draft rejection |
| Attention masks | mask=None (undefined behavior) |
β Family-specific mask generation |
| Token acceptance | Buggy cumprod logic |
β First-mismatch detection with bonus token |
| Streaming | Not supported | β Real-time text streaming with generator interface |
| OpenAI server | Not supported | β FastAPI + simple HTTP with metrics endpoint |
| Model conversion | PyTorchβMLX weight converter | β Updated for all z-lab drafters |
| Training | Basic trainer | β Architecture-aware training with adapter compatibility |
| Benchmarking | None | β Built-in benchmark vs mlx_lm baseline |
| uv support | pip only | β
uv + uv run workflow with lock files |
π¦ Installation
Option 1: uv (Recommended β ultra-fast, reproducible)
uv is an extremely fast Python package manager written in Rust. It's the recommended way to install on macOS.
# 1. Install uv (one-time)
brew install uv
# or: curl -LsSf https://astral.sh/uv/install.sh | sh
# 2. Clone and setup
git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
cd dflash-mlx-universal
# 3. One-command setup (creates venv, installs deps, locks)
chmod +x setup_uv.sh
./setup_uv.sh
# Or manually:
uv venv
uv pip install -e ".[dev,server]"
uv lock
Why uv?
- 10-100Γ faster than
pip(written in Rust) - Automatic virtual environment management
- Lock file (
uv.lock) for reproducible installs uv runβ run any script without activating venv manually
# Examples of uv workflow
uv run python examples/qwen3_4b_demo.py
uv run pytest tests/ -v
uv run python -m dflash_mlx.serve --target ... --draft ... --port 8000
uv run black dflash_mlx/
uv run ruff check dflash_mlx/
Option 2: pip (Classic)
pip install mlx-lm dflash-mlx-universal
For Apple Silicon (M1/M2/M3/M4):
pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal
Optional (for server mode):
pip install fastapi uvicorn
β‘ Quick Start
Option 1: Pre-converted DFlash drafter (recommended)
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash, infer_target_model
from mlx_lm import load
# 1. Load any MLX target model
target_path = "mlx-community/Qwen3-4B-bf16"
model, tokenizer = load(target_path)
# 2. Load a pre-converted DFlash drafter
draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
# 3. Create architecture-aware decoder
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=draft_config.get("block_size", 16),
)
# 4. Generate with 6Γ speedup
output = decoder.generate(
prompt="Explain quantum computing to a 10-year-old.",
max_tokens=1024,
temperature=0.0,
)
print(output)
Option 2: Universal decoder (auto-detects architecture)
from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load
# Works with ANY mlx_lm model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
# Auto-detects architecture, creates generic drafter
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
block_size=16,
)
# Train a custom drafter (2-8 hours on Apple Silicon)
decoder.train_drafter(
dataset="open-web-math",
epochs=6,
lr=6e-4,
batch_size=16,
)
output = decoder.generate("Write a Python function to implement quicksort.")
print(output)
Option 3: Convert PyTorch drafter to MLX
# Download official z-lab drafter and convert weights
python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ./Qwen3-4B-DFlash-mlx
# Or with uv (recommended)
uv run python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ./Qwen3-4B-DFlash-mlx
# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx
convert_dflash_to_mlx(
pytorch_model_id="z-lab/Qwen3.5-9B-DFlash",
output_path="./Qwen3.5-9B-DFlash-mlx",
)
π― Supported Models
Pre-built DFlash drafters (convert to MLX)
All official z-lab/*-DFlash models can be converted:
| PyTorch Drafter | Target Model | Status |
|---|---|---|
z-lab/Qwen3-4B-DFlash-b16 |
Qwen/Qwen3-4B |
β Ready |
z-lab/Qwen3-8B-DFlash-b16 |
Qwen/Qwen3-8B |
β Ready |
z-lab/Qwen3.5-4B-DFlash |
Qwen/Qwen3.5-4B |
β Ready |
z-lab/Qwen3.5-9B-DFlash |
Qwen/Qwen3.5-9B |
β Ready |
z-lab/Qwen3.5-27B-DFlash |
Qwen/Qwen3.5-27B |
β Ready |
z-lab/Qwen3.6-27B-DFlash |
Qwen/Qwen3.6-27B |
β Ready |
z-lab/Qwen3.6-35B-A3B-DFlash |
Qwen/Qwen3.6-35B-A3B |
β Ready |
z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat |
meta-llama/Llama-3.1-8B |
β Ready |
z-lab/gemma-4-31B-it-DFlash |
google/gemma-4-31b-it |
β Ready |
z-lab/gpt-oss-20b-DFlash |
openai/gpt-oss-20b |
β Ready |
z-lab/Kimi-K2.5-DFlash |
moonshotai/Kimi-K2.5 |
β Ready |
Architecture adapters (built-in)
| Model Family | Adapter | Hidden States | KV Cache | Attention Mask |
|---|---|---|---|---|
| Qwen3 | Qwen3Adapter |
β | β
KVCache.trim() |
β
qwen3.create_attention_mask |
| Qwen3.5 | Qwen35Adapter |
β | β ArraysCache | β Hybrid FA + SSM masks |
| LLaMA 2/3 | LlamaAdapter |
β | β
KVCache.trim() |
β
llama.create_attention_mask |
| Mistral | MistralAdapter |
β | β
KVCache.trim() |
β
mistral.create_attention_mask |
| Gemma | GemmaAdapter |
β | β
KVCache.trim() |
β
gemma.create_attention_mask |
| Generic | MLXTargetAdapter |
β | β Basic trim | β οΈ Causal fallback |
ποΈ Architecture Overview
βββββββββββββββββββ βββββββββββββββββββ
β Target Model ββββββΆβ Extract Hidden β
β (Any MLX LLM) β β Features (KV) β
βββββββββββββββββββ ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ βββββββββββββββββββ
β Verify Drafts βββββββ DFlash Draft β
β (Parallel) β β Model (Diffusion)
βββββββββββββββββββ βββββββββββββββββββ
β β²
β Accepted Tokens β
ββββββββββββββββββββββββββ
Key Design
- Architecture Adapters: Per-family
MLXTargetAdaptersubclasses handle embedding extraction, layer iteration, attention masks, and KV cache management - KV Injection: Target model hidden states β draft model's K/V projections via
extract_context_features() - Block Diffusion: All tokens in a block predicted in parallel (not sequentially)
- Cross-Layer Fusion: Features from multiple target layers concatenated and projected
- Exact Acceptance: Draft tokens verified greedily; KV cache rewound to accepted prefix
π Benchmarking
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
from mlx_lm import load
model, tokenizer = load("Qwen/Qwen3-4B")
draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
# Built-in benchmark (runs warmup + multiple trials)
results = decoder.benchmark(
prompt="Write a quicksort in Python.",
max_tokens=512,
num_runs=5,
)
# prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s
π₯οΈ OpenAI-Compatible Server
# Start server with DFlash acceleration
python -m dflash_mlx.serve \
--target mlx-community/Qwen3.5-9B-4bit \
--draft ./Qwen3.5-9B-DFlash-mlx \
--block-size 16 \
--port 8000
# With uv (recommended)
uv run python -m dflash_mlx.serve \
--target mlx-community/Qwen3.5-9B-4bit \
--draft ./Qwen3.5-9B-DFlash-mlx \
--block-size 16 \
--port 8000
# Query with curl
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3.5-9b",
"messages": [{"role": "user", "content": "Hello!"}],
"max_tokens": 256,
"temperature": 0.0,
"stream": false
}'
# Streaming SSE
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen3.5-9b",
"messages": [{"role": "user", "content": "Count to 10"}],
"max_tokens": 100,
"stream": true
}'
# Check metrics
curl http://localhost:8000/metrics
Endpoints:
GET /healthβ Server status and modeGET /v1/modelsβ Available modelsGET /metricsβ Request count, tok/s, recent historyPOST /v1/chat/completionsβ Chat completions (OpenAI-compatible)
ποΈ Training Custom Drafters
from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
)
# Train using paper recipe (6 epochs, lr=6e-4, AdamW)
decoder.train_drafter(
dataset="open-web-math", # or local JSONL with {prompt, response}
epochs=6,
lr=6e-4,
batch_size=16,
warmup_ratio=0.04,
grad_clip=1.0,
output_path="./my-llama-drafter",
)
# Save and reload
decoder.save_drafter("./my-llama-drafter")
Training recipe (from DFlash paper Β§5):
- Data mix: 50% Chat + 30% Math + 20% Code
- Random anchor sampling: real accepted tokens as block starts
- Sparse attention mask: bidirectional within block, causal across blocks
- Position-dependent loss decay: exponential decay from anchor
- AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
π Repository Structure
dflash-mlx-universal/
βββ dflash_mlx/
β βββ __init__.py # Package exports
β βββ adapters.py # π Architecture adapters (NEW v0.2.0)
β βββ model.py # DFlash draft model (attention, diffusion)
β βββ speculative_decode.py # Core speculative decoding loop (FIXED)
β βββ convert.py # PyTorch β MLX weight converter
β βββ universal.py # Generic decoder for any model
β βββ trainer.py # DFlash drafter training
β βββ data.py # Training data generation
β βββ serve.py # OpenAI-compatible HTTP server (NEW)
βββ examples/
β βββ qwen3_4b_demo.py # End-to-end Qwen3 demo
β βββ convert_drafter.py # CLI conversion script
β βββ train_custom_drafter.py # CLI training script
βββ tests/
β βββ test_model.py # Model unit tests
β βββ test_adapters.py # Adapter tests (NEW)
βββ benchmark_m2.py # Apple Silicon benchmark
βββ setup_m2.sh # Automated setup script
βββ setup_uv.sh # β
UV setup script (NEW v0.2.0)
βββ .python-version # Python version pin for uv
βββ USAGE_GUIDE.md # Detailed usage guide
βββ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide
βββ README.md # This file
βββ pyproject.toml # Package configuration (with uv support)
π§ͺ Testing
# With uv (recommended)
uv run pytest tests/
uv run pytest tests/test_adapters.py -v
uv run pytest tests/test_model.py -v
uv run pytest --cov=dflash_mlx tests/
# Classic pip
pytest tests/
pytest tests/test_adapters.py -v
pytest tests/test_model.py -v
π§ Adding a New Model Family
To add support for a new architecture (e.g., Phi, Falcon):
# 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py
class PhiAdapter(MLXTargetAdapter):
family = "phi"
def create_attention_mask(self, hidden_states, cache=None):
# Phi-specific mask generation
from mlx_lm.models import phi
return phi.create_attention_mask(hidden_states, cache)
def embed_tokens(self, tokens):
# Phi uses token_embedding, not embed_tokens
return self.model.token_embedding(tokens)
# 2. Register in ADAPTERS dict
ADAPTERS["phi"] = PhiAdapter
# 3. Add alias if needed
def adapter_for_model_type(model_type):
if model_type.startswith("phi"):
return PhiAdapter
# ...
See ADDING_MODELS.md (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria.
π Performance (Reference)
Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+:
| Model | Baseline tok/s | DFlash tok/s | Speedup | Memory |
|---|---|---|---|---|
| Qwen3-4B (4-bit) | ~45 | ~270 | 6.0Γ | ~4.5GB |
| Qwen3-8B (4-bit) | ~22 | ~135 | 6.1Γ | ~6.5GB |
| Qwen3.5-9B (4-bit) | ~18 | ~110 | 6.1Γ | ~7.5GB |
| LLaMA-3.1-8B (4-bit) | ~20 | ~120 | 6.0Γ | ~6.5GB |
| Qwen3.5-27B (4-bit) | ~5 | ~30 | 6.0Γ | ~26GB |
Actual numbers depend on prompt complexity, temperature, and hardware.
π Citation
@misc{chen2026dflash,
title={DFlash: Block Diffusion for Flash Speculative Decoding},
author={Jian Chen and Yesheng Liang and Zhijian Liu},
year={2026},
eprint={2602.06036},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
π License
MIT License β same as the original DFlash project.
π Acknowledgements
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
- Aryagm for the original MLX community port (
dflash-mlx) and adapter pattern - bstnxbt for the production MLX port with Metal kernels and prefix caching
- MLX team at Apple for the excellent MLX framework
- Hugging Face community for model hosting and tools
Get 6Γ faster LLM inference on Apple Silicon today! π
Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+. ```
# Install MLX LM uv tool install mlx-lm# Generate some text mlx_lm.generate --model "tritesh/dflash-mlx-universal" --prompt "Once upon a time"