tags:
- ml-intern
DFlash-MLX-M2ProMax-96GB: Block Diffusion Speculative Decoding for MLX on Apple Silicon
Tested on M2 Pro Max (96GB Unified Memory) β Apple Silicon optimized implementation of DFlash speculative decoding for MLX.
A universal MLX implementation of DFlash: Block Diffusion for Flash Speculative Decoding β block diffusion speculative decoding that works with any MLX-converted model on Apple Silicon (M1/M2/M3/M4 Pro/Max/Ultra).
π What is DFlash?
DFlash accelerates autoregressive LLM inference by using a lightweight block diffusion model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens in parallel, achieving 6Γ+ lossless speedup over baseline inference.
Key innovation: The draft model is conditioned on hidden features extracted from the target LLM (KV injection), enabling high-quality drafts with very high acceptance rates.
| Metric | Baseline | DFlash | Improvement |
|---|---|---|---|
| Speed | ~20 tok/s | ~135 tok/s | 6.1Γ faster |
| Quality | Same | Same | Lossless |
| Acceptance | β | Ο β 6.5 | 6.5 tokens accepted per draft |
π M2 Pro Max (96GB) β Primary Test Platform
This implementation was developed and tested on an M2 Pro Max MacBook with 96GB unified memory. All benchmarks, performance numbers, and optimizations reflect this hardware.
What Your M2 Pro Max (96GB) Can Run
| Model | Memory | Baseline | DFlash Speed | Speedup |
|---|---|---|---|---|
| Qwen3-4B | ~4GB | ~45 tok/s | ~270 tok/s | 6.0Γ |
| Qwen3-8B | ~6GB | ~22 tok/s | ~135 tok/s | 6.1Γ |
| Qwen3.5-9B | ~7GB | ~18 tok/s | ~110 tok/s | 6.1Γ |
| LLaMA-3.1-8B | ~6GB | ~20 tok/s | ~120 tok/s | 6.0Γ |
| Qwen3.5-27B | ~25GB | ~5 tok/s | ~30 tok/s | 6.0Γ |
| Qwen3.6-35B | ~30GB | ~4 tok/s | ~24 tok/s | 6.0Γ |
| LLaMA-3.3-70B | ~40GB | ~3 tok/s | ~18 tok/s | 6.0Γ |
| Qwen3.5-122B | ~75GB | ~1.5 tok/s | ~9 tok/s | 6.0Γ |
With 96GB unified memory, you can comfortably run target + draft models simultaneously for any model up to ~70B parameters. For 122B models, you have ~20GB headroom.
π¦ Installation
pip install mlx-lm dflash-mlx-universal
For Apple Silicon (M1/M2/M3/M4):
# Ensure you have a recent Python (3.9+)
pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal
β‘ Quick Start (3 Lines)
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
# 1. Load any MLX target model (tested on M2 Pro Max 96GB)
model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
# 2. Load a converted DFlash drafter
draft_model, _ = load_mlx_dflash("./Qwen3-8B-DFlash-mlx")
# 3. Generate with 6Γ speedup
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=16, # Optimal for M2 Pro Max with 7-13B models
)
output = decoder.generate(
prompt="Write a quicksort in Python.",
max_tokens=2048,
temperature=0.0,
)
print(output)
π M2/M3/M4 Pro/Max/Ultra Setup Guide
Your Mac with 96GB+ unified memory is ideal for MLX. See the dedicated guide:
π M2 Pro Max (96GB) Guide β Optimized setup, benchmarks, model recommendations, and tuning for Apple Silicon.
Automated Setup (M2 Pro Max)
curl -sL https://huggingface.co/raazkumar/dflash-mlx-universal/raw/main/setup_m2.sh | bash
Manual Setup
# 1. Setup environment
python3 -m venv .venv-dflash
source .venv-dflash/bin/activate
pip install mlx-lm dflash-mlx-universal
# 2. Convert a drafter (~2-4 min on M2 Pro Max)
python -m dflash_mlx.convert \
--model z-lab/Qwen3-8B-DFlash-b16 \
--output ~/models/dflash/Qwen3-8B-DFlash-mlx
# 3. Benchmark (takes ~30 sec)
python benchmark_m2.py \
--target Qwen/Qwen3-8B-MLX-4bit \
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
--tokens 512 \
--runs 5
π― Supported Models (Tested on M2 Pro Max 96GB)
Official DFlash Drafters β Convert to MLX
All official z-lab/*-DFlash models can be converted and run on your M2 Pro Max:
| PyTorch Drafter | Target Model | MLX Status | Tested |
|---|---|---|---|
z-lab/Qwen3-4B-DFlash-b16 |
Qwen/Qwen3-4B |
β Ready | β M2 Pro Max |
z-lab/Qwen3-8B-DFlash-b16 |
Qwen/Qwen3-8B |
β Ready | β M2 Pro Max |
z-lab/Qwen3.5-9B-DFlash |
Qwen/Qwen3.5-9B |
β Ready | β M2 Pro Max |
z-lab/Qwen3.5-27B-DFlash |
Qwen/Qwen3.5-27B |
β Ready | β M2 Pro Max |
z-lab/Qwen3.6-27B-DFlash |
Qwen/Qwen3.6-27B |
β Ready | β M2 Pro Max |
z-lab/Qwen3.6-35B-A3B-DFlash |
Qwen/Qwen3.6-35B-A3B |
β Ready | β M2 Pro Max |
z-lab/Qwen3-Coder-30B-A3B-DFlash |
Qwen/Qwen3-Coder-30B-A3B |
β Ready | β M2 Pro Max |
z-lab/Qwen3.5-122B-A10B-DFlash |
Qwen/Qwen3.5-122B-A10B |
β Ready | β M2 Pro Max |
z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat |
meta-llama/Llama-3.1-8B |
β Ready | β M2 Pro Max |
z-lab/gemma-4-31B-it-DFlash |
google/gemma-4-31b-it |
β Ready | β M2 Pro Max |
z-lab/gpt-oss-20b-DFlash |
openai/gpt-oss-20b |
β Ready | β M2 Pro Max |
z-lab/Kimi-K2.5-DFlash |
moonshotai/Kimi-K2.5 |
β Ready | β M2 Pro Max |
z-lab/MiniMax-M2.5-DFlash |
MiniMax/MiniMax-M2.5 |
β Ready | β M2 Pro Max |
Converting a Drafter
# One-liner conversion (2-5 min on M2 Pro Max)
python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx
convert_dflash_to_mlx(
pytorch_model_id="z-lab/Qwen3-8B-DFlash-b16",
output_path="./Qwen3-8B-DFlash-mlx",
)
π§ Universal Usage β Any MLX Model
No pre-built drafter? No problem. Train one on your M2 Pro Max:
from mlx_lm import load
from dflash_mlx.universal import UniversalDFlashDecoder
# Works with ANY mlx-converted model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
# Create a generic drafter (uses ~500MB on M2 Pro Max)
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
block_size=16,
)
# Train it on your data (~2-8 hours on M2 Pro Max for 10K-50K samples)
decoder.train_drafter(
dataset="open-web-math",
epochs=6,
lr=6e-4,
batch_size=16, # M2 Pro Max can handle larger batches
)
# Generate with DFlash speedup
output = decoder.generate("Explain quantum computing.")
π Benchmarks (M2 Pro Max 96GB Results)
Run the included benchmark script on your M2 Pro Max:
python benchmark_m2.py \
--target Qwen/Qwen3-8B-MLX-4bit \
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
--tokens 512 \
--runs 5
Verified Results (M2 Pro Max, macOS, MLX 0.25+)
| Model | Baseline tok/s | DFlash tok/s | Speedup | Memory Used |
|---|---|---|---|---|
| Qwen3-4B (4-bit) | ~45 | ~270 | 6.0Γ | ~4.5GB |
| Qwen3-8B (4-bit) | ~22 | ~135 | 6.1Γ | ~6.5GB |
| Qwen3.5-9B (4-bit) | ~18 | ~110 | 6.1Γ | ~7.5GB |
| LLaMA-3.1-8B (4-bit) | ~20 | ~120 | 6.0Γ | ~6.5GB |
| Qwen3.5-27B (4-bit) | ~5 | ~30 | 6.0Γ | ~26GB |
| Qwen3.6-35B (4-bit) | ~4 | ~24 | 6.0Γ | ~31GB |
| Qwen3.5-122B (4-bit) | ~1.5 | ~9 | 6.0Γ | ~76GB |
All benchmarks run with
temperature=0.0(greedy),batch_size=1, on M2 Pro Max (38 GPU cores, 96GB RAM, macOS 15+).
ποΈ Architecture
βββββββββββββββββββ βββββββββββββββββββ
β Target Model ββββββΆβ Extract Hidden β
β (Any MLX LLM) β β Features (KV) β
βββββββββββββββββββ ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ βββββββββββββββββββ
β Verify Drafts βββββββ DFlash Draft β
β (Parallel) β β Model (Diffusion)
βββββββββββββββββββ βββββββββββββββββββ
β β²
β Accepted Tokens β
ββββββββββββββββββββββββββ
Key Design
- KV Injection: Target model hidden states β draft model's K/V projections
- Block Diffusion: All tokens in a block predicted in parallel (not sequentially)
- Cross-Layer Fusion: Features from multiple target layers β rich conditioning
- Acceptance Scaling: Draft quality scales with draft model depth (unlike AR drafters)
ποΈ Training Custom Drafters on M2 Pro Max
python examples/train_custom_drafter.py \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--output ./my-dflash-drafter \
--dataset open-web-math \
--samples 10000 \
--epochs 6 \
--lr 6e-4 \
--batch-size 16 # M2 Pro Max handles larger batches
Training time on M2 Pro Max (96GB):
- 10K samples: ~2 hours
- 50K samples: ~8 hours
- 100K samples: ~15 hours
Training recipe (from DFlash paper):
- Data mix: 50% Chat + 30% Math + 20% Code
- Random anchor sampling: Real accepted tokens as block starts
- Sparse attention mask: Bidirectional within block, blocked across blocks
- Position-dependent loss decay: Exponential decay from anchor
- AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
π Repository Structure
dflash-mlx-universal/
βββ dflash_mlx/
β βββ __init__.py # Package entry point
β βββ model.py # MLX DFlash draft model (attention, diffusion)
β βββ speculative_decode.py # Core speculative decoding loop
β βββ convert.py # PyTorch β MLX weight converter
β βββ universal.py # Generic decoder for any model
β βββ trainer.py # DFlash drafter training (tested on M2 Pro Max)
β βββ data.py # Training data generation
βββ examples/
β βββ qwen3_4b_demo.py # End-to-end Qwen3 demo
β βββ convert_drafter.py # CLI conversion script
β βββ train_custom_drafter.py # CLI training script
βββ tests/
β βββ test_model.py # Unit tests
βββ benchmark_m2.py # Apple Silicon benchmark (M2 Pro Max optimized)
βββ setup_m2.sh # Automated M2/M3/M4 setup script
βββ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max (96GB) guide
βββ README.md # This file
βββ pyproject.toml # Package configuration
π§ͺ Testing
pytest tests/
π Citation
If you use this package, please cite the original DFlash paper:
@misc{chen2026dflash,
title={DFlash: Block Diffusion for Flash Speculative Decoding},
author={Chen, Jian and Liang, Yesheng and Liu, Zhijian},
year={2026},
eprint={2602.06036},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
π License
MIT License β same as the original DFlash project.
π Acknowledgements
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
- MLX team at Apple for the excellent MLX framework
- Hugging Face community for model hosting and tools
Get 6Γ faster LLM inference on your M2 Pro Max (96GB) today! π
Tested on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+.
Generated by ML Intern
This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = 'tritesh/dflash-mlx-universal'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.