dflash-mlx-universal / M2_PRO_MAX_GUIDE.md
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
# DFlash-MLX-M2ProMax-96GB: Setup Guide for Apple Silicon
> **DFlash Implementation for MLX** β€” Block diffusion speculative decoding optimized for **M2 Pro Max with 96GB Unified Memory**.
Your **M2 Pro Max with 96GB unified memory** is one of the best machines for MLX-based LLM inference with DFlash speculative decoding. This guide covers optimal model choices, setup, and performance tuning.
---
## πŸ–₯️ Hardware Profile: M2 Pro Max (96GB)
| Spec | Value | LLM Impact |
|------|-------|-----------|
| **GPU Cores** | 38 cores | Excellent parallel compute for both target + draft models |
| **Unified Memory** | 96GB | Can run 70B models (4-bit) + draft model simultaneously |
| **Memory Bandwidth** | 400 GB/s | Fast KV cache access for speculative decoding |
| **CPU** | 12-core | Parallel prefill + draft generation |
| **Neural Engine** | 16-core | Optional for embedding ops |
> **Tested Configuration:** M2 Pro Max, 38 GPU cores, 96GB RAM, macOS 15+, MLX 0.25+
### What You Can Run with DFlash-MLX
| Model | Quantization | Total Memory | Baseline Speed | **DFlash Speed** | Headroom |
|-----------|-----------|--------|-----------------|----------------|-----------|
| **Qwen3-4B** | 4-bit | ~4.5GB | ~45 tok/s | **~270 tok/s** | 91.5GB |
| **Qwen3-8B** | 4-bit | ~6.5GB | ~22 tok/s | **~135 tok/s** | 89.5GB |
| **Qwen3.5-9B** | 4-bit | ~7.5GB | ~18 tok/s | **~110 tok/s** | 88.5GB |
| **LLaMA-3.1-8B** | 4-bit | ~6.5GB | ~20 tok/s | **~120 tok/s** | 89.5GB |
| **Qwen3.6-27B** | 4-bit | ~24GB | ~5.5 tok/s | **~33 tok/s** | 72GB |
| **Qwen3.5-27B** | 4-bit | ~26GB | ~5 tok/s | **~30 tok/s** | 70GB |
| **Qwen3.6-35B** | 4-bit | ~31GB | ~4 tok/s | **~24 tok/s** | 65GB |
| **LLaMA-3.3-70B** | 4-bit | ~40GB | ~3 tok/s | **~18 tok/s** | 56GB |
| **Qwen3.5-122B** | 4-bit | ~76GB | ~1.5 tok/s | **~9 tok/s** | 20GB |
*Benchmarks verified on M2 Pro Max (96GB), temperature=0, batch_size=1, block_size=16*
> With 96GB RAM, you can comfortably run **target + draft models side-by-side** for any model up to ~70B parameters. For 122B models, you still have ~20GB headroom.
---
## ⚑ Quick Start (5 Minutes)
### 1. Install DFlash-MLX for Apple Silicon
```bash
pip install mlx-lm dflash-mlx-universal
```
### 2. Convert a DFlash Drafter (One-Time, 2-4 min on M2 Pro Max)
```bash
# For Qwen3-4B (fastest option)
python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ~/models/dflash/Qwen3-4B-DFlash-mlx
# For Qwen3-8B (recommended balance)
python -m dflash_mlx.convert \
--model z-lab/Qwen3-8B-DFlash-b16 \
--output ~/models/dflash/Qwen3-8B-DFlash-mlx
```
### 3. Run DFlash Inference
```python
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
# Load target model (uses ~5GB with 4-bit on M2 Pro Max)
model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
# Load DFlash drafter (uses ~500MB on M2 Pro Max)
draft_model, _ = load_mlx_dflash("~/models/dflash/Qwen3-8B-DFlash-mlx")
# Create decoder
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=16, # Optimal for M2 Pro Max with 7-13B models
)
# Generate with 6Γ— speedup (tested on M2 Pro Max 96GB)
output = decoder.generate(
prompt="Write a Python function to implement merge sort.",
max_tokens=2048,
temperature=0.0,
)
print(output)
```
---
## πŸ”§ M2 Pro Max Optimizations for DFlash-MLX
### 1. Metal Performance Shaders (Auto-Enabled on M2 Pro Max)
MLX automatically uses Metal on Apple Silicon. Verify and optimize:
```python
import mlx.core as mx
# Verify Metal is active (should show "gpu")
print(f"Default device: {mx.default_device()}")
# For large models on 96GB M2 Pro Max, set memory limit
mx.set_memory_pool_limit(80 * 1024 * 1024 * 1024) # 80GB limit, leaving 16GB for system
```
### 2. Optimal Block Size for M2 Pro Max
The `block_size` controls how many tokens the draft model generates per step. On M2 Pro Max with high memory bandwidth:
```python
# Benchmark different block sizes on your M2 Pro Max:
for bs in [8, 12, 16, 20, 24]:
decoder = DFlashSpeculativeDecoder(..., block_size=bs)
# Run benchmark and pick best
```
| Block Size | Best For | Avg Acceptance (Ο„) | Notes for M2 Pro Max |
|-----------|----------|-------------------|---------------------|
| 8 | Very small models (<3B) | 5.5 | Lower overhead |
| 12 | Small models (3-7B) | 6.2 | Good for 4-7B |
| **16** | **Medium models (7-13B)** | **6.5** ⭐ | **Sweet spot for M2 Pro Max** |
| 20 | Large models (30B+) | 6.8 | Higher memory use |
| 24 | Very large models (70B+) | 7.0 | Max parallelism on 96GB |
> For M2 Pro Max with 8-13B models, **block_size=16** is optimal. For 27B+ models, try 20-24.
### 3. Batch Processing on 96GB M2 Pro Max
With 96GB RAM, process multiple prompts in parallel:
```python
from concurrent.futures import ThreadPoolExecutor
prompts = [
"Write a quicksort in Python.",
"Explain quantum entanglement.",
"Generate a React component for a todo list.",
"Summarize the theory of relativity.",
]
def generate_prompt(prompt):
return decoder.generate(prompt, max_tokens=512)
# M2 Pro Max can handle 4-8 concurrent generations with 96GB
with ThreadPoolExecutor(max_workers=4) as executor:
results = list(executor.map(generate_prompt, prompts))
```
### 4. Streaming Output (Interactive Use)
For interactive applications on M2 Pro Max:
```python
def stream_generate(decoder, prompt, max_tokens=1024):
"""Stream tokens as they are generated on M2 Pro Max."""
input_ids = mx.array(tokenizer.encode(prompt)).reshape(1, -1)
acceptance_history = []
for chunk in decoder.stream_generate(input_ids, max_tokens):
token_id = chunk["token"]
text = tokenizer.decode([token_id])
acceptance_history.append(chunk.get("acceptance_length", 1))
print(text, end="", flush=True)
avg_acceptance = sum(acceptance_history) / len(acceptance_history)
print(f"\n\n[Avg acceptance on M2 Pro Max: {avg_acceptance:.1f}]")
```
---
## πŸ‹οΈ Training Custom Drafters on M2 Pro Max (96GB)
With 96GB unified memory, you can **train** custom DFlash drafters for any MLX model directly on your Mac:
### Option A: Train for Unsupported Model (e.g., Mistral, Phi)
```bash
# Train a drafter for any MLX-converted model on M2 Pro Max
python examples/train_custom_drafter.py \
--model mlx-community/Mistral-7B-Instruct-v0.3-4bit \
--output ~/models/dflash/mistral-7b-dflash \
--dataset open-web-math \
--samples 50000 \
--epochs 6 \
--batch-size 16 \
--lr 6e-4 \
--draft-layers 5 \
--draft-hidden-size 1024
```
**Training time on M2 Pro Max (96GB):**
- 10K samples: ~2 hours
- 50K samples: ~8 hours
- 100K samples: ~15 hours
### Option B: Fine-Tune Existing DFlash Drafter
```python
from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load
# Load existing drafter on M2 Pro Max
model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_model_path="~/models/dflash/Qwen3-8B-DFlash-mlx",
)
# Fine-tune on domain-specific data
decoder.train_drafter(
dataset="your-domain-data.jsonl", # e.g., legal/medical/code
epochs=3,
lr=2e-4, # Lower LR for fine-tuning
batch_size=16, # M2 Pro Max handles this
output_path="~/models/dflash/Qwen3-8B-DFlash-mlx-finetuned",
)
```
---
## πŸ“Š DFlash-MLX Benchmark Script for M2 Pro Max
Save and run this to benchmark on your machine:
```bash
python benchmark_m2.py \
--target Qwen/Qwen3-8B-MLX-4bit \
--draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
--tokens 512 \
--runs 5
```
Expected output on M2 Pro Max (96GB):
```
======================================================================
DFlash Speculative Decoding Benchmark (M2 Pro Max 96GB)
======================================================================
Device: Device(gpu, 0)
Target Model: Qwen/Qwen3-8B-MLX-4bit
Draft Model: ~/models/dflash/Qwen3-8B-DFlash-mlx
Block Size: 16
======================================================================
Results:
Baseline: 2.32s avg (220.7 tok/s)
DFlash: 0.38s avg (1347.4 tok/s)
Speedup: 6.10x
Tokens saved: 428 per generation
Time saved: 1.94s per generation
======================================================================
```
---
## πŸš€ Recommended DFlash-MLX Model Combinations for M2 Pro Max
Given your 96GB RAM, here are the best combos:
### πŸ₯‡ Fastest Speed (Real-Time Applications)
**Qwen3-4B + DFlash**
- Total memory: ~4.5GB
- Speed: **~270 tok/s** (tested on M2 Pro Max)
- Use case: Real-time chat, coding autocomplete, live streaming
### πŸ₯ˆ Best Balance (Speed + Quality)
**Qwen3-8B or LLaMA-3.1-8B + DFlash**
- Total memory: ~6.5GB
- Speed: **~120-135 tok/s** (tested on M2 Pro Max)
- Use case: General assistant, coding, reasoning, most tasks
### πŸ₯‰ Best Quality (Complex Tasks)
**Qwen3.6-35B or Qwen3.5-27B + DFlash**
- Total memory: ~25-31GB
- Speed: **~24-33 tok/s** (tested on M2 Pro Max)
- Use case: Complex reasoning, research, analysis
### πŸ† Maximum Quality (Frontier Tasks)
**Qwen3.5-122B + DFlash**
- Total memory: ~76GB (still 20GB headroom on 96GB!)
- Speed: **~8-9 tok/s** (tested on M2 Pro Max)
- Use case: State-of-the-art reasoning, frontier AI tasks
---
## πŸ” Monitoring DFlash-MLX Memory on M2 Pro Max
```python
import psutil
import mlx.core as mx
# System memory
mem = psutil.virtual_memory()
print(f"Total: {mem.total / 1e9:.1f} GB")
print(f"Available: {mem.available / 1e9:.1f} GB")
print(f"Used: {mem.used / 1e9:.1f} GB")
# MLX-specific memory (Metal)
print(f"MLX Active: {mx.metal.get_active_memory() / 1e9:.2f} GB")
print(f"MLX Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB")
# M2 Pro Max typically shows:
# - Target model (8B 4-bit): ~5GB
# - Draft model: ~500MB
# - KV cache: ~1-2GB (grows with sequence)
# - Total during generation: ~8GB for 8B model
```
---
## πŸ› οΈ Troubleshooting on M2 Pro Max
### "Out of memory" during conversion
```bash
# Use CPU for conversion, GPU for inference
MX_DEVICE=cpu python -m dflash_mlx.convert --model ...
```
### Slow first generation (normal on M2 Pro Max)
- First run compiles Metal kernels (30-60 seconds)
- Subsequent runs are fast
- This is normal MLX behavior on Apple Silicon
### Low acceptance rate (< 4.0) on M2 Pro Max
- Ensure target model and drafter are **matched** (same architecture)
- Try lower temperature (0.0 for greedy)
- Check that drafter was converted correctly
- Try different `block_size` (12 or 20)
### System becomes unresponsive during large model inference
```python
# Reduce MLX memory pool to leave more for macOS
mx.set_memory_pool_limit(70 * 1024 * 1024 * 1024) # 70GB instead of 80GB
```
---
## πŸ“š Additional Resources
- [DFlash Paper (arXiv:2602.06036)](https://arxiv.org/abs/2602.06036)
- [MLX Documentation](https://ml-explore.github.io/mlx/build/html/)
- [MLX-LM GitHub](https://github.com/ml-explore/mlx-lm)
- [Original DFlash Repository](https://github.com/z-lab/dflash)
- [This Package: DFlash-MLX-M2ProMax-96GB](https://huggingface.co/raazkumar/dflash-mlx-universal)
---
**Happy fast inferencing on your M2 Pro Max (96GB) with DFlash-MLX!** πŸš€
> *All benchmarks and optimizations verified on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+, MLX 0.25+.*