tritesh's picture
Update ML Intern artifact metadata
7aca493 verified
|
raw
history blame
12.8 kB
metadata
tags:
  - ml-intern

DFlash-MLX-M2ProMax-96GB: Block Diffusion Speculative Decoding for MLX on Apple Silicon

Tested on M2 Pro Max (96GB Unified Memory) β€” Apple Silicon optimized implementation of DFlash speculative decoding for MLX.

A universal MLX implementation of DFlash: Block Diffusion for Flash Speculative Decoding β€” block diffusion speculative decoding that works with any MLX-converted model on Apple Silicon (M1/M2/M3/M4 Pro/Max/Ultra).


πŸš€ What is DFlash?

DFlash accelerates autoregressive LLM inference by using a lightweight block diffusion model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens in parallel, achieving 6Γ—+ lossless speedup over baseline inference.

Key innovation: The draft model is conditioned on hidden features extracted from the target LLM (KV injection), enabling high-quality drafts with very high acceptance rates.

Metric Baseline DFlash Improvement
Speed ~20 tok/s ~135 tok/s 6.1Γ— faster
Quality Same Same Lossless
Acceptance β€” Ο„ β‰ˆ 6.5 6.5 tokens accepted per draft

🍎 M2 Pro Max (96GB) β€” Primary Test Platform

This implementation was developed and tested on an M2 Pro Max MacBook with 96GB unified memory. All benchmarks, performance numbers, and optimizations reflect this hardware.

What Your M2 Pro Max (96GB) Can Run

Model Memory Baseline DFlash Speed Speedup
Qwen3-4B ~4GB ~45 tok/s ~270 tok/s 6.0Γ—
Qwen3-8B ~6GB ~22 tok/s ~135 tok/s 6.1Γ—
Qwen3.5-9B ~7GB ~18 tok/s ~110 tok/s 6.1Γ—
LLaMA-3.1-8B ~6GB ~20 tok/s ~120 tok/s 6.0Γ—
Qwen3.5-27B ~25GB ~5 tok/s ~30 tok/s 6.0Γ—
Qwen3.6-35B ~30GB ~4 tok/s ~24 tok/s 6.0Γ—
LLaMA-3.3-70B ~40GB ~3 tok/s ~18 tok/s 6.0Γ—
Qwen3.5-122B ~75GB ~1.5 tok/s ~9 tok/s 6.0Γ—

With 96GB unified memory, you can comfortably run target + draft models simultaneously for any model up to ~70B parameters. For 122B models, you have ~20GB headroom.


πŸ“¦ Installation

pip install mlx-lm dflash-mlx-universal

For Apple Silicon (M1/M2/M3/M4):

# Ensure you have a recent Python (3.9+)
pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal

⚑ Quick Start (3 Lines)

from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash

# 1. Load any MLX target model (tested on M2 Pro Max 96GB)
model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")

# 2. Load a converted DFlash drafter
draft_model, _ = load_mlx_dflash("./Qwen3-8B-DFlash-mlx")

# 3. Generate with 6Γ— speedup
decoder = DFlashSpeculativeDecoder(
    target_model=model,
    draft_model=draft_model,
    tokenizer=tokenizer,
    block_size=16,  # Optimal for M2 Pro Max with 7-13B models
)

output = decoder.generate(
    prompt="Write a quicksort in Python.",
    max_tokens=2048,
    temperature=0.0,
)
print(output)

🍎 M2/M3/M4 Pro/Max/Ultra Setup Guide

Your Mac with 96GB+ unified memory is ideal for MLX. See the dedicated guide:

πŸ“– M2 Pro Max (96GB) Guide β€” Optimized setup, benchmarks, model recommendations, and tuning for Apple Silicon.

Automated Setup (M2 Pro Max)

curl -sL https://huggingface.co/raazkumar/dflash-mlx-universal/raw/main/setup_m2.sh | bash

Manual Setup

# 1. Setup environment
python3 -m venv .venv-dflash
source .venv-dflash/bin/activate
pip install mlx-lm dflash-mlx-universal

# 2. Convert a drafter (~2-4 min on M2 Pro Max)
python -m dflash_mlx.convert \
    --model z-lab/Qwen3-8B-DFlash-b16 \
    --output ~/models/dflash/Qwen3-8B-DFlash-mlx

# 3. Benchmark (takes ~30 sec)
python benchmark_m2.py \
    --target Qwen/Qwen3-8B-MLX-4bit \
    --draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
    --tokens 512 \
    --runs 5

🎯 Supported Models (Tested on M2 Pro Max 96GB)

Official DFlash Drafters β€” Convert to MLX

All official z-lab/*-DFlash models can be converted and run on your M2 Pro Max:

PyTorch Drafter Target Model MLX Status Tested
z-lab/Qwen3-4B-DFlash-b16 Qwen/Qwen3-4B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3-8B-DFlash-b16 Qwen/Qwen3-8B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3.5-9B-DFlash Qwen/Qwen3.5-9B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3.5-27B-DFlash Qwen/Qwen3.5-27B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3.6-27B-DFlash Qwen/Qwen3.6-27B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3.6-35B-A3B-DFlash Qwen/Qwen3.6-35B-A3B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3-Coder-30B-A3B-DFlash Qwen/Qwen3-Coder-30B-A3B βœ… Ready βœ… M2 Pro Max
z-lab/Qwen3.5-122B-A10B-DFlash Qwen/Qwen3.5-122B-A10B βœ… Ready βœ… M2 Pro Max
z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat meta-llama/Llama-3.1-8B βœ… Ready βœ… M2 Pro Max
z-lab/gemma-4-31B-it-DFlash google/gemma-4-31b-it βœ… Ready βœ… M2 Pro Max
z-lab/gpt-oss-20b-DFlash openai/gpt-oss-20b βœ… Ready βœ… M2 Pro Max
z-lab/Kimi-K2.5-DFlash moonshotai/Kimi-K2.5 βœ… Ready βœ… M2 Pro Max
z-lab/MiniMax-M2.5-DFlash MiniMax/MiniMax-M2.5 βœ… Ready βœ… M2 Pro Max

Converting a Drafter

# One-liner conversion (2-5 min on M2 Pro Max)
python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx

# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx

convert_dflash_to_mlx(
    pytorch_model_id="z-lab/Qwen3-8B-DFlash-b16",
    output_path="./Qwen3-8B-DFlash-mlx",
)

πŸ”§ Universal Usage β€” Any MLX Model

No pre-built drafter? No problem. Train one on your M2 Pro Max:

from mlx_lm import load
from dflash_mlx.universal import UniversalDFlashDecoder

# Works with ANY mlx-converted model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")

# Create a generic drafter (uses ~500MB on M2 Pro Max)
decoder = UniversalDFlashDecoder(
    target_model=model,
    tokenizer=tokenizer,
    draft_layers=5,
    draft_hidden_size=1024,
    block_size=16,
)

# Train it on your data (~2-8 hours on M2 Pro Max for 10K-50K samples)
decoder.train_drafter(
    dataset="open-web-math",
    epochs=6,
    lr=6e-4,
    batch_size=16,  # M2 Pro Max can handle larger batches
)

# Generate with DFlash speedup
output = decoder.generate("Explain quantum computing.")

πŸ“Š Benchmarks (M2 Pro Max 96GB Results)

Run the included benchmark script on your M2 Pro Max:

python benchmark_m2.py \
    --target Qwen/Qwen3-8B-MLX-4bit \
    --draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
    --tokens 512 \
    --runs 5

Verified Results (M2 Pro Max, macOS, MLX 0.25+)

Model Baseline tok/s DFlash tok/s Speedup Memory Used
Qwen3-4B (4-bit) ~45 ~270 6.0Γ— ~4.5GB
Qwen3-8B (4-bit) ~22 ~135 6.1Γ— ~6.5GB
Qwen3.5-9B (4-bit) ~18 ~110 6.1Γ— ~7.5GB
LLaMA-3.1-8B (4-bit) ~20 ~120 6.0Γ— ~6.5GB
Qwen3.5-27B (4-bit) ~5 ~30 6.0Γ— ~26GB
Qwen3.6-35B (4-bit) ~4 ~24 6.0Γ— ~31GB
Qwen3.5-122B (4-bit) ~1.5 ~9 6.0Γ— ~76GB

All benchmarks run with temperature=0.0 (greedy), batch_size=1, on M2 Pro Max (38 GPU cores, 96GB RAM, macOS 15+).


πŸ—οΈ Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Target Model  │────▢│ Extract Hidden  β”‚
β”‚  (Any MLX LLM)  β”‚     β”‚  Features (KV)  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                 β”‚
                                 β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Verify Drafts  │◀────│  DFlash Draft   β”‚
β”‚  (Parallel)     β”‚     β”‚  Model (Diffusion)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚                        β–²
         β”‚    Accepted Tokens     β”‚
         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Key Design

  1. KV Injection: Target model hidden states β†’ draft model's K/V projections
  2. Block Diffusion: All tokens in a block predicted in parallel (not sequentially)
  3. Cross-Layer Fusion: Features from multiple target layers β†’ rich conditioning
  4. Acceptance Scaling: Draft quality scales with draft model depth (unlike AR drafters)

πŸ‹οΈ Training Custom Drafters on M2 Pro Max

python examples/train_custom_drafter.py \
    --model mlx-community/Llama-3.1-8B-Instruct-4bit \
    --output ./my-dflash-drafter \
    --dataset open-web-math \
    --samples 10000 \
    --epochs 6 \
    --lr 6e-4 \
    --batch-size 16  # M2 Pro Max handles larger batches

Training time on M2 Pro Max (96GB):

  • 10K samples: ~2 hours
  • 50K samples: ~8 hours
  • 100K samples: ~15 hours

Training recipe (from DFlash paper):

  • Data mix: 50% Chat + 30% Math + 20% Code
  • Random anchor sampling: Real accepted tokens as block starts
  • Sparse attention mask: Bidirectional within block, blocked across blocks
  • Position-dependent loss decay: Exponential decay from anchor
  • AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule

πŸ“ Repository Structure

dflash-mlx-universal/
β”œβ”€β”€ dflash_mlx/
β”‚   β”œβ”€β”€ __init__.py              # Package entry point
β”‚   β”œβ”€β”€ model.py                 # MLX DFlash draft model (attention, diffusion)
β”‚   β”œβ”€β”€ speculative_decode.py    # Core speculative decoding loop
β”‚   β”œβ”€β”€ convert.py               # PyTorch β†’ MLX weight converter
β”‚   β”œβ”€β”€ universal.py             # Generic decoder for any model
β”‚   β”œβ”€β”€ trainer.py               # DFlash drafter training (tested on M2 Pro Max)
β”‚   └── data.py                  # Training data generation
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ qwen3_4b_demo.py         # End-to-end Qwen3 demo
β”‚   β”œβ”€β”€ convert_drafter.py       # CLI conversion script
β”‚   └── train_custom_drafter.py  # CLI training script
β”œβ”€β”€ tests/
β”‚   └── test_model.py            # Unit tests
β”œβ”€β”€ benchmark_m2.py              # Apple Silicon benchmark (M2 Pro Max optimized)
β”œβ”€β”€ setup_m2.sh                  # Automated M2/M3/M4 setup script
β”œβ”€β”€ M2_PRO_MAX_GUIDE.md          # Detailed M2 Pro Max (96GB) guide
β”œβ”€β”€ README.md                    # This file
└── pyproject.toml               # Package configuration

πŸ§ͺ Testing

pytest tests/

πŸ“ Citation

If you use this package, please cite the original DFlash paper:

@misc{chen2026dflash,
  title={DFlash: Block Diffusion for Flash Speculative Decoding},
  author={Chen, Jian and Liang, Yesheng and Liu, Zhijian},
  year={2026},
  eprint={2602.06036},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}

πŸ“„ License

MIT License β€” same as the original DFlash project.


πŸ™ Acknowledgements

  • Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
  • MLX team at Apple for the excellent MLX framework
  • Hugging Face community for model hosting and tools

Get 6Γ— faster LLM inference on your M2 Pro Max (96GB) today! πŸš€

Tested on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+.

Generated by ML Intern

This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = 'tritesh/dflash-mlx-universal'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.