How to use from
MLX LM
# Install MLX LM
uv tool install mlx-lm
# Generate some text
mlx_lm.generate --model "tritesh/dflash-mlx-universal" --prompt "Once upon a time"
Quick Links

DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX

Universal DFlash speculative decoding implementation for Apple Silicon (MLX). Works with any MLX-converted model β€” Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more.

Python MLX License uv


πŸš€ What is DFlash?

DFlash (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight block diffusion model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens in parallel within each block, achieving 4-6Γ— lossless speedup over baseline inference.

Key innovation: The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates.

Feature Baseline DFlash Improvement
Speed ~20 tok/s ~120 tok/s 6Γ— faster
Quality Same Same Lossless
Acceptance β€” Ο„ β‰ˆ 6-7 ~6 tokens accepted per draft

✨ What's New in Universal (v0.2.0)

This is a major rewrite that fixes the critical gaps in earlier community ports:

Gap Before (v0.1.x) Now (v0.2.0)
Architecture support Hardcoded to Qwen3 βœ… Universal adapters for Qwen3/3.5, LLaMA, Mistral, Gemma
Hidden state extraction Direct .layers access (breaks on most models) βœ… Architecture-aware adapter system with per-family hooks
KV cache management None β€” never rewound βœ… Proper trim/rewind on draft rejection
Attention masks mask=None (undefined behavior) βœ… Family-specific mask generation
Token acceptance Buggy cumprod logic βœ… First-mismatch detection with bonus token
Streaming Not supported βœ… Real-time text streaming with generator interface
OpenAI server Not supported βœ… FastAPI + simple HTTP with metrics endpoint
Model conversion PyTorchβ†’MLX weight converter βœ… Updated for all z-lab drafters
Training Basic trainer βœ… Architecture-aware training with adapter compatibility
Benchmarking None βœ… Built-in benchmark vs mlx_lm baseline
uv support pip only βœ… uv + uv run workflow with lock files

πŸ“¦ Installation

Option 1: uv (Recommended β€” ultra-fast, reproducible)

uv is an extremely fast Python package manager written in Rust. It's the recommended way to install on macOS.

# 1. Install uv (one-time)
brew install uv
# or: curl -LsSf https://astral.sh/uv/install.sh | sh

# 2. Clone and setup
git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
cd dflash-mlx-universal

# 3. One-command setup (creates venv, installs deps, locks)
chmod +x setup_uv.sh
./setup_uv.sh

# Or manually:
uv venv
uv pip install -e ".[dev,server]"
uv lock

Why uv?

  • 10-100Γ— faster than pip (written in Rust)
  • Automatic virtual environment management
  • Lock file (uv.lock) for reproducible installs
  • uv run β€” run any script without activating venv manually
# Examples of uv workflow
uv run python examples/qwen3_4b_demo.py
uv run pytest tests/ -v
uv run python -m dflash_mlx.serve --target ... --draft ... --port 8000
uv run black dflash_mlx/
uv run ruff check dflash_mlx/

Option 2: pip (Classic)

pip install mlx-lm dflash-mlx-universal

For Apple Silicon (M1/M2/M3/M4):

pip install --upgrade pip
pip install mlx-lm dflash-mlx-universal

Optional (for server mode):

pip install fastapi uvicorn

⚑ Quick Start

Option 1: Pre-converted DFlash drafter (recommended)

from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash, infer_target_model
from mlx_lm import load

# 1. Load any MLX target model
target_path = "mlx-community/Qwen3-4B-bf16"
model, tokenizer = load(target_path)

# 2. Load a pre-converted DFlash drafter
draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")

# 3. Create architecture-aware decoder
decoder = DFlashSpeculativeDecoder(
    target_model=model,
    draft_model=draft_model,
    tokenizer=tokenizer,
    block_size=draft_config.get("block_size", 16),
)

# 4. Generate with 6Γ— speedup
output = decoder.generate(
    prompt="Explain quantum computing to a 10-year-old.",
    max_tokens=1024,
    temperature=0.0,
)
print(output)

Option 2: Universal decoder (auto-detects architecture)

from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load

# Works with ANY mlx_lm model
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")

# Auto-detects architecture, creates generic drafter
decoder = UniversalDFlashDecoder(
    target_model=model,
    tokenizer=tokenizer,
    draft_layers=5,
    draft_hidden_size=1024,
    block_size=16,
)

# Train a custom drafter (2-8 hours on Apple Silicon)
decoder.train_drafter(
    dataset="open-web-math",
    epochs=6,
    lr=6e-4,
    batch_size=16,
)

output = decoder.generate("Write a Python function to implement quicksort.")
print(output)

Option 3: Convert PyTorch drafter to MLX

# Download official z-lab drafter and convert weights
python -m dflash_mlx.convert \
    --model z-lab/Qwen3-4B-DFlash-b16 \
    --output ./Qwen3-4B-DFlash-mlx

# Or with uv (recommended)
uv run python -m dflash_mlx.convert \
    --model z-lab/Qwen3-4B-DFlash-b16 \
    --output ./Qwen3-4B-DFlash-mlx

# Or in Python
from dflash_mlx.convert import convert_dflash_to_mlx

convert_dflash_to_mlx(
    pytorch_model_id="z-lab/Qwen3.5-9B-DFlash",
    output_path="./Qwen3.5-9B-DFlash-mlx",
)

🎯 Supported Models

Pre-built DFlash drafters (convert to MLX)

All official z-lab/*-DFlash models can be converted:

PyTorch Drafter Target Model Status
z-lab/Qwen3-4B-DFlash-b16 Qwen/Qwen3-4B βœ… Ready
z-lab/Qwen3-8B-DFlash-b16 Qwen/Qwen3-8B βœ… Ready
z-lab/Qwen3.5-4B-DFlash Qwen/Qwen3.5-4B βœ… Ready
z-lab/Qwen3.5-9B-DFlash Qwen/Qwen3.5-9B βœ… Ready
z-lab/Qwen3.5-27B-DFlash Qwen/Qwen3.5-27B βœ… Ready
z-lab/Qwen3.6-27B-DFlash Qwen/Qwen3.6-27B βœ… Ready
z-lab/Qwen3.6-35B-A3B-DFlash Qwen/Qwen3.6-35B-A3B βœ… Ready
z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat meta-llama/Llama-3.1-8B βœ… Ready
z-lab/gemma-4-31B-it-DFlash google/gemma-4-31b-it βœ… Ready
z-lab/gpt-oss-20b-DFlash openai/gpt-oss-20b βœ… Ready
z-lab/Kimi-K2.5-DFlash moonshotai/Kimi-K2.5 βœ… Ready

Architecture adapters (built-in)

Model Family Adapter Hidden States KV Cache Attention Mask
Qwen3 Qwen3Adapter βœ… βœ… KVCache.trim() βœ… qwen3.create_attention_mask
Qwen3.5 Qwen35Adapter βœ… βœ… ArraysCache βœ… Hybrid FA + SSM masks
LLaMA 2/3 LlamaAdapter βœ… βœ… KVCache.trim() βœ… llama.create_attention_mask
Mistral MistralAdapter βœ… βœ… KVCache.trim() βœ… mistral.create_attention_mask
Gemma GemmaAdapter βœ… βœ… KVCache.trim() βœ… gemma.create_attention_mask
Generic MLXTargetAdapter βœ… βœ… Basic trim ⚠️ Causal fallback

πŸ—οΈ Architecture Overview

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Target Model  │────▢│ Extract Hidden  β”‚
β”‚  (Any MLX LLM)  β”‚     β”‚  Features (KV)  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                 β”‚
                                 β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Verify Drafts  │◀────│  DFlash Draft   β”‚
β”‚  (Parallel)     β”‚     β”‚  Model (Diffusion)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚                        β–²
         β”‚    Accepted Tokens     β”‚
         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Key Design

  1. Architecture Adapters: Per-family MLXTargetAdapter subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management
  2. KV Injection: Target model hidden states β†’ draft model's K/V projections via extract_context_features()
  3. Block Diffusion: All tokens in a block predicted in parallel (not sequentially)
  4. Cross-Layer Fusion: Features from multiple target layers concatenated and projected
  5. Exact Acceptance: Draft tokens verified greedily; KV cache rewound to accepted prefix

πŸ“Š Benchmarking

from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
from mlx_lm import load

model, tokenizer = load("Qwen/Qwen3-4B")
draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")

decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)

# Built-in benchmark (runs warmup + multiple trials)
results = decoder.benchmark(
    prompt="Write a quicksort in Python.",
    max_tokens=512,
    num_runs=5,
)
# prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s

πŸ–₯️ OpenAI-Compatible Server

# Start server with DFlash acceleration
python -m dflash_mlx.serve \
    --target mlx-community/Qwen3.5-9B-4bit \
    --draft ./Qwen3.5-9B-DFlash-mlx \
    --block-size 16 \
    --port 8000

# With uv (recommended)
uv run python -m dflash_mlx.serve \
    --target mlx-community/Qwen3.5-9B-4bit \
    --draft ./Qwen3.5-9B-DFlash-mlx \
    --block-size 16 \
    --port 8000

# Query with curl
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "qwen3.5-9b",
    "messages": [{"role": "user", "content": "Hello!"}],
    "max_tokens": 256,
    "temperature": 0.0,
    "stream": false
  }'

# Streaming SSE
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "qwen3.5-9b",
    "messages": [{"role": "user", "content": "Count to 10"}],
    "max_tokens": 100,
    "stream": true
  }'

# Check metrics
curl http://localhost:8000/metrics

Endpoints:

  • GET /health β€” Server status and mode
  • GET /v1/models β€” Available models
  • GET /metrics β€” Request count, tok/s, recent history
  • POST /v1/chat/completions β€” Chat completions (OpenAI-compatible)

πŸ‹οΈ Training Custom Drafters

from dflash_mlx.universal import UniversalDFlashDecoder
from mlx_lm import load

model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")

decoder = UniversalDFlashDecoder(
    target_model=model,
    tokenizer=tokenizer,
    draft_layers=5,
    draft_hidden_size=1024,
)

# Train using paper recipe (6 epochs, lr=6e-4, AdamW)
decoder.train_drafter(
    dataset="open-web-math",  # or local JSONL with {prompt, response}
    epochs=6,
    lr=6e-4,
    batch_size=16,
    warmup_ratio=0.04,
    grad_clip=1.0,
    output_path="./my-llama-drafter",
)

# Save and reload
decoder.save_drafter("./my-llama-drafter")

Training recipe (from DFlash paper Β§5):

  • Data mix: 50% Chat + 30% Math + 20% Code
  • Random anchor sampling: real accepted tokens as block starts
  • Sparse attention mask: bidirectional within block, causal across blocks
  • Position-dependent loss decay: exponential decay from anchor
  • AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule

πŸ“ Repository Structure

dflash-mlx-universal/
β”œβ”€β”€ dflash_mlx/
β”‚   β”œβ”€β”€ __init__.py              # Package exports
β”‚   β”œβ”€β”€ adapters.py              # πŸ”‘ Architecture adapters (NEW v0.2.0)
β”‚   β”œβ”€β”€ model.py                 # DFlash draft model (attention, diffusion)
β”‚   β”œβ”€β”€ speculative_decode.py    # Core speculative decoding loop (FIXED)
β”‚   β”œβ”€β”€ convert.py               # PyTorch β†’ MLX weight converter
β”‚   β”œβ”€β”€ universal.py             # Generic decoder for any model
β”‚   β”œβ”€β”€ trainer.py               # DFlash drafter training
β”‚   β”œβ”€β”€ data.py                  # Training data generation
β”‚   └── serve.py                 # OpenAI-compatible HTTP server (NEW)
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ qwen3_4b_demo.py         # End-to-end Qwen3 demo
β”‚   β”œβ”€β”€ convert_drafter.py       # CLI conversion script
β”‚   └── train_custom_drafter.py  # CLI training script
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ test_model.py            # Model unit tests
β”‚   └── test_adapters.py         # Adapter tests (NEW)
β”œβ”€β”€ benchmark_m2.py              # Apple Silicon benchmark
β”œβ”€β”€ setup_m2.sh                  # Automated setup script
β”œβ”€β”€ setup_uv.sh                  # βœ… UV setup script (NEW v0.2.0)
β”œβ”€β”€ .python-version              # Python version pin for uv
β”œβ”€β”€ USAGE_GUIDE.md               # Detailed usage guide
β”œβ”€β”€ M2_PRO_MAX_GUIDE.md          # Detailed M2 Pro Max guide
β”œβ”€β”€ README.md                    # This file
└── pyproject.toml               # Package configuration (with uv support)

πŸ§ͺ Testing

# With uv (recommended)
uv run pytest tests/
uv run pytest tests/test_adapters.py -v
uv run pytest tests/test_model.py -v
uv run pytest --cov=dflash_mlx tests/

# Classic pip
pytest tests/
pytest tests/test_adapters.py -v
pytest tests/test_model.py -v

πŸ”§ Adding a New Model Family

To add support for a new architecture (e.g., Phi, Falcon):

# 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py
class PhiAdapter(MLXTargetAdapter):
    family = "phi"
    
    def create_attention_mask(self, hidden_states, cache=None):
        # Phi-specific mask generation
        from mlx_lm.models import phi
        return phi.create_attention_mask(hidden_states, cache)
    
    def embed_tokens(self, tokens):
        # Phi uses token_embedding, not embed_tokens
        return self.model.token_embedding(tokens)

# 2. Register in ADAPTERS dict
ADAPTERS["phi"] = PhiAdapter

# 3. Add alias if needed
def adapter_for_model_type(model_type):
    if model_type.startswith("phi"):
        return PhiAdapter
    # ...

See ADDING_MODELS.md (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria.


πŸ“Š Performance (Reference)

Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+:

Model Baseline tok/s DFlash tok/s Speedup Memory
Qwen3-4B (4-bit) ~45 ~270 6.0Γ— ~4.5GB
Qwen3-8B (4-bit) ~22 ~135 6.1Γ— ~6.5GB
Qwen3.5-9B (4-bit) ~18 ~110 6.1Γ— ~7.5GB
LLaMA-3.1-8B (4-bit) ~20 ~120 6.0Γ— ~6.5GB
Qwen3.5-27B (4-bit) ~5 ~30 6.0Γ— ~26GB

Actual numbers depend on prompt complexity, temperature, and hardware.


πŸ“ Citation

@misc{chen2026dflash,
  title={DFlash: Block Diffusion for Flash Speculative Decoding},
  author={Jian Chen and Yesheng Liang and Zhijian Liu},
  year={2026},
  eprint={2602.06036},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}

πŸ“„ License

MIT License β€” same as the original DFlash project.


πŸ™ Acknowledgements

  • Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
  • Aryagm for the original MLX community port (dflash-mlx) and adapter pattern
  • bstnxbt for the production MLX port with Metal kernels and prefix caching
  • MLX team at Apple for the excellent MLX framework
  • Hugging Face community for model hosting and tools

Get 6Γ— faster LLM inference on Apple Silicon today! πŸš€

Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+. ```

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Space using tritesh/dflash-mlx-universal 1

Paper for tritesh/dflash-mlx-universal