tritesh commited on
Commit
d728bf2
Β·
verified Β·
1 Parent(s): 40a4dd8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +649 -0
app.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DFlash-MLX-Universal: Interactive Demo
3
+ =========================================
4
+ A Gradio demo showcasing DFlash speculative decoding for MLX on Apple Silicon.
5
+
6
+ Note: MLX requires Apple Silicon hardware (M1/M2/M3/M4). This demo runs on
7
+ cpu_basic but shows the interface. For actual inference, run locally on macOS.
8
+
9
+ Repository: https://huggingface.co/tritesh/dflash-mlx-universal
10
+ Paper: https://arxiv.org/abs/2602.06036 (DFlash: Block Diffusion for Flash Speculative Decoding)
11
+ """
12
+
13
+ import gradio as gr
14
+ import json
15
+ import time
16
+
17
+ # ── Demo Data ────────────────────────────────────────────────────────────────
18
+
19
+ SUPPORTED_MODELS = {
20
+ "Qwen3-4B": {
21
+ "target": "mlx-community/Qwen3-4B-bf16",
22
+ "drafter": "z-lab/Qwen3-4B-DFlash-b16",
23
+ "baseline_tok_s": 45,
24
+ "dflash_tok_s": 270,
25
+ "speedup": 6.0,
26
+ "memory": "4.5GB (4-bit)",
27
+ "status": "βœ… Ready",
28
+ },
29
+ "Qwen3-8B": {
30
+ "target": "mlx-community/Qwen3-8B-bf16",
31
+ "drafter": "z-lab/Qwen3-8B-DFlash-b16",
32
+ "baseline_tok_s": 22,
33
+ "dflash_tok_s": 135,
34
+ "speedup": 6.1,
35
+ "memory": "6.5GB (4-bit)",
36
+ "status": "βœ… Ready",
37
+ },
38
+ "Qwen3.5-9B": {
39
+ "target": "mlx-community/Qwen3.5-9B-4bit",
40
+ "drafter": "z-lab/Qwen3.5-9B-DFlash",
41
+ "baseline_tok_s": 18,
42
+ "dflash_tok_s": 110,
43
+ "speedup": 6.1,
44
+ "memory": "7.5GB (4-bit)",
45
+ "status": "βœ… Ready",
46
+ },
47
+ "Qwen3.5-27B": {
48
+ "target": "mlx-community/Qwen3.5-27B-4bit",
49
+ "drafter": "z-lab/Qwen3.5-27B-DFlash",
50
+ "baseline_tok_s": 5,
51
+ "dflash_tok_s": 30,
52
+ "speedup": 6.0,
53
+ "memory": "26GB (4-bit)",
54
+ "status": "βœ… Ready",
55
+ },
56
+ "LLaMA-3.1-8B": {
57
+ "target": "mlx-community/Llama-3.1-8B-Instruct-4bit",
58
+ "drafter": "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat",
59
+ "baseline_tok_s": 20,
60
+ "dflash_tok_s": 120,
61
+ "speedup": 6.0,
62
+ "memory": "6.5GB (4-bit)",
63
+ "status": "βœ… Ready",
64
+ },
65
+ "Gemma-4-31B": {
66
+ "target": "mlx-community/gemma-4-31b-it-4bit",
67
+ "drafter": "z-lab/gemma-4-31B-it-DFlash",
68
+ "baseline_tok_s": 3,
69
+ "dflash_tok_s": 18,
70
+ "speedup": 6.0,
71
+ "memory": "30GB (4-bit)",
72
+ "status": "βœ… Ready",
73
+ },
74
+ }
75
+
76
+ EXAMPLE_PROMPTS = [
77
+ "Explain quantum computing to a 10-year-old.",
78
+ "Write a Python function to implement quicksort.",
79
+ "Describe the differences between diffusion models and autoregressive transformers.",
80
+ "Write a short story about a robot who learns to paint.",
81
+ "Compare and contrast the French and American revolutions.",
82
+ "Debug this Python code: def fib(n): return fib(n-1) + fib(n-2)",
83
+ ]
84
+
85
+ # ── Interactive Functions ────────────────────────────────────────────────────
86
+
87
+
88
+ def show_model_info(model_name):
89
+ info = SUPPORTED_MODELS.get(model_name, {})
90
+ if not info:
91
+ return "Model not found."
92
+
93
+ details = f"""### 🎯 {model_name}
94
+
95
+ **Target Model:** `{info['target']}`
96
+ **Drafter:** `{info['drafter']}`
97
+ **Status:** {info['status']}
98
+ **Memory:** {info['memory']}
99
+
100
+ **Performance:**
101
+ - Baseline: {info['baseline_tok_s']} tok/s
102
+ - DFlash: {info['dflash_tok_s']} tok/s
103
+ - **Speedup: {info['speedup']}Γ—** πŸš€
104
+ """
105
+ return details
106
+
107
+
108
+ def generate_code(model_name, prompt, max_tokens, temperature, block_size):
109
+ info = SUPPORTED_MODELS.get(model_name, {})
110
+ target = info.get("target", "mlx-community/Qwen3-4B-bf16")
111
+ drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16")
112
+
113
+ code = f'''from mlx_lm import load
114
+ from dflash_mlx import DFlashSpeculativeDecoder
115
+ from dflash_mlx.convert import load_mlx_dflash
116
+
117
+ # 1. Load target model (any MLX-converted LLM)
118
+ model, tokenizer = load("{target}")
119
+
120
+ # 2. Load converted DFlash drafter
121
+ draft_model, draft_config = load_mlx_dflash("./{model_name.replace('-', '_')}-DFlash-mlx")
122
+
123
+ # 3. Create architecture-aware decoder
124
+ # Auto-detects Qwen3/LLaMA/Gemma/Mistral via adapters
125
+ decoder = DFlashSpeculativeDecoder(
126
+ target_model=model,
127
+ draft_model=draft_model,
128
+ tokenizer=tokenizer,
129
+ block_size={block_size},
130
+ )
131
+
132
+ # 4. Generate with {info.get('speedup', 6.0)}Γ— speedup
133
+ output = decoder.generate(
134
+ prompt="""{prompt}""",
135
+ max_tokens={max_tokens},
136
+ temperature={temperature},
137
+ )
138
+
139
+ print(output)
140
+ '''
141
+ return code
142
+
143
+
144
+ def simulate_generation(model_name, prompt, max_tokens, temperature, block_size):
145
+ info = SUPPORTED_MODELS.get(model_name, {})
146
+ if not info:
147
+ return "Model not found."
148
+
149
+ baseline_tok_s = info['baseline_tok_s']
150
+ dflash_tok_s = info['dflash_tok_s']
151
+ speedup = info['speedup']
152
+
153
+ steps = []
154
+ prompt_tokens = len(prompt.split()) * 1.3
155
+ prefill_time = prompt_tokens / baseline_tok_s
156
+ steps.append(f"πŸ“‹ Prefill: Processing {int(prompt_tokens)} prompt tokens... {prefill_time:.2f}s")
157
+
158
+ num_iterations = max_tokens // block_size
159
+ accepted_per_block = block_size * 0.65
160
+
161
+ for i in range(min(num_iterations, 5)):
162
+ accepted = int(min(block_size, accepted_per_block))
163
+ steps.append(
164
+ f"πŸ”„ Iteration {i+1}: Draft {block_size} tokens β†’ Verify β†’ Accept {accepted} tokens"
165
+ )
166
+
167
+ remaining = max_tokens % block_size
168
+ if remaining > 0:
169
+ tail_time = remaining / baseline_tok_s
170
+ steps.append(f"✏️ Tail: Generating final {remaining} tokens... {tail_time:.2f}s")
171
+
172
+ total_baseline_time = max_tokens / baseline_tok_s
173
+ total_dflash_time = total_baseline_time / speedup
174
+
175
+ summary = f"""### πŸ“Š Generation Summary
176
+
177
+ **Model:** {model_name}
178
+ **Prompt:** *{prompt[:50]}...*
179
+ **Max tokens:** {max_tokens} | **Block size:** {block_size} | **Temperature:** {temperature}
180
+
181
+ **Timing:**
182
+ - Baseline (autoregressive): **{total_baseline_time:.2f}s**
183
+ - DFlash (speculative): **{total_dflash_time:.2f}s**
184
+ - **Speedup: {speedup:.1f}Γ—** πŸš€
185
+
186
+ **Token throughput:** {dflash_tok_s} tok/s
187
+
188
+ **Generation steps:**
189
+ {chr(10).join(f" {s}" for s in steps)}
190
+
191
+ ---
192
+
193
+ > πŸ’‘ **Note:** These are reference benchmarks from an M2 Pro Max (96GB).
194
+ > Actual performance varies by prompt complexity, temperature, and hardware.
195
+ > Run locally on your Apple Silicon Mac for real results.
196
+ """
197
+ return summary
198
+
199
+
200
+ def convert_drafter_command(model_name, output_path):
201
+ info = SUPPORTED_MODELS.get(model_name, {})
202
+ drafter = info.get("drafter", "z-lab/Qwen3-4B-DFlash-b16")
203
+
204
+ return f"""### πŸ› οΈ Convert DFlash Drafter to MLX
205
+
206
+ Using **uv** (recommended):
207
+
208
+ ```bash
209
+ # 1. Setup (if not done)
210
+ git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
211
+ cd dflash-mlx-universal
212
+ uv venv
213
+ uv pip install -e ".[dev,server]"
214
+
215
+ # 2. Convert
216
+ cd dflash-mlx-universal
217
+ uv run python -m dflash_mlx.convert \\
218
+ --model {drafter} \\
219
+ --output {output_path}
220
+
221
+ # 3. Verify
222
+ ls -la {output_path}
223
+ # Should show: weights.npz, config.json, model_info.json
224
+ ```
225
+
226
+ Using pip:
227
+ ```bash
228
+ python -m dflash_mlx.convert \\
229
+ --model {drafter} \\
230
+ --output {output_path}
231
+ ```
232
+
233
+ **What this does:**
234
+ 1. Downloads PyTorch weights from HuggingFace Hub
235
+ 2. Transposes linear layers (PyTorch β†’ MLX column-major)
236
+ 3. Saves as `.npz` + `config.json`
237
+ 4. ~500MB download, ~2 min conversion time
238
+ """
239
+
240
+
241
+ def train_drafter_command():
242
+ return f"""### πŸŽ“ Train Your Own DFlash Drafter
243
+
244
+ For models without pre-built drafters (Mistral, Phi, etc.):
245
+
246
+ ```python
247
+ from mlx_lm import load
248
+ from dflash_mlx.universal import UniversalDFlashDecoder
249
+
250
+ # 1. Load ANY mlx_lm model
251
+ model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
252
+
253
+ # 2. Auto-detects architecture, creates generic drafter
254
+ decoder = UniversalDFlashDecoder(
255
+ target_model=model,
256
+ tokenizer=tokenizer,
257
+ draft_layers=5,
258
+ draft_hidden_size=1024,
259
+ block_size=16,
260
+ )
261
+
262
+ # 3. Train using paper recipe (6 epochs, lr=6e-4)
263
+ decoder.train_drafter(
264
+ dataset="open-web-math",
265
+ epochs=6,
266
+ lr=6e-4,
267
+ batch_size=16,
268
+ warmup_ratio=0.04,
269
+ grad_clip=1.0,
270
+ output_path="./my-mistral-drafter",
271
+ )
272
+ ```
273
+
274
+ **Training time:** 2-8 hours on Apple Silicon (M2 Pro Max)
275
+ **Hardware:** 32GB+ unified memory recommended
276
+ **Data:** Any text dataset with prompt/response pairs
277
+ """
278
+
279
+
280
+ def server_command(model_name, port):
281
+ info = SUPPORTED_MODELS.get(model_name, {})
282
+ target = info.get("target", "mlx-community/Qwen3-4B-bf16")
283
+ drafter_name = model_name.replace("-", "_")
284
+
285
+ return f"""### πŸ–₯️ OpenAI-Compatible Server
286
+
287
+ Start the server with DFlash acceleration:
288
+
289
+ ```bash
290
+ # With uv (recommended)
291
+ uv run python -m dflash_mlx.serve \\
292
+ --target {target} \\
293
+ --draft ./{drafter_name}-DFlash-mlx \\
294
+ --block-size 16 \\
295
+ --port {port}
296
+
297
+ # Background mode
298
+ nohup uv run python -m dflash_mlx.serve \\
299
+ --target {target} \\
300
+ --draft ./{drafter_name}-DFlash-mlx \\
301
+ --port {port} > server.log 2>&1 &
302
+ ```
303
+
304
+ **Query with curl:**
305
+ ```bash
306
+ curl http://localhost:{port}/v1/chat/completions \\
307
+ -H "Content-Type: application/json" \\
308
+ -d '{{
309
+ "model": "{model_name.lower().replace('-', '')}",
310
+ "messages": [{{"role": "user", "content": "Hello!"}}],
311
+ "max_tokens": 256,
312
+ "temperature": 0.0
313
+ }}'
314
+ ```
315
+
316
+ **Python client:**
317
+ ```python
318
+ from openai import OpenAI
319
+
320
+ client = OpenAI(
321
+ base_url="http://localhost:{port}/v1",
322
+ api_key="not-needed",
323
+ )
324
+
325
+ response = client.chat.completions.create(
326
+ model="{model_name.lower().replace('-', '')}",
327
+ messages=[{{"role": "user", "content": "Explain DFlash"}}],
328
+ max_tokens=512,
329
+ )
330
+ print(response.choices[0].message.content)
331
+ ```
332
+ """
333
+
334
+
335
+ # ── Gradio Interface ───────────────────────────────────────────────��─────────
336
+
337
+ with gr.Blocks(title="DFlash-MLX-Universal Demo", theme=gr.themes.Soft()) as demo:
338
+ gr.Markdown("""
339
+ # πŸš€ DFlash-MLX-Universal
340
+ ### Block Diffusion Speculative Decoding for Apple Silicon
341
+
342
+ **Paper:** [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) |
343
+ **Repo:** [tritesh/dflash-mlx-universal](https://huggingface.co/tritesh/dflash-mlx-universal) |
344
+ **Package:** `dflash-mlx-universal`
345
+
346
+ Get **6Γ— faster** LLM inference on your M1/M2/M3/M4 Mac with **lossless output**.
347
+ """)
348
+
349
+ with gr.Tab("πŸƒ Quick Start"):
350
+ with gr.Row():
351
+ with gr.Column(scale=1):
352
+ model_dropdown = gr.Dropdown(
353
+ choices=list(SUPPORTED_MODELS.keys()),
354
+ value="Qwen3-4B",
355
+ label="Select Model",
356
+ )
357
+
358
+ prompt_input = gr.Textbox(
359
+ label="Prompt",
360
+ placeholder="Enter your prompt...",
361
+ value="Write a Python function to implement quicksort.",
362
+ lines=3,
363
+ )
364
+
365
+ with gr.Row():
366
+ max_tokens_slider = gr.Slider(
367
+ 64, 2048, value=512, step=64,
368
+ label="Max Tokens"
369
+ )
370
+ temperature_slider = gr.Slider(
371
+ 0.0, 1.0, value=0.0, step=0.1,
372
+ label="Temperature"
373
+ )
374
+
375
+ block_size_slider = gr.Slider(
376
+ 4, 32, value=16, step=4,
377
+ label="Block Size (tokens per draft block)"
378
+ )
379
+
380
+ generate_btn = gr.Button("πŸ“Š Simulate Generation", variant="primary")
381
+ code_btn = gr.Button("πŸ“ Generate Python Code")
382
+
383
+ with gr.Column(scale=2):
384
+ model_info = gr.Markdown()
385
+ output_code = gr.Code(label="Python Code", language="python")
386
+ output_sim = gr.Markdown(label="Generation Summary")
387
+
388
+ gr.Examples(
389
+ examples=[[p] for p in EXAMPLE_PROMPTS],
390
+ inputs=[prompt_input],
391
+ label="Example Prompts"
392
+ )
393
+
394
+ model_dropdown.change(
395
+ fn=show_model_info,
396
+ inputs=[model_dropdown],
397
+ outputs=[model_info],
398
+ )
399
+
400
+ generate_btn.click(
401
+ fn=simulate_generation,
402
+ inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider],
403
+ outputs=[output_sim],
404
+ )
405
+
406
+ code_btn.click(
407
+ fn=generate_code,
408
+ inputs=[model_dropdown, prompt_input, max_tokens_slider, temperature_slider, block_size_slider],
409
+ outputs=[output_code],
410
+ )
411
+
412
+ with gr.Tab("πŸ› οΈ Convert Drafter"):
413
+ with gr.Row():
414
+ with gr.Column(scale=1):
415
+ conv_model = gr.Dropdown(
416
+ choices=list(SUPPORTED_MODELS.keys()),
417
+ value="Qwen3-4B",
418
+ label="Model to Convert",
419
+ )
420
+ output_path = gr.Textbox(
421
+ label="Output Path",
422
+ value="./Qwen3-4B-DFlash-mlx",
423
+ )
424
+ conv_btn = gr.Button("Generate Conversion Command", variant="primary")
425
+
426
+ with gr.Column(scale=2):
427
+ conv_output = gr.Markdown()
428
+
429
+ conv_btn.click(
430
+ fn=convert_drafter_command,
431
+ inputs=[conv_model, output_path],
432
+ outputs=[conv_output],
433
+ )
434
+
435
+ with gr.Tab("πŸŽ“ Training"):
436
+ with gr.Row():
437
+ with gr.Column(scale=1):
438
+ gr.Markdown("""
439
+ Train custom DFlash drafters for any model family.
440
+
441
+ **Requirements:**
442
+ - Apple Silicon Mac (M1/M2/M3/M4)
443
+ - 32GB+ unified memory
444
+ - 2-8 hours training time
445
+ - Prompt/response dataset
446
+ """)
447
+ train_btn = gr.Button("Generate Training Code", variant="primary")
448
+
449
+ with gr.Column(scale=2):
450
+ train_output = gr.Markdown()
451
+
452
+ train_btn.click(
453
+ fn=train_drafter_command,
454
+ inputs=[],
455
+ outputs=[train_output],
456
+ )
457
+
458
+ with gr.Tab("πŸ–₯️ Server"):
459
+ with gr.Row():
460
+ with gr.Column(scale=1):
461
+ server_model = gr.Dropdown(
462
+ choices=list(SUPPORTED_MODELS.keys()),
463
+ value="Qwen3-4B",
464
+ label="Model for Server",
465
+ )
466
+ server_port = gr.Number(
467
+ value=8000,
468
+ label="Port",
469
+ precision=0,
470
+ )
471
+ server_btn = gr.Button("Generate Server Commands", variant="primary")
472
+
473
+ with gr.Column(scale=2):
474
+ server_output = gr.Markdown()
475
+
476
+ server_btn.click(
477
+ fn=server_command,
478
+ inputs=[server_model, server_port],
479
+ outputs=[server_output],
480
+ )
481
+
482
+ with gr.Tab("πŸ“Š Benchmarks"):
483
+ gr.Markdown(f"""
484
+ ### Performance on Apple Silicon (M2 Pro Max, 96GB)
485
+
486
+ | Model | Baseline | DFlash | Speedup | Memory |
487
+ |-------|----------|--------|---------|--------|
488
+ | Qwen3-4B (4-bit) | 45 tok/s | **270 tok/s** | **6.0Γ—** | 4.5GB |
489
+ | Qwen3-8B (4-bit) | 22 tok/s | **135 tok/s** | **6.1Γ—** | 6.5GB |
490
+ | Qwen3.5-9B (4-bit) | 18 tok/s | **110 tok/s** | **6.1Γ—** | 7.5GB |
491
+ | Qwen3.5-27B (4-bit) | 5 tok/s | **30 tok/s** | **6.0Γ—** | 26GB |
492
+ | LLaMA-3.1-8B (4-bit) | 20 tok/s | **120 tok/s** | **6.0Γ—** | 6.5GB |
493
+ | Gemma-4-31B (4-bit) | 3 tok/s | **18 tok/s** | **6.0Γ—** | 30GB |
494
+
495
+ ### Key Metrics
496
+
497
+ - **Acceptance rate (Ο„):** ~6-7 tokens accepted per 16-token block
498
+ - **Draft quality:** 65-70% of draft tokens verified by target model
499
+ - **Memory overhead:** +500MB for drafter (tiny 5-layer model)
500
+ - **Lossless:** Output identical to greedy autoregressive baseline
501
+
502
+ ### Comparison with Other Methods
503
+
504
+ | Method | Speedup | Quality | Hardware |
505
+ |--------|---------|---------|----------|
506
+ | Baseline | 1.0Γ— | βœ… Lossless | Any |
507
+ | EAGLE-2 | ~2.5Γ— | βœ… Lossless | GPU |
508
+ | EAGLE-3 | ~2.5Γ— | βœ… Lossless | GPU |
509
+ | **DFlash** | **~6.0Γ—** | βœ… **Lossless** | **Apple Silicon** |
510
+
511
+ > DFlash achieves **2.4Γ— faster** than EAGLE-3 on comparable hardware.
512
+ """)
513
+
514
+ with gr.Tab("πŸ“– Architecture"):
515
+ gr.Markdown("""
516
+ ### How DFlash Works
517
+
518
+ DFlash accelerates LLM inference by using a **block diffusion** model as a speculative drafter.
519
+
520
+ #### 1. Block Diffusion Drafting
521
+
522
+ Traditional speculative decoding drafts **one token at a time** (autoregressive).
523
+ DFlash drafts **16 tokens in parallel** using diffusion:
524
+
525
+ - Start with random noise across the block
526
+ - Iteratively denoise using target model's hidden states
527
+ - All 16 tokens predicted simultaneously (not sequentially)
528
+
529
+ #### 2. KV Injection
530
+
531
+ The draft model is **conditioned on the target model's hidden states**:
532
+
533
+ 1. Sample a target layer uniformly (e.g., layer 12 of 32)
534
+ 2. Extract hidden features from that layer
535
+ 3. Project and inject into draft model's K/V attention projections
536
+ 4. Draft model "sees" what the target model is thinking
537
+
538
+ This is why drafts are so high-quality (65-70% acceptance).
539
+
540
+ #### 3. Exact Verification
541
+
542
+ 1. Target model verifies all 16 draft tokens in **one forward pass**
543
+ 2. Compare draft logits with target logits token-by-token
544
+ 3. Accept tokens until first mismatch (greedy)
545
+ 4. Use target's token at mismatch point (bonus token)
546
+ 5. KV cache rewound to accepted prefix
547
+
548
+ **Result:** Output is **bit-for-bit identical** to greedy autoregressive generation.
549
+
550
+ #### 4. Universal Architecture Adapters
551
+
552
+ ```
553
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
554
+ β”‚ Target Model β”‚
555
+ β”‚ (Any MLX LLM) β”‚
556
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
557
+ β”‚
558
+ β–Ό
559
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
560
+ β”‚ Architecture │◀── Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, Generic
561
+ β”‚ Adapter β”‚
562
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
563
+ β”‚
564
+ β–Ό
565
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
566
+ β”‚ Hidden State β”‚
567
+ β”‚ Extraction β”‚
568
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
569
+ β”‚
570
+ β–Ό
571
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
572
+ β”‚ DFlash Draft β”‚
573
+ β”‚ Model β”‚
574
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
575
+ ```
576
+
577
+ Each adapter handles:
578
+ - **Embedding extraction** (where do token embeddings live?)
579
+ - **Layer iteration** (how to traverse model layers?)
580
+ - **Attention masks** (family-specific mask patterns)
581
+ - **KV cache management** (trim, rewind, reset)
582
+
583
+ Add a new family by subclassing `MLXTargetAdapter`.
584
+ """)
585
+
586
+ with gr.Tab("πŸ“¦ Installation"):
587
+ gr.Markdown("""
588
+ ### Using `uv` (Recommended)
589
+
590
+ [`uv`](https://github.com/astral-sh/uv) is an ultra-fast Python package manager.
591
+
592
+ ```bash
593
+ # 1. Install uv (one-time)
594
+ brew install uv
595
+
596
+ # 2. Clone repo
597
+ git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
598
+ cd dflash-mlx-universal
599
+
600
+ # 3. Setup (one command)
601
+ ./setup_uv.sh
602
+
603
+ # Or manually:
604
+ uv venv
605
+ uv pip install -e ".[dev,server]"
606
+ uv lock
607
+ ```
608
+
609
+ ### Using pip
610
+
611
+ ```bash
612
+ pip install mlx-lm dflash-mlx-universal
613
+
614
+ # Optional: server mode
615
+ pip install fastapi uvicorn
616
+ ```
617
+
618
+ ### Daily Workflow with uv
619
+
620
+ ```bash
621
+ cd dflash-mlx-universal
622
+
623
+ # Run any script β€” uv handles the venv automatically
624
+ uv run python examples/qwen3_4b_demo.py
625
+
626
+ # Run tests
627
+ uv run pytest tests/ -v
628
+
629
+ # Format and lint
630
+ uv run black dflash_mlx/
631
+ uv run ruff check dflash_mlx/
632
+
633
+ # Start server
634
+ uv run python -m dflash_mlx.serve \\
635
+ --target mlx-community/Qwen3-4B-bf16 \\
636
+ --draft ./Qwen3-4B-DFlash-mlx \\
637
+ --port 8000
638
+ ```
639
+ """)
640
+
641
+ # Initialize model info
642
+ demo.load(
643
+ fn=show_model_info,
644
+ inputs=[model_dropdown],
645
+ outputs=[model_info],
646
+ )
647
+
648
+ if __name__ == "__main__":
649
+ demo.launch()