tritesh's picture
Upload app.py
d728bf2 verified
"""
DFlash-MLX-Universal: Interactive Demo
=========================================
A Gradio demo showcasing DFlash speculative decoding for MLX on Apple Silicon.
Note: MLX requires Apple Silicon hardware (M1/M2/M3/M4). This demo runs on
cpu_basic but shows the interface. For actual inference, run locally on macOS.
Repository: https://huggingface.co/tritesh/dflash-mlx-universal
Paper: https://arxiv.org/abs/2602.06036 (DFlash: Block Diffusion for Flash Speculative Decoding)
"""
import gradio as gr
import json
import time
# ── Demo Data ────────────────────────────────────────────────────────────────
SUPPORTED_MODELS = {
"Qwen3-4B": {
"target": "mlx-community/Qwen3-4B-bf16",
"drafter": "z-lab/Qwen3-4B-DFlash-b16",
"baseline_tok_s": 45,
"dflash_tok_s": 270,
"speedup": 6.0,
"memory": "4.5GB (4-bit)",
"status": "βœ… Ready",
},
"Qwen3-8B": {
"target": "mlx-community/Qwen3-8B-bf16",
"drafter": "z-lab/Qwen3-8B-DFlash-b16",
"baseline_tok_s": 22,
"dflash_tok_s": 135,
"speedup": 6.1,
"memory": "6.5GB (4-bit)",
"status": "βœ… Ready",
},
"Qwen3.5-9B": {
"target": "mlx-community/Qwen3.5-9B-4bit",
"drafter": "z-lab/Qwen3.5-9B-DFlash",
"baseline_tok_s": 18,
"dflash_tok_s": 110,
"speedup": 6.1,
"memory": "7.5GB (4-bit)",
"status": "βœ… Ready",
},
"Qwen3.5-27B": {
"target": "mlx-community/Qwen3.5-27B-4bit",
"drafter": "z-lab/Qwen3.5-27B-DFlash",
"baseline_tok_s": 5,
"dflash_tok_s": 30,
"speedup": 6.0,
"memory": "26GB (4-bit)",
"status": "βœ… Ready",
},
"LLaMA-3.1-8B": {
"target": "mlx-community/Llama-3.1-8B-Instruct-4bit",
"drafter": "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat",
"baseline_tok_s": 20,
"dflash_tok_s": 120,
"speedup": 6.0,
"memory": "6.5GB (4-bit)",
"status": "βœ… Ready",
},
"Gemma-4-31B": {
"target": "mlx-community/gemma-4-31b-it-4bit",
"drafter": "z-lab/gemma-4-31B-it-DFlash",
"baseline_tok_s": 3,
"dflash_tok_s": 18,
"speedup": 6.0,
"memory": "30GB (4-bit)",
"status": "βœ… Ready",
},
}
EXAMPLE_PROMPTS = [
"Explain quantum computing to a 10-year-old.",
"Write a Python function to implement quicksort.",
"Describe the differences between diffusion models and autoregressive transformers.",
"Write a short story about a robot who learns to paint.",
"Compare and contrast the French and American revolutions.",
"Debug this Python code: def fib(n): return fib(n-1) + fib(n-2)",
]
# ── Interactive Functions ────────────────────────────────────────────────────
def show_model_info(model_name):
info = SUPPORTED_MODELS.get(model_name, {})
if not info:
return "Model not found."
details = f"""### 🎯 {model_name}
**Target Model:** `{info['target']}`
**Drafter:** `{info['drafter']}`
**Status:** {info['status']}
**Memory:** {info['memory']}
**Performance:**
- Baseline: {info['baseline_tok_s']} tok/s
- DFlash: {info['dflash_tok_s']} tok/s
- **Speedup: {info['speedup']}Γ—** πŸš€
"""
return details
def generate_code(model_name, prompt, max_tokens, temperature, block_size):
info = SUPPORTED_MODELS.get(model_name, {})
target = info.get("target", "mlx-community/Qwen3-4B-bf16")
drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16")
code = f'''from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
# 1. Load target model (any MLX-converted LLM)
model, tokenizer = load("{target}")
# 2. Load converted DFlash drafter
draft_model, draft_config = load_mlx_dflash("./{model_name.replace('-', '_')}-DFlash-mlx")
# 3. Create architecture-aware decoder
# Auto-detects Qwen3/LLaMA/Gemma/Mistral via adapters
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size={block_size},
)
# 4. Generate with {info.get('speedup', 6.0)}Γ— speedup
output = decoder.generate(
prompt="""{prompt}""",
max_tokens={max_tokens},
temperature={temperature},
)
print(output)
'''
return code
def simulate_generation(model_name, prompt, max_tokens, temperature, block_size):
info = SUPPORTED_MODELS.get(model_name, {})
if not info:
return "Model not found."
baseline_tok_s = info['baseline_tok_s']
dflash_tok_s = info['dflash_tok_s']
speedup = info['speedup']
steps = []
prompt_tokens = len(prompt.split()) * 1.3
prefill_time = prompt_tokens / baseline_tok_s
steps.append(f"πŸ“‹ Prefill: Processing {int(prompt_tokens)} prompt tokens... {prefill_time:.2f}s")
num_iterations = max_tokens // block_size
accepted_per_block = block_size * 0.65
for i in range(min(num_iterations, 5)):
accepted = int(min(block_size, accepted_per_block))
steps.append(
f"πŸ”„ Iteration {i+1}: Draft {block_size} tokens β†’ Verify β†’ Accept {accepted} tokens"
)
remaining = max_tokens % block_size
if remaining > 0:
tail_time = remaining / baseline_tok_s
steps.append(f"✏️ Tail: Generating final {remaining} tokens... {tail_time:.2f}s")
total_baseline_time = max_tokens / baseline_tok_s
total_dflash_time = total_baseline_time / speedup
summary = f"""### πŸ“Š Generation Summary
**Model:** {model_name}
**Prompt:** *{prompt[:50]}...*
**Max tokens:** {max_tokens} | **Block size:** {block_size} | **Temperature:** {temperature}
**Timing:**
- Baseline (autoregressive): **{total_baseline_time:.2f}s**
- DFlash (speculative): **{total_dflash_time:.2f}s**
- **Speedup: {speedup:.1f}Γ—** πŸš€
**Token throughput:** {dflash_tok_s} tok/s
**Generation steps:**
{chr(10).join(f" {s}" for s in steps)}
---
> πŸ’‘ **Note:** These are reference benchmarks from an M2 Pro Max (96GB).
> Actual performance varies by prompt complexity, temperature, and hardware.
> Run locally on your Apple Silicon Mac for real results.
"""
return summary
def convert_drafter_command(model_name, output_path):
info = SUPPORTED_MODELS.get(model_name, {})
drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16")
return f"""### πŸ› οΈ Convert DFlash Drafter to MLX
Using **uv** (recommended):
```bash
# 1. Setup (if not done)
git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
cd dflash-mlx-universal
uv venv
uv pip install -e ".[dev,server]"
# 2. Convert
cd dflash-mlx-universal
uv run python -m dflash_mlx.convert \\
--model {drafter} \\
--output {output_path}
# 3. Verify
ls -la {output_path}
# Should show: weights.npz, config.json, model_info.json
```
Using pip:
```bash
python -m dflash_mlx.convert \\
--model {drafter} \\
--output {output_path}
```
**What this does:**
1. Downloads PyTorch weights from HuggingFace Hub
2. Transposes linear layers (PyTorch β†’ MLX column-major)
3. Saves as `.npz` + `config.json`
4. ~500MB download, ~2 min conversion time
"""
def train_drafter_command():
return f"""### πŸŽ“ Train Your Own DFlash Drafter
For models without pre-built drafters (Mistral, Phi, etc.):
```python
from mlx_lm import load
from dflash_mlx.universal import UniversalDFlashDecoder
# 1. Load ANY mlx_lm model
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# 2. Auto-detects architecture, creates generic drafter
decoder = UniversalDFlashDecoder(
target_model=model,
tokenizer=tokenizer,
draft_layers=5,
draft_hidden_size=1024,
block_size=16,
)
# 3. Train using paper recipe (6 epochs, lr=6e-4)
decoder.train_drafter(
dataset="open-web-math",
epochs=6,
lr=6e-4,
batch_size=16,
warmup_ratio=0.04,
grad_clip=1.0,
output_path="./my-mistral-drafter",
)
```
**Training time:** 2-8 hours on Apple Silicon (M2 Pro Max)
**Hardware:** 32GB+ unified memory recommended
**Data:** Any text dataset with prompt/response pairs
"""
def server_command(model_name, port):
info = SUPPORTED_MODELS.get(model_name, {})
target = info.get("target", "mlx-community/Qwen3-4B-bf16")
drafter_name = model_name.replace("-", "_")
return f"""### πŸ–₯️ OpenAI-Compatible Server
Start the server with DFlash acceleration:
```bash
# With uv (recommended)
uv run python -m dflash_mlx.serve \\
--target {target} \\
--draft ./{drafter_name}-DFlash-mlx \\
--block-size 16 \\
--port {port}
# Background mode
nohup uv run python -m dflash_mlx.serve \\
--target {target} \\
--draft ./{drafter_name}-DFlash-mlx \\
--port {port} > server.log 2>&1 &
```
**Query with curl:**
```bash
curl http://localhost:{port}/v1/chat/completions \\
-H "Content-Type: application/json" \\
-d '{{
"model": "{model_name.lower().replace('-', '')}",
"messages": [{{"role": "user", "content": "Hello!"}}],
"max_tokens": 256,
"temperature": 0.0
}}'
```
**Python client:**
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:{port}/v1",
api_key="not-needed",
)
response = client.chat.completions.create(
model="{model_name.lower().replace('-', '')}",
messages=[{{"role": "user", "content": "Explain DFlash"}}],
max_tokens=512,
)
print(response.choices[0].message.content)
```
"""
# ── Gradio Interface ─────────────────────────────────────────────────────────
with gr.Blocks(title="DFlash-MLX-Universal Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸš€ DFlash-MLX-Universal
### Block Diffusion Speculative Decoding for Apple Silicon
**Paper:** [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) |
**Repo:** [tritesh/dflash-mlx-universal](https://huggingface.co/tritesh/dflash-mlx-universal) |
**Package:** `dflash-mlx-universal`
Get **6Γ— faster** LLM inference on your M1/M2/M3/M4 Mac with **lossless output**.
""")
with gr.Tab("πŸƒ Quick Start"):
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=list(SUPPORTED_MODELS.keys()),
value="Qwen3-4B",
label="Select Model",
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt...",
value="Write a Python function to implement quicksort.",
lines=3,
)
with gr.Row():
max_tokens_slider = gr.Slider(
64, 2048, value=512, step=64,
label="Max Tokens"
)
temperature_slider = gr.Slider(
0.0, 1.0, value=0.0, step=0.1,
label="Temperature"
)
block_size_slider = gr.Slider(
4, 32, value=16, step=4,
label="Block Size (tokens per draft block)"
)
generate_btn = gr.Button("πŸ“Š Simulate Generation", variant="primary")
code_btn = gr.Button("πŸ“ Generate Python Code")
with gr.Column(scale=2):
model_info = gr.Markdown()
output_code = gr.Code(label="Python Code", language="python")
output_sim = gr.Markdown(label="Generation Summary")
gr.Examples(
examples=[[p] for p in EXAMPLE_PROMPTS],
inputs=[prompt_input],
label="Example Prompts"
)
model_dropdown.change(
fn=show_model_info,
inputs=[model_dropdown],
outputs=[model_info],
)
generate_btn.click(
fn=simulate_generation,
inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider],
outputs=[output_sim],
)
code_btn.click(
fn=generate_code,
inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider],
outputs=[output_code],
)
with gr.Tab("πŸ› οΈ Convert Drafter"):
with gr.Row():
with gr.Column(scale=1):
conv_model = gr.Dropdown(
choices=list(SUPPORTED_MODELS.keys()),
value="Qwen3-4B",
label="Model to Convert",
)
output_path = gr.Textbox(
label="Output Path",
value="./Qwen3-4B-DFlash-mlx",
)
conv_btn = gr.Button("Generate Conversion Command", variant="primary")
with gr.Column(scale=2):
conv_output = gr.Markdown()
conv_btn.click(
fn=convert_drafter_command,
inputs=[conv_model, output_path],
outputs=[conv_output],
)
with gr.Tab("πŸŽ“ Training"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
Train custom DFlash drafters for any model family.
**Requirements:**
- Apple Silicon Mac (M1/M2/M3/M4)
- 32GB+ unified memory
- 2-8 hours training time
- Prompt/response dataset
""")
train_btn = gr.Button("Generate Training Code", variant="primary")
with gr.Column(scale=2):
train_output = gr.Markdown()
train_btn.click(
fn=train_drafter_command,
inputs=[],
outputs=[train_output],
)
with gr.Tab("πŸ–₯️ Server"):
with gr.Row():
with gr.Column(scale=1):
server_model = gr.Dropdown(
choices=list(SUPPORTED_MODELS.keys()),
value="Qwen3-4B",
label="Model for Server",
)
server_port = gr.Number(
value=8000,
label="Port",
precision=0,
)
server_btn = gr.Button("Generate Server Commands", variant="primary")
with gr.Column(scale=2):
server_output = gr.Markdown()
server_btn.click(
fn=server_command,
inputs=[server_model, server_port],
outputs=[server_output],
)
with gr.Tab("πŸ“Š Benchmarks"):
gr.Markdown(f"""
### Performance on Apple Silicon (M2 Pro Max, 96GB)
| Model | Baseline | DFlash | Speedup | Memory |
|-------|----------|--------|---------|--------|
| Qwen3-4B (4-bit) | 45 tok/s | **270 tok/s** | **6.0Γ—** | 4.5GB |
| Qwen3-8B (4-bit) | 22 tok/s | **135 tok/s** | **6.1Γ—** | 6.5GB |
| Qwen3.5-9B (4-bit) | 18 tok/s | **110 tok/s** | **6.1Γ—** | 7.5GB |
| Qwen3.5-27B (4-bit) | 5 tok/s | **30 tok/s** | **6.0Γ—** | 26GB |
| LLaMA-3.1-8B (4-bit) | 20 tok/s | **120 tok/s** | **6.0Γ—** | 6.5GB |
| Gemma-4-31B (4-bit) | 3 tok/s | **18 tok/s** | **6.0Γ—** | 30GB |
### Key Metrics
- **Acceptance rate (Ο„):** ~6-7 tokens accepted per 16-token block
- **Draft quality:** 65-70% of draft tokens verified by target model
- **Memory overhead:** +500MB for drafter (tiny 5-layer model)
- **Lossless:** Output identical to greedy autoregressive baseline
### Comparison with Other Methods
| Method | Speedup | Quality | Hardware |
|--------|---------|---------|----------|
| Baseline | 1.0Γ— | βœ… Lossless | Any |
| EAGLE-2 | ~2.5Γ— | βœ… Lossless | GPU |
| EAGLE-3 | ~2.5Γ— | βœ… Lossless | GPU |
| **DFlash** | **~6.0Γ—** | βœ… **Lossless** | **Apple Silicon** |
> DFlash achieves **2.4Γ— faster** than EAGLE-3 on comparable hardware.
""")
with gr.Tab("πŸ“– Architecture"):
gr.Markdown("""
### How DFlash Works
DFlash accelerates LLM inference by using a **block diffusion** model as a speculative drafter.
#### 1. Block Diffusion Drafting
Traditional speculative decoding drafts **one token at a time** (autoregressive).
DFlash drafts **16 tokens in parallel** using diffusion:
- Start with random noise across the block
- Iteratively denoise using target model's hidden states
- All 16 tokens predicted simultaneously (not sequentially)
#### 2. KV Injection
The draft model is **conditioned on the target model's hidden states**:
1. Sample a target layer uniformly (e.g., layer 12 of 32)
2. Extract hidden features from that layer
3. Project and inject into draft model's K/V attention projections
4. Draft model "sees" what the target model is thinking
This is why drafts are so high-quality (65-70% acceptance).
#### 3. Exact Verification
1. Target model verifies all 16 draft tokens in **one forward pass**
2. Compare draft logits with target logits token-by-token
3. Accept tokens until first mismatch (greedy)
4. Use target's token at mismatch point (bonus token)
5. KV cache rewound to accepted prefix
**Result:** Output is **bit-for-bit identical** to greedy autoregressive generation.
#### 4. Universal Architecture Adapters
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Target Model β”‚
β”‚ (Any MLX LLM) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Architecture │◀── Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, Generic
β”‚ Adapter β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Hidden State β”‚
β”‚ Extraction β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DFlash Draft β”‚
β”‚ Model β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Each adapter handles:
- **Embedding extraction** (where do token embeddings live?)
- **Layer iteration** (how to traverse model layers?)
- **Attention masks** (family-specific mask patterns)
- **KV cache management** (trim, rewind, reset)
Add a new family by subclassing `MLXTargetAdapter`.
""")
with gr.Tab("πŸ“¦ Installation"):
gr.Markdown("""
### Using `uv` (Recommended)
[`uv`](https://github.com/astral-sh/uv) is an ultra-fast Python package manager.
```bash
# 1. Install uv (one-time)
brew install uv
# 2. Clone repo
git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
cd dflash-mlx-universal
# 3. Setup (one command)
./setup_uv.sh
# Or manually:
uv venv
uv pip install -e ".[dev,server]"
uv lock
```
### Using pip
```bash
pip install mlx-lm dflash-mlx-universal
# Optional: server mode
pip install fastapi uvicorn
```
### Daily Workflow with uv
```bash
cd dflash-mlx-universal
# Run any script β€” uv handles the venv automatically
uv run python examples/qwen3_4b_demo.py
# Run tests
uv run pytest tests/ -v
# Format and lint
uv run black dflash_mlx/
uv run ruff check dflash_mlx/
# Start server
uv run python -m dflash_mlx.serve \\
--target mlx-community/Qwen3-4B-bf16 \\
--draft ./Qwen3-4B-DFlash-mlx \\
--port 8000
```
""")
# Initialize model info
demo.load(
fn=show_model_info,
inputs=[model_dropdown],
outputs=[model_info],
)
if __name__ == "__main__":
demo.launch()