tritesh commited on
Commit
1e567d1
Β·
verified Β·
1 Parent(s): 5044c56

Upload USAGE_GUIDE.md

Browse files
Files changed (1) hide show
  1. USAGE_GUIDE.md +429 -0
USAGE_GUIDE.md ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash-MLX-Universal: System Usage Guide
2
+
3
+ > How to use `dflash-mlx-universal` on your Apple Silicon Mac (M1/M2/M3/M4)
4
+
5
+ ---
6
+
7
+ ## πŸ“‹ Prerequisites
8
+
9
+ | Requirement | Version | Notes |
10
+ |------------|---------|-------|
11
+ | macOS | 14+ (Sonoma/Sequoia) | MLX requires Apple Silicon |
12
+ | Python | 3.9 - 3.12 | Recommend 3.11 or 3.12 |
13
+ | Chip | M1/M2/M3/M4 (Pro/Max/Ultra) | Unified memory required for large models |
14
+ | Memory | 16GB+ minimum, 32GB+ recommended | 96GB for 70B+ models |
15
+
16
+ ---
17
+
18
+ ## 1️⃣ Installation
19
+
20
+ ```bash
21
+ # 1. Create a virtual environment (recommended)
22
+ python3 -m venv .venv-dflash
23
+ source .venv-dflash/bin/activate # On zsh/bash
24
+
25
+ # 2. Upgrade pip
26
+ pip install --upgrade pip
27
+
28
+ # 3. Install core dependencies
29
+ pip install mlx-lm>=0.24.0 transformers>=4.57.0 huggingface-hub>=0.25.0
30
+
31
+ # 4. Install dflash-mlx-universal from your repo
32
+ pip install git+https://huggingface.co/tritesh/dflash-mlx-universal.git
33
+
34
+ # Optional: server mode
35
+ pip install fastapi uvicorn
36
+ ```
37
+
38
+ ### Alternative: Install from local clone
39
+
40
+ ```bash
41
+ git clone https://huggingface.co/tritesh/dflash-mlx-universal.git
42
+ cd dflash-mlx-universal
43
+ pip install -e .
44
+ ```
45
+
46
+ ---
47
+
48
+ ## 2️⃣ Quick Start β€” Using a Pre-converted Drafter
49
+
50
+ ### Step A: Convert an Official DFlash Drafter to MLX
51
+
52
+ Official drafters are PyTorch models. You need to convert them to MLX format once:
53
+
54
+ ```bash
55
+ # Convert Qwen3-4B drafter (~2-4 minutes on M2 Pro Max)
56
+ python -m dflash_mlx.convert \
57
+ --model z-lab/Qwen3-4B-DFlash-b16 \
58
+ --output ~/models/dflash/Qwen3-4B-DFlash-mlx
59
+
60
+ # Convert Qwen3.5-9B drafter
61
+ python -m dflash_mlx.convert \
62
+ --model z-lab/Qwen3.5-9B-DFlash \
63
+ --output ~/models/dflash/Qwen3.5-9B-DFlash-mlx
64
+
65
+ # Convert LLaMA-3.1-8B drafter
66
+ python -m dflash_mlx.convert \
67
+ --model z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat \
68
+ --output ~/models/dflash/LLaMA3.1-8B-DFlash-mlx
69
+ ```
70
+
71
+ **What this does:**
72
+ - Downloads PyTorch weights from HF Hub
73
+ - Transposes linear layers (PyTorch β†’ MLX format)
74
+ - Saves as `weights.npz` + `config.json`
75
+ - Creates `model_info.json` with target model mapping
76
+
77
+ ---
78
+
79
+ ### Step B: Generate with DFlash Speculative Decoding
80
+
81
+ ```python
82
+ from mlx_lm import load
83
+ from dflash_mlx import DFlashSpeculativeDecoder
84
+ from dflash_mlx.convert import load_mlx_dflash
85
+
86
+ # 1. Load target model (any MLX-converted model)
87
+ model, tokenizer = load("mlx-community/Qwen3-4B-bf16")
88
+
89
+ # 2. Load converted DFlash drafter
90
+ draft_model, draft_config = load_mlx_dflash("~/models/dflash/Qwen3-4B-DFlash-mlx")
91
+
92
+ # 3. Create decoder (auto-detects architecture via adapters)
93
+ decoder = DFlashSpeculativeDecoder(
94
+ target_model=model,
95
+ draft_model=draft_model,
96
+ tokenizer=tokenizer,
97
+ block_size=draft_config.get("block_size", 16),
98
+ )
99
+
100
+ # 4. Generate with 6Γ— speedup
101
+ output = decoder.generate(
102
+ prompt="Write a Python function to implement quicksort.",
103
+ max_tokens=1024,
104
+ temperature=0.0, # Greedy for exact reproduction
105
+ )
106
+ print(output)
107
+ ```
108
+
109
+ **Expected output:**
110
+ ```
111
+ [DFlash] Prefill: processing 12 prompt tokens...
112
+ [DFlash] Starting speculative decoding (block_size=16)...
113
+ [DFlash] Done. Generated 1024 tokens, avg acceptance: 6.23, effective speedup: ~5.8x
114
+ ```
115
+
116
+ ---
117
+
118
+ ## 3️⃣ Streaming Generation
119
+
120
+ For real-time output (chat UI, etc.):
121
+
122
+ ```python
123
+ from mlx_lm import load
124
+ from dflash_mlx import DFlashSpeculativeDecoder
125
+ from dflash_mlx.convert import load_mlx_dflash
126
+
127
+ model, tokenizer = load("mlx-community/Qwen3-4B-bf16")
128
+ draft_model, _ = load_mlx_dflash("~/models/dflash/Qwen3-4B-DFlash-mlx")
129
+ decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
130
+
131
+ # Generator-based streaming
132
+ for chunk in decoder.generate(
133
+ prompt="Tell me a story about a robot.",
134
+ max_tokens=512,
135
+ stream=True, # ← Returns generator
136
+ ):
137
+ print(chunk, end="", flush=True)
138
+ ```
139
+
140
+ ---
141
+
142
+ ## 4️⃣ Benchmark Mode
143
+
144
+ Compare DFlash vs baseline speed:
145
+
146
+ ```python
147
+ from mlx_lm import load
148
+ from dflash_mlx import DFlashSpeculativeDecoder
149
+ from dflash_mlx.convert import load_mlx_dflash
150
+
151
+ model, tokenizer = load("mlx-community/Qwen3-4B-bf16")
152
+ draft_model, _ = load_mlx_dflash("~/models/dflash/Qwen3-4B-DFlash-mlx")
153
+ decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
154
+
155
+ # Run benchmark
156
+ results = decoder.benchmark(
157
+ prompt="Write a quicksort in Python.",
158
+ max_tokens=512,
159
+ num_runs=5,
160
+ )
161
+
162
+ print(f"Speedup: {results['speedup']:.2f}x")
163
+ print(f"Tokens/sec: {results['tokens_per_sec']:.1f}")
164
+ ```
165
+
166
+ **Sample results (M2 Pro Max 96GB):**
167
+ ```
168
+ [Benchmark] Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s
169
+ ```
170
+
171
+ ---
172
+
173
+ ## 5️⃣ Universal Decoder (Any Model Without Pre-built Drafter)
174
+
175
+ If your model doesn't have a DFlash drafter yet:
176
+
177
+ ```python
178
+ from mlx_lm import load
179
+ from dflash_mlx.universal import UniversalDFlashDecoder
180
+
181
+ # Load ANY mlx_lm model
182
+ model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
183
+
184
+ # UniversalDFlashDecoder:
185
+ # 1. Auto-detects architecture (LLaMA in this case)
186
+ # 2. Creates a generic 5-layer drafter (~500MB)
187
+ # 3. Sets up proper adapter for hidden state extraction
188
+
189
+ decoder = UniversalDFlashDecoder(
190
+ target_model=model,
191
+ tokenizer=tokenizer,
192
+ draft_layers=5,
193
+ draft_hidden_size=1024,
194
+ block_size=16,
195
+ )
196
+
197
+ # Option A: Train a custom drafter (2-8 hours)
198
+ decoder.train_drafter(
199
+ dataset="open-web-math", # or local JSONL
200
+ epochs=6,
201
+ lr=6e-4,
202
+ batch_size=16,
203
+ output_path="~/models/dflash/my-llama-drafter",
204
+ )
205
+
206
+ # Option B: Use untrained (low quality, for testing only)
207
+ output = decoder.generate(
208
+ prompt="Hello world!",
209
+ max_tokens=100,
210
+ )
211
+ ```
212
+
213
+ ---
214
+
215
+ ## 6️⃣ OpenAI-Compatible Server
216
+
217
+ Run a local server compatible with OpenAI clients:
218
+
219
+ ```bash
220
+ # Start server
221
+ python -m dflash_mlx.serve \
222
+ --target mlx-community/Qwen3-4B-bf16 \
223
+ --draft ~/models/dflash/Qwen3-4B-DFlash-mlx \
224
+ --block-size 16 \
225
+ --port 8000
226
+
227
+ # Or in background
228
+ nohup python -m dflash_mlx.serve \
229
+ --target mlx-community/Qwen3-4B-bf16 \
230
+ --draft ~/models/dflash/Qwen3-4B-DFlash-mlx \
231
+ --port 8000 > dflash.log 2>&1 &
232
+ ```
233
+
234
+ ### Query the server
235
+
236
+ ```bash
237
+ # Chat completion
238
+ curl http://localhost:8000/v1/chat/completions \
239
+ -H "Content-Type: application/json" \
240
+ -d '{
241
+ "model": "qwen3-4b",
242
+ "messages": [{"role": "user", "content": "Explain quantum computing"}],
243
+ "max_tokens": 512,
244
+ "temperature": 0.0
245
+ }'
246
+
247
+ # Streaming
248
+ curl http://localhost:8000/v1/chat/completions \
249
+ -H "Content-Type: application/json" \
250
+ -d '{
251
+ "model": "qwen3-4b",
252
+ "messages": [{"role": "user", "content": "Count to 10"}],
253
+ "max_tokens": 100,
254
+ "stream": true
255
+ }'
256
+
257
+ # Check metrics
258
+ curl http://localhost:8000/metrics
259
+ ```
260
+
261
+ ### Python client
262
+
263
+ ```python
264
+ from openai import OpenAI
265
+
266
+ client = OpenAI(
267
+ base_url="http://localhost:8000/v1",
268
+ api_key="not-needed", # Local server, no auth
269
+ )
270
+
271
+ response = client.chat.completions.create(
272
+ model="qwen3-4b",
273
+ messages=[{"role": "user", "content": "Write a haiku about ML"}],
274
+ max_tokens=100,
275
+ )
276
+ print(response.choices[0].message.content)
277
+ ```
278
+
279
+ ---
280
+
281
+ ## 7️⃣ Using with Ollama, aider, Continue, etc.
282
+
283
+ Any OpenAI-compatible client works:
284
+
285
+ ### aider (AI coding assistant)
286
+ ```bash
287
+ aider --model openai/qwen3-4b --openai-api-base http://localhost:8000/v1 --openai-api-key not-needed
288
+ ```
289
+
290
+ ### Continue.dev (VS Code extension)
291
+ ```json
292
+ // .continue/config.json
293
+ {
294
+ "models": [{
295
+ "title": "DFlash Qwen3-4B",
296
+ "provider": "openai",
297
+ "model": "qwen3-4b",
298
+ "apiBase": "http://localhost:8000/v1",
299
+ "apiKey": "not-needed"
300
+ }]
301
+ }
302
+ ```
303
+
304
+ ### Ollama (as custom endpoint)
305
+ Configure any OpenAI-compatible frontend to point at `http://localhost:8000/v1`
306
+
307
+ ---
308
+
309
+ ## 8️⃣ Supported Model Families
310
+
311
+ | Family | Target Model Example | Drafter Status |
312
+ |--------|---------------------|---------------|
313
+ | **Qwen3** | `mlx-community/Qwen3-4B-bf16` | βœ… Pre-built |
314
+ | **Qwen3.5** | `mlx-community/Qwen3.5-9B-4bit` | βœ… Pre-built |
315
+ | **Qwen3.6** | `mlx-community/Qwen3.6-27B-4bit` | βœ… Pre-built |
316
+ | **LLaMA 3.1** | `mlx-community/Llama-3.1-8B-Instruct-4bit` | βœ… Pre-built |
317
+ | **LLaMA 3.3** | `mlx-community/Llama-3.3-70B-Instruct-4bit` | βœ… Pre-built |
318
+ | **Mistral** | `mlx-community/Mistral-7B-Instruct-v0.3-4bit` | ⚠️ Train custom |
319
+ | **Gemma** | `mlx-community/gemma-4-31b-it-4bit` | βœ… Pre-built |
320
+ | **Phi** | `mlx-community/Phi-3-mini-4k-instruct-4bit` | ⚠️ Generic adapter |
321
+
322
+ ---
323
+
324
+ ## 9️⃣ Troubleshooting
325
+
326
+ ### "Unsupported model_type: phi"
327
+ ```python
328
+ # Add a custom adapter for your model
329
+ from dflash_mlx.adapters import MLXTargetAdapter, ADAPTERS
330
+
331
+ class PhiAdapter(MLXTargetAdapter):
332
+ family = "phi"
333
+ # Override methods as needed...
334
+
335
+ ADAPTERS["phi"] = PhiAdapter
336
+ ```
337
+
338
+ ### "Vocab size mismatch"
339
+ Ensure target model and draft model share the same tokenizer vocabulary. Drafters are trained for specific target families.
340
+
341
+ ### Slow first run
342
+ MLX compiles Metal kernels lazily. First generation is slow; subsequent runs are fast. The benchmark method includes warmup.
343
+
344
+ ### Out of memory
345
+ - Reduce `--block-size` (default 16 β†’ 8)
346
+ - Use 4-bit quantized target models (`-4bit` suffix)
347
+ - Reduce `max_tokens`
348
+
349
+ ### Draft tokens all rejected
350
+ - Drafter may not match target model (wrong family)
351
+ - Use trained drafter for your specific model
352
+ - Check `target_layer_ids` alignment in config
353
+
354
+ ---
355
+
356
+ ## πŸ”Ÿ Full Example Script
357
+
358
+ Save as `run_dflash.py`:
359
+
360
+ ```python
361
+ #!/usr/bin/env python3
362
+ """Complete DFlash example with error handling."""
363
+
364
+ import sys
365
+ from mlx_lm import load
366
+ from dflash_mlx import DFlashSpeculativeDecoder
367
+ from dflash_mlx.convert import load_mlx_dflash
368
+
369
+ def main():
370
+ # Configuration
371
+ TARGET_MODEL = "mlx-community/Qwen3-4B-bf16"
372
+ DRAFT_MODEL = "~/models/dflash/Qwen3-4B-DFlash-mlx"
373
+ PROMPT = "Explain how speculative decoding works."
374
+ MAX_TOKENS = 512
375
+
376
+ print(f"Loading target model: {TARGET_MODEL}")
377
+ model, tokenizer = load(TARGET_MODEL)
378
+
379
+ print(f"Loading DFlash drafter: {DRAFT_MODEL}")
380
+ try:
381
+ draft_model, draft_config = load_mlx_dflash(DRAFT_MODEL)
382
+ except FileNotFoundError:
383
+ print(f"Error: Drafter not found at {DRAFT_MODEL}")
384
+ print("Convert first: python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ~/models/dflash/Qwen3-4B-DFlash-mlx")
385
+ sys.exit(1)
386
+
387
+ print("Creating DFlash decoder...")
388
+ decoder = DFlashSpeculativeDecoder(
389
+ target_model=model,
390
+ draft_model=draft_model,
391
+ tokenizer=tokenizer,
392
+ block_size=draft_config.get("block_size", 16),
393
+ )
394
+
395
+ print(f"\nPrompt: {PROMPT}")
396
+ print("-" * 60)
397
+
398
+ output = decoder.generate(
399
+ prompt=PROMPT,
400
+ max_tokens=MAX_TOKENS,
401
+ temperature=0.0,
402
+ )
403
+
404
+ print(output)
405
+ print("-" * 60)
406
+ print("Done!")
407
+
408
+ if __name__ == "__main__":
409
+ main()
410
+ ```
411
+
412
+ Run:
413
+ ```bash
414
+ python run_dflash.py
415
+ ```
416
+
417
+ ---
418
+
419
+ ## πŸ“š Next Steps
420
+
421
+ 1. **Convert your first drafter** β†’ `python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./drafter`
422
+ 2. **Benchmark it** β†’ Use `decoder.benchmark(...)` to verify speedup
423
+ 3. **Start the server** β†’ `python -m dflash_mlx.serve --target ... --draft ...`
424
+ 4. **Connect your tools** β†’ aider, Continue, custom clients
425
+ 5. **Train custom drafters** β†’ For unsupported models using `UniversalDFlashDecoder`
426
+
427
+ ---
428
+
429
+ For questions/issues: https://huggingface.co/tritesh/dflash-mlx-universal/discussions