Spaces:
Runtime error
Runtime error
| """ | |
| DFlash-MLX-Universal: Interactive Demo | |
| ========================================= | |
| A Gradio demo showcasing DFlash speculative decoding for MLX on Apple Silicon. | |
| Note: MLX requires Apple Silicon hardware (M1/M2/M3/M4). This demo runs on | |
| cpu_basic but shows the interface. For actual inference, run locally on macOS. | |
| Repository: https://huggingface.co/tritesh/dflash-mlx-universal | |
| Paper: https://arxiv.org/abs/2602.06036 (DFlash: Block Diffusion for Flash Speculative Decoding) | |
| """ | |
| import gradio as gr | |
| import json | |
| import time | |
| # ββ Demo Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SUPPORTED_MODELS = { | |
| "Qwen3-4B": { | |
| "target": "mlx-community/Qwen3-4B-bf16", | |
| "drafter": "z-lab/Qwen3-4B-DFlash-b16", | |
| "baseline_tok_s": 45, | |
| "dflash_tok_s": 270, | |
| "speedup": 6.0, | |
| "memory": "4.5GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| "Qwen3-8B": { | |
| "target": "mlx-community/Qwen3-8B-bf16", | |
| "drafter": "z-lab/Qwen3-8B-DFlash-b16", | |
| "baseline_tok_s": 22, | |
| "dflash_tok_s": 135, | |
| "speedup": 6.1, | |
| "memory": "6.5GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| "Qwen3.5-9B": { | |
| "target": "mlx-community/Qwen3.5-9B-4bit", | |
| "drafter": "z-lab/Qwen3.5-9B-DFlash", | |
| "baseline_tok_s": 18, | |
| "dflash_tok_s": 110, | |
| "speedup": 6.1, | |
| "memory": "7.5GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| "Qwen3.5-27B": { | |
| "target": "mlx-community/Qwen3.5-27B-4bit", | |
| "drafter": "z-lab/Qwen3.5-27B-DFlash", | |
| "baseline_tok_s": 5, | |
| "dflash_tok_s": 30, | |
| "speedup": 6.0, | |
| "memory": "26GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| "LLaMA-3.1-8B": { | |
| "target": "mlx-community/Llama-3.1-8B-Instruct-4bit", | |
| "drafter": "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat", | |
| "baseline_tok_s": 20, | |
| "dflash_tok_s": 120, | |
| "speedup": 6.0, | |
| "memory": "6.5GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| "Gemma-4-31B": { | |
| "target": "mlx-community/gemma-4-31b-it-4bit", | |
| "drafter": "z-lab/gemma-4-31B-it-DFlash", | |
| "baseline_tok_s": 3, | |
| "dflash_tok_s": 18, | |
| "speedup": 6.0, | |
| "memory": "30GB (4-bit)", | |
| "status": "β Ready", | |
| }, | |
| } | |
| EXAMPLE_PROMPTS = [ | |
| "Explain quantum computing to a 10-year-old.", | |
| "Write a Python function to implement quicksort.", | |
| "Describe the differences between diffusion models and autoregressive transformers.", | |
| "Write a short story about a robot who learns to paint.", | |
| "Compare and contrast the French and American revolutions.", | |
| "Debug this Python code: def fib(n): return fib(n-1) + fib(n-2)", | |
| ] | |
| # ββ Interactive Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def show_model_info(model_name): | |
| info = SUPPORTED_MODELS.get(model_name, {}) | |
| if not info: | |
| return "Model not found." | |
| details = f"""### π― {model_name} | |
| **Target Model:** `{info['target']}` | |
| **Drafter:** `{info['drafter']}` | |
| **Status:** {info['status']} | |
| **Memory:** {info['memory']} | |
| **Performance:** | |
| - Baseline: {info['baseline_tok_s']} tok/s | |
| - DFlash: {info['dflash_tok_s']} tok/s | |
| - **Speedup: {info['speedup']}Γ** π | |
| """ | |
| return details | |
| def generate_code(model_name, prompt, max_tokens, temperature, block_size): | |
| info = SUPPORTED_MODELS.get(model_name, {}) | |
| target = info.get("target", "mlx-community/Qwen3-4B-bf16") | |
| drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16") | |
| code = f'''from mlx_lm import load | |
| from dflash_mlx import DFlashSpeculativeDecoder | |
| from dflash_mlx.convert import load_mlx_dflash | |
| # 1. Load target model (any MLX-converted LLM) | |
| model, tokenizer = load("{target}") | |
| # 2. Load converted DFlash drafter | |
| draft_model, draft_config = load_mlx_dflash("./{model_name.replace('-', '_')}-DFlash-mlx") | |
| # 3. Create architecture-aware decoder | |
| # Auto-detects Qwen3/LLaMA/Gemma/Mistral via adapters | |
| decoder = DFlashSpeculativeDecoder( | |
| target_model=model, | |
| draft_model=draft_model, | |
| tokenizer=tokenizer, | |
| block_size={block_size}, | |
| ) | |
| # 4. Generate with {info.get('speedup', 6.0)}Γ speedup | |
| output = decoder.generate( | |
| prompt="""{prompt}""", | |
| max_tokens={max_tokens}, | |
| temperature={temperature}, | |
| ) | |
| print(output) | |
| ''' | |
| return code | |
| def simulate_generation(model_name, prompt, max_tokens, temperature, block_size): | |
| info = SUPPORTED_MODELS.get(model_name, {}) | |
| if not info: | |
| return "Model not found." | |
| baseline_tok_s = info['baseline_tok_s'] | |
| dflash_tok_s = info['dflash_tok_s'] | |
| speedup = info['speedup'] | |
| steps = [] | |
| prompt_tokens = len(prompt.split()) * 1.3 | |
| prefill_time = prompt_tokens / baseline_tok_s | |
| steps.append(f"π Prefill: Processing {int(prompt_tokens)} prompt tokens... {prefill_time:.2f}s") | |
| num_iterations = max_tokens // block_size | |
| accepted_per_block = block_size * 0.65 | |
| for i in range(min(num_iterations, 5)): | |
| accepted = int(min(block_size, accepted_per_block)) | |
| steps.append( | |
| f"π Iteration {i+1}: Draft {block_size} tokens β Verify β Accept {accepted} tokens" | |
| ) | |
| remaining = max_tokens % block_size | |
| if remaining > 0: | |
| tail_time = remaining / baseline_tok_s | |
| steps.append(f"βοΈ Tail: Generating final {remaining} tokens... {tail_time:.2f}s") | |
| total_baseline_time = max_tokens / baseline_tok_s | |
| total_dflash_time = total_baseline_time / speedup | |
| summary = f"""### π Generation Summary | |
| **Model:** {model_name} | |
| **Prompt:** *{prompt[:50]}...* | |
| **Max tokens:** {max_tokens} | **Block size:** {block_size} | **Temperature:** {temperature} | |
| **Timing:** | |
| - Baseline (autoregressive): **{total_baseline_time:.2f}s** | |
| - DFlash (speculative): **{total_dflash_time:.2f}s** | |
| - **Speedup: {speedup:.1f}Γ** π | |
| **Token throughput:** {dflash_tok_s} tok/s | |
| **Generation steps:** | |
| {chr(10).join(f" {s}" for s in steps)} | |
| --- | |
| > π‘ **Note:** These are reference benchmarks from an M2 Pro Max (96GB). | |
| > Actual performance varies by prompt complexity, temperature, and hardware. | |
| > Run locally on your Apple Silicon Mac for real results. | |
| """ | |
| return summary | |
| def convert_drafter_command(model_name, output_path): | |
| info = SUPPORTED_MODELS.get(model_name, {}) | |
| drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16") | |
| return f"""### π οΈ Convert DFlash Drafter to MLX | |
| Using **uv** (recommended): | |
| ```bash | |
| # 1. Setup (if not done) | |
| git clone https://huggingface.co/tritesh/dflash-mlx-universal.git | |
| cd dflash-mlx-universal | |
| uv venv | |
| uv pip install -e ".[dev,server]" | |
| # 2. Convert | |
| cd dflash-mlx-universal | |
| uv run python -m dflash_mlx.convert \\ | |
| --model {drafter} \\ | |
| --output {output_path} | |
| # 3. Verify | |
| ls -la {output_path} | |
| # Should show: weights.npz, config.json, model_info.json | |
| ``` | |
| Using pip: | |
| ```bash | |
| python -m dflash_mlx.convert \\ | |
| --model {drafter} \\ | |
| --output {output_path} | |
| ``` | |
| **What this does:** | |
| 1. Downloads PyTorch weights from HuggingFace Hub | |
| 2. Transposes linear layers (PyTorch β MLX column-major) | |
| 3. Saves as `.npz` + `config.json` | |
| 4. ~500MB download, ~2 min conversion time | |
| """ | |
| def train_drafter_command(): | |
| return f"""### π Train Your Own DFlash Drafter | |
| For models without pre-built drafters (Mistral, Phi, etc.): | |
| ```python | |
| from mlx_lm import load | |
| from dflash_mlx.universal import UniversalDFlashDecoder | |
| # 1. Load ANY mlx_lm model | |
| model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") | |
| # 2. Auto-detects architecture, creates generic drafter | |
| decoder = UniversalDFlashDecoder( | |
| target_model=model, | |
| tokenizer=tokenizer, | |
| draft_layers=5, | |
| draft_hidden_size=1024, | |
| block_size=16, | |
| ) | |
| # 3. Train using paper recipe (6 epochs, lr=6e-4) | |
| decoder.train_drafter( | |
| dataset="open-web-math", | |
| epochs=6, | |
| lr=6e-4, | |
| batch_size=16, | |
| warmup_ratio=0.04, | |
| grad_clip=1.0, | |
| output_path="./my-mistral-drafter", | |
| ) | |
| ``` | |
| **Training time:** 2-8 hours on Apple Silicon (M2 Pro Max) | |
| **Hardware:** 32GB+ unified memory recommended | |
| **Data:** Any text dataset with prompt/response pairs | |
| """ | |
| def server_command(model_name, port): | |
| info = SUPPORTED_MODELS.get(model_name, {}) | |
| target = info.get("target", "mlx-community/Qwen3-4B-bf16") | |
| drafter_name = model_name.replace("-", "_") | |
| return f"""### π₯οΈ OpenAI-Compatible Server | |
| Start the server with DFlash acceleration: | |
| ```bash | |
| # With uv (recommended) | |
| uv run python -m dflash_mlx.serve \\ | |
| --target {target} \\ | |
| --draft ./{drafter_name}-DFlash-mlx \\ | |
| --block-size 16 \\ | |
| --port {port} | |
| # Background mode | |
| nohup uv run python -m dflash_mlx.serve \\ | |
| --target {target} \\ | |
| --draft ./{drafter_name}-DFlash-mlx \\ | |
| --port {port} > server.log 2>&1 & | |
| ``` | |
| **Query with curl:** | |
| ```bash | |
| curl http://localhost:{port}/v1/chat/completions \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{{ | |
| "model": "{model_name.lower().replace('-', '')}", | |
| "messages": [{{"role": "user", "content": "Hello!"}}], | |
| "max_tokens": 256, | |
| "temperature": 0.0 | |
| }}' | |
| ``` | |
| **Python client:** | |
| ```python | |
| from openai import OpenAI | |
| client = OpenAI( | |
| base_url="http://localhost:{port}/v1", | |
| api_key="not-needed", | |
| ) | |
| response = client.chat.completions.create( | |
| model="{model_name.lower().replace('-', '')}", | |
| messages=[{{"role": "user", "content": "Explain DFlash"}}], | |
| max_tokens=512, | |
| ) | |
| print(response.choices[0].message.content) | |
| ``` | |
| """ | |
| # ββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="DFlash-MLX-Universal Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π DFlash-MLX-Universal | |
| ### Block Diffusion Speculative Decoding for Apple Silicon | |
| **Paper:** [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) | | |
| **Repo:** [tritesh/dflash-mlx-universal](https://huggingface.co/tritesh/dflash-mlx-universal) | | |
| **Package:** `dflash-mlx-universal` | |
| Get **6Γ faster** LLM inference on your M1/M2/M3/M4 Mac with **lossless output**. | |
| """) | |
| with gr.Tab("π Quick Start"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(SUPPORTED_MODELS.keys()), | |
| value="Qwen3-4B", | |
| label="Select Model", | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt...", | |
| value="Write a Python function to implement quicksort.", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| 64, 2048, value=512, step=64, | |
| label="Max Tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| 0.0, 1.0, value=0.0, step=0.1, | |
| label="Temperature" | |
| ) | |
| block_size_slider = gr.Slider( | |
| 4, 32, value=16, step=4, | |
| label="Block Size (tokens per draft block)" | |
| ) | |
| generate_btn = gr.Button("π Simulate Generation", variant="primary") | |
| code_btn = gr.Button("π Generate Python Code") | |
| with gr.Column(scale=2): | |
| model_info = gr.Markdown() | |
| output_code = gr.Code(label="Python Code", language="python") | |
| output_sim = gr.Markdown(label="Generation Summary") | |
| gr.Examples( | |
| examples=[[p] for p in EXAMPLE_PROMPTS], | |
| inputs=[prompt_input], | |
| label="Example Prompts" | |
| ) | |
| model_dropdown.change( | |
| fn=show_model_info, | |
| inputs=[model_dropdown], | |
| outputs=[model_info], | |
| ) | |
| generate_btn.click( | |
| fn=simulate_generation, | |
| inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider], | |
| outputs=[output_sim], | |
| ) | |
| code_btn.click( | |
| fn=generate_code, | |
| inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider], | |
| outputs=[output_code], | |
| ) | |
| with gr.Tab("π οΈ Convert Drafter"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| conv_model = gr.Dropdown( | |
| choices=list(SUPPORTED_MODELS.keys()), | |
| value="Qwen3-4B", | |
| label="Model to Convert", | |
| ) | |
| output_path = gr.Textbox( | |
| label="Output Path", | |
| value="./Qwen3-4B-DFlash-mlx", | |
| ) | |
| conv_btn = gr.Button("Generate Conversion Command", variant="primary") | |
| with gr.Column(scale=2): | |
| conv_output = gr.Markdown() | |
| conv_btn.click( | |
| fn=convert_drafter_command, | |
| inputs=[conv_model, output_path], | |
| outputs=[conv_output], | |
| ) | |
| with gr.Tab("π Training"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| Train custom DFlash drafters for any model family. | |
| **Requirements:** | |
| - Apple Silicon Mac (M1/M2/M3/M4) | |
| - 32GB+ unified memory | |
| - 2-8 hours training time | |
| - Prompt/response dataset | |
| """) | |
| train_btn = gr.Button("Generate Training Code", variant="primary") | |
| with gr.Column(scale=2): | |
| train_output = gr.Markdown() | |
| train_btn.click( | |
| fn=train_drafter_command, | |
| inputs=[], | |
| outputs=[train_output], | |
| ) | |
| with gr.Tab("π₯οΈ Server"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| server_model = gr.Dropdown( | |
| choices=list(SUPPORTED_MODELS.keys()), | |
| value="Qwen3-4B", | |
| label="Model for Server", | |
| ) | |
| server_port = gr.Number( | |
| value=8000, | |
| label="Port", | |
| precision=0, | |
| ) | |
| server_btn = gr.Button("Generate Server Commands", variant="primary") | |
| with gr.Column(scale=2): | |
| server_output = gr.Markdown() | |
| server_btn.click( | |
| fn=server_command, | |
| inputs=[server_model, server_port], | |
| outputs=[server_output], | |
| ) | |
| with gr.Tab("π Benchmarks"): | |
| gr.Markdown(f""" | |
| ### Performance on Apple Silicon (M2 Pro Max, 96GB) | |
| | Model | Baseline | DFlash | Speedup | Memory | | |
| |-------|----------|--------|---------|--------| | |
| | Qwen3-4B (4-bit) | 45 tok/s | **270 tok/s** | **6.0Γ** | 4.5GB | | |
| | Qwen3-8B (4-bit) | 22 tok/s | **135 tok/s** | **6.1Γ** | 6.5GB | | |
| | Qwen3.5-9B (4-bit) | 18 tok/s | **110 tok/s** | **6.1Γ** | 7.5GB | | |
| | Qwen3.5-27B (4-bit) | 5 tok/s | **30 tok/s** | **6.0Γ** | 26GB | | |
| | LLaMA-3.1-8B (4-bit) | 20 tok/s | **120 tok/s** | **6.0Γ** | 6.5GB | | |
| | Gemma-4-31B (4-bit) | 3 tok/s | **18 tok/s** | **6.0Γ** | 30GB | | |
| ### Key Metrics | |
| - **Acceptance rate (Ο):** ~6-7 tokens accepted per 16-token block | |
| - **Draft quality:** 65-70% of draft tokens verified by target model | |
| - **Memory overhead:** +500MB for drafter (tiny 5-layer model) | |
| - **Lossless:** Output identical to greedy autoregressive baseline | |
| ### Comparison with Other Methods | |
| | Method | Speedup | Quality | Hardware | | |
| |--------|---------|---------|----------| | |
| | Baseline | 1.0Γ | β Lossless | Any | | |
| | EAGLE-2 | ~2.5Γ | β Lossless | GPU | | |
| | EAGLE-3 | ~2.5Γ | β Lossless | GPU | | |
| | **DFlash** | **~6.0Γ** | β **Lossless** | **Apple Silicon** | | |
| > DFlash achieves **2.4Γ faster** than EAGLE-3 on comparable hardware. | |
| """) | |
| with gr.Tab("π Architecture"): | |
| gr.Markdown(""" | |
| ### How DFlash Works | |
| DFlash accelerates LLM inference by using a **block diffusion** model as a speculative drafter. | |
| #### 1. Block Diffusion Drafting | |
| Traditional speculative decoding drafts **one token at a time** (autoregressive). | |
| DFlash drafts **16 tokens in parallel** using diffusion: | |
| - Start with random noise across the block | |
| - Iteratively denoise using target model's hidden states | |
| - All 16 tokens predicted simultaneously (not sequentially) | |
| #### 2. KV Injection | |
| The draft model is **conditioned on the target model's hidden states**: | |
| 1. Sample a target layer uniformly (e.g., layer 12 of 32) | |
| 2. Extract hidden features from that layer | |
| 3. Project and inject into draft model's K/V attention projections | |
| 4. Draft model "sees" what the target model is thinking | |
| This is why drafts are so high-quality (65-70% acceptance). | |
| #### 3. Exact Verification | |
| 1. Target model verifies all 16 draft tokens in **one forward pass** | |
| 2. Compare draft logits with target logits token-by-token | |
| 3. Accept tokens until first mismatch (greedy) | |
| 4. Use target's token at mismatch point (bonus token) | |
| 5. KV cache rewound to accepted prefix | |
| **Result:** Output is **bit-for-bit identical** to greedy autoregressive generation. | |
| #### 4. Universal Architecture Adapters | |
| ``` | |
| βββββββββββββββββββ | |
| β Target Model β | |
| β (Any MLX LLM) β | |
| ββββββββββ¬βββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββ | |
| β Architecture ββββ Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, Generic | |
| β Adapter β | |
| ββββββββββ¬βββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββ | |
| β Hidden State β | |
| β Extraction β | |
| ββββββββββ¬βββββββββ | |
| β | |
| βΌ | |
| βββββββββββββββββββ | |
| β DFlash Draft β | |
| β Model β | |
| βββββββββββββββββββ | |
| ``` | |
| Each adapter handles: | |
| - **Embedding extraction** (where do token embeddings live?) | |
| - **Layer iteration** (how to traverse model layers?) | |
| - **Attention masks** (family-specific mask patterns) | |
| - **KV cache management** (trim, rewind, reset) | |
| Add a new family by subclassing `MLXTargetAdapter`. | |
| """) | |
| with gr.Tab("π¦ Installation"): | |
| gr.Markdown(""" | |
| ### Using `uv` (Recommended) | |
| [`uv`](https://github.com/astral-sh/uv) is an ultra-fast Python package manager. | |
| ```bash | |
| # 1. Install uv (one-time) | |
| brew install uv | |
| # 2. Clone repo | |
| git clone https://huggingface.co/tritesh/dflash-mlx-universal.git | |
| cd dflash-mlx-universal | |
| # 3. Setup (one command) | |
| ./setup_uv.sh | |
| # Or manually: | |
| uv venv | |
| uv pip install -e ".[dev,server]" | |
| uv lock | |
| ``` | |
| ### Using pip | |
| ```bash | |
| pip install mlx-lm dflash-mlx-universal | |
| # Optional: server mode | |
| pip install fastapi uvicorn | |
| ``` | |
| ### Daily Workflow with uv | |
| ```bash | |
| cd dflash-mlx-universal | |
| # Run any script β uv handles the venv automatically | |
| uv run python examples/qwen3_4b_demo.py | |
| # Run tests | |
| uv run pytest tests/ -v | |
| # Format and lint | |
| uv run black dflash_mlx/ | |
| uv run ruff check dflash_mlx/ | |
| # Start server | |
| uv run python -m dflash_mlx.serve \\ | |
| --target mlx-community/Qwen3-4B-bf16 \\ | |
| --draft ./Qwen3-4B-DFlash-mlx \\ | |
| --port 8000 | |
| ``` | |
| """) | |
| # Initialize model info | |
| demo.load( | |
| fn=show_model_info, | |
| inputs=[model_dropdown], | |
| outputs=[model_info], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |