| """ |
| Example: DFlash speculative decoding with Qwen3-4B on MLX. |
| |
| This demonstrates using a pre-converted DFlash drafter with the Qwen3-4B |
| model on Apple Silicon. |
| |
| Prerequisites: |
| pip install mlx-lm dflash-mlx-universal |
| |
| # Convert the drafter (one-time) |
| python -m dflash_mlx.convert \ |
| --model z-lab/Qwen3-4B-DFlash-b16 \ |
| --output ./Qwen3-4B-DFlash-mlx |
| """ |
|
|
| from mlx_lm import load |
| from dflash_mlx import DFlashSpeculativeDecoder |
| from dflash_mlx.convert import load_mlx_dflash |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("DFlash Speculative Decoding Demo - Qwen3-4B") |
| print("=" * 60) |
|
|
| |
| print("\n[1] Loading Qwen3-4B target model...") |
| model, tokenizer = load("Qwen/Qwen3-4B-MLX-4bit") |
| print(" ✓ Target model loaded") |
|
|
| |
| print("\n[2] Loading DFlash drafter...") |
| draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") |
| print(f" ✓ Drafter loaded ({draft_config['num_hidden_layers']} layers)") |
|
|
| |
| print("\n[3] Creating DFlash speculative decoder...") |
| decoder = DFlashSpeculativeDecoder( |
| target_model=model, |
| draft_model=draft_model, |
| tokenizer=tokenizer, |
| block_size=draft_config.get("block_size", 16), |
| ) |
|
|
| |
| print("\n[4] Generating with DFlash speculative decoding...") |
| prompt = "Write a Python function to implement quicksort." |
| |
| print(f"\nPrompt: {prompt}") |
| print("-" * 60) |
| |
| output = decoder.generate( |
| prompt=prompt, |
| max_tokens=1024, |
| temperature=0.0, |
| ) |
| |
| print(output) |
| print("-" * 60) |
|
|
| |
| print("\n[5] Running baseline (no speculative decoding)...") |
| |
| import time |
| |
| |
| start = time.time() |
| baseline_output = model.generate( |
| tokenizer.encode(prompt), |
| max_tokens=512, |
| temp=0.0, |
| ) |
| baseline_time = time.time() - start |
| |
| |
| start = time.time() |
| dflash_output = decoder.generate( |
| prompt=prompt, |
| max_tokens=512, |
| temperature=0.0, |
| ) |
| dflash_time = time.time() - start |
| |
| speedup = baseline_time / dflash_time |
| print(f"\nBaseline: {baseline_time:.2f}s") |
| print(f"DFlash: {dflash_time:.2f}s") |
| print(f"Speedup: {speedup:.2f}x") |
|
|
| print("\n" + "=" * 60) |
| print("Demo complete!") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|