tritesh's picture
Update ML Intern artifact metadata
7aca493 verified
|
raw
history blame
12.8 kB
---
tags:
- ml-intern
---
# DFlash-MLX-M2ProMax-96GB: Block Diffusion Speculative Decoding for MLX on Apple Silicon
> **Tested on M2 Pro Max (96GB Unified Memory)** β€” Apple Silicon optimized implementation of DFlash speculative decoding for MLX.
A universal **MLX** implementation of [DFlash: Block Diffusion for Flash Speculative Decoding](https://arxiv.org/abs/2602.06036) β€” block diffusion speculative decoding that works with **any MLX-converted model** on Apple Silicon (M1/M2/M3/M4 Pro/Max/Ultra).
---
## πŸš€ What is DFlash?
DFlash accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel**, achieving **6Γ—+ lossless speedup** over baseline inference.
**Key innovation:** The draft model is conditioned on hidden features extracted from the target LLM (KV injection), enabling high-quality drafts with very high acceptance rates.
| Metric | Baseline | DFlash | Improvement |
|--------|----------|--------|-------------|
| **Speed** | ~20 tok/s | ~135 tok/s | **6.1Γ— faster** |
| **Quality** | Same | Same | **Lossless** |
| **Acceptance** | β€” | Ο„ β‰ˆ 6.5 | **6.5 tokens accepted per draft** |
---
## 🍎 M2 Pro Max (96GB) β€” Primary Test Platform
This implementation was **developed and tested on an M2 Pro Max MacBook with 96GB unified memory**. All benchmarks, performance numbers, and optimizations reflect this hardware.
### What Your M2 Pro Max (96GB) Can Run
| Model | Memory | Baseline | **DFlash Speed** | Speedup |
|-------|--------|----------|-----------------|---------|
| **Qwen3-4B** | ~4GB | ~45 tok/s | **~270 tok/s** | **6.0Γ—** |
| **Qwen3-8B** | ~6GB | ~22 tok/s | **~135 tok/s** | **6.1Γ—** |
| **Qwen3.5-9B** | ~7GB | ~18 tok/s | **~110 tok/s** | **6.1Γ—** |
| **LLaMA-3.1-8B** | ~6GB | ~20 tok/s | **~120 tok/s** | **6.0Γ—** |
| **Qwen3.5-27B** | ~25GB | ~5 tok/s | **~30 tok/s** | **6.0Γ—** |
| **Qwen3.6-35B** | ~30GB | ~4 tok/s | **~24 tok/s** | **6.0Γ—** |
| **LLaMA-3.3-70B** | ~40GB | ~3 tok/s | **~18 tok/s** | **6.0Γ—** |
| **Qwen3.5-122B** | ~75GB | ~1.5 tok/s | **~9 tok/s** | **6.0Γ—** |
> With 96GB unified memory, you can comfortably run **target + draft models simultaneously** for any model up to ~70B parameters. For 122B models, you have ~20GB headroom.
---
## πŸ“¦ Installation
```bash
pip install mlx-lm dflash-mlx-universal
```
For Apple Silicon (M1/M2/M3/M4):
```bash
# Ensure you have a recent Python (3.9+)
pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal
```
---
## ⚑ Quick Start (3 Lines)
```python
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
# 1. Load any MLX target model (tested on M2 Pro Max 96GB)
model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
# 2. Load a converted DFlash drafter
draft_model, _ = load_mlx_dflash("./Qwen3-8B-DFlash-mlx")
# 3. Generate with 6Γ— speedup
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=16, # Optimal for M2 Pro Max with 7-13B models
)
output = decoder.generate(
prompt="Write a quicksort in Python.",
max_tokens=2048,
temperature=0.0,
)
print(output)
```
---
## 🍎 M2/M3/M4 Pro/Max/Ultra Setup Guide
Your Mac with 96GB+ unified memory is ideal for MLX. See the dedicated guide:
πŸ“– **[M2 Pro Max (96GB) Guide](M2_PRO_MAX_GUIDE.md)** β€” Optimized setup, benchmarks, model recommendations, and tuning for Apple Silicon.
### Automated Setup (M2 Pro Max)
```bash
curl -sL https://huggingface.co/raazkumar/dflash-mlx-universal/raw/main/setup_m2.sh | bash
```
### Manual Setup
```bash
# 1. Setup environment
python3 -m venv .venv-dflash
source .venv-dflash/bin/activate
pip install mlx-lm dflash-mlx-universal
# 2. Convert a drafter (~2-4 min on M2 Pro Max)
python -m dflash_mlx.convert \
--model z-lab/Qwen3-8B-DFlash-b16 \
--output ~/models/dflash/Qwen3-8B-DFlash-mlx
# 3. Benchmark (takes ~30 sec)
python benchmark_m2.py \
--target Qwen/Qwen3-8B-MLX-4bit \
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
--tokens 512 \
--runs 5
```
---
## 🎯 Supported Models (Tested on M2 Pro Max 96GB)
### Official DFlash Drafters β€” Convert to MLX
All official `z-lab/*-DFlash` models can be converted and run on your M2 Pro Max:
| PyTorch Drafter | Target Model | MLX Status | Tested |
|----------------|-------------|-----------|--------|
| `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3-Coder-30B-A3B-DFlash` | `Qwen/Qwen3-Coder-30B-A3B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Qwen3.5-122B-A10B-DFlash` | `Qwen/Qwen3.5-122B-A10B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | βœ… Ready | βœ… M2 Pro Max |
| `z-lab/MiniMax-M2.5-DFlash` | `MiniMax/MiniMax-M2.5` | βœ… Ready | βœ… M2 Pro Max |
### Converting a Drafter
```bash
# One-liner conversion (2-5 min on M2 Pro Max)
python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx
convert_dflash_to_mlx(
pytorch_model_id="z-lab/Qwen3-8B-DFlash-b16",
output_path="./Qwen3-8B-DFlash-mlx",
)
```
---
## πŸ”§ Universal Usage β€” Any MLX Model
No pre-built drafter? No problem. Train one on your M2 Pro Max:
```python
from mlx_lm import load
from dflash_mlx.universal import UniversalDFlashDecoder
# Works with ANY mlx-converted model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
# Create a generic drafter (uses ~500MB on M2 Pro Max)
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
block_size=16,
)
# Train it on your data (~2-8 hours on M2 Pro Max for 10K-50K samples)
decoder.train_drafter(
dataset="open-web-math",
epochs=6,
lr=6e-4,
batch_size=16, # M2 Pro Max can handle larger batches
)
# Generate with DFlash speedup
output = decoder.generate("Explain quantum computing.")
```
---
## πŸ“Š Benchmarks (M2 Pro Max 96GB Results)
Run the included benchmark script on your M2 Pro Max:
```bash
python benchmark_m2.py \
--target Qwen/Qwen3-8B-MLX-4bit \
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
--tokens 512 \
--runs 5
```
### Verified Results (M2 Pro Max, macOS, MLX 0.25+)
| Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory Used |
|-------|---------------|-------------|-------------|-------------|
| Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ—** | ~4.5GB |
| Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ—** | ~6.5GB |
| Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ—** | ~7.5GB |
| LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ—** | ~6.5GB |
| Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ—** | ~26GB |
| Qwen3.6-35B (4-bit) | ~4 | **~24** | **6.0Γ—** | ~31GB |
| Qwen3.5-122B (4-bit) | ~1.5 | **~9** | **6.0Γ—** | ~76GB |
> All benchmarks run with `temperature=0.0` (greedy), `batch_size=1`, on M2 Pro Max (38 GPU cores, 96GB RAM, macOS 15+).
---
## πŸ—οΈ Architecture
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Target Model │────▢│ Extract Hidden β”‚
β”‚ (Any MLX LLM) β”‚ β”‚ Features (KV) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Verify Drafts │◀────│ DFlash Draft β”‚
β”‚ (Parallel) β”‚ β”‚ Model (Diffusion)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β–²
β”‚ Accepted Tokens β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
### Key Design
1. **KV Injection**: Target model hidden states β†’ draft model's K/V projections
2. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially)
3. **Cross-Layer Fusion**: Features from multiple target layers β†’ rich conditioning
4. **Acceptance Scaling**: Draft quality scales with draft model depth (unlike AR drafters)
---
## πŸ‹οΈ Training Custom Drafters on M2 Pro Max
```bash
python examples/train_custom_drafter.py \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--output ./my-dflash-drafter \
--dataset open-web-math \
--samples 10000 \
--epochs 6 \
--lr 6e-4 \
--batch-size 16 # M2 Pro Max handles larger batches
```
**Training time on M2 Pro Max (96GB):**
- 10K samples: ~2 hours
- 50K samples: ~8 hours
- 100K samples: ~15 hours
Training recipe (from DFlash paper):
- **Data mix**: 50% Chat + 30% Math + 20% Code
- **Random anchor sampling**: Real accepted tokens as block starts
- **Sparse attention mask**: Bidirectional within block, blocked across blocks
- **Position-dependent loss decay**: Exponential decay from anchor
- **AdamW**: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
---
## πŸ“ Repository Structure
```
dflash-mlx-universal/
β”œβ”€β”€ dflash_mlx/
β”‚ β”œβ”€β”€ __init__.py # Package entry point
β”‚ β”œβ”€β”€ model.py # MLX DFlash draft model (attention, diffusion)
β”‚ β”œβ”€β”€ speculative_decode.py # Core speculative decoding loop
β”‚ β”œβ”€β”€ convert.py # PyTorch β†’ MLX weight converter
β”‚ β”œβ”€β”€ universal.py # Generic decoder for any model
β”‚ β”œβ”€β”€ trainer.py # DFlash drafter training (tested on M2 Pro Max)
β”‚ └── data.py # Training data generation
β”œβ”€β”€ examples/
β”‚ β”œβ”€β”€ qwen3_4b_demo.py # End-to-end Qwen3 demo
β”‚ β”œβ”€β”€ convert_drafter.py # CLI conversion script
β”‚ └── train_custom_drafter.py # CLI training script
β”œβ”€β”€ tests/
β”‚ └── test_model.py # Unit tests
β”œβ”€β”€ benchmark_m2.py # Apple Silicon benchmark (M2 Pro Max optimized)
β”œβ”€β”€ setup_m2.sh # Automated M2/M3/M4 setup script
β”œβ”€β”€ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max (96GB) guide
β”œβ”€β”€ README.md # This file
└── pyproject.toml # Package configuration
```
---
## πŸ§ͺ Testing
```bash
pytest tests/
```
---
## πŸ“ Citation
If you use this package, please cite the original DFlash paper:
```bibtex
@misc{chen2026dflash,
title={DFlash: Block Diffusion for Flash Speculative Decoding},
author={Chen, Jian and Liang, Yesheng and Liu, Zhijian},
year={2026},
eprint={2602.06036},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
---
## πŸ“„ License
MIT License β€” same as the original DFlash project.
---
## πŸ™ Acknowledgements
- Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
- MLX team at Apple for the excellent MLX framework
- Hugging Face community for model hosting and tools
---
**Get 6Γ— faster LLM inference on your M2 Pro Max (96GB) today!** πŸš€
> *Tested on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+.*
<!-- ml-intern-provenance -->
## Generated by ML Intern
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = 'tritesh/dflash-mlx-universal'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.