""" Example: DFlash speculative decoding with Qwen3-4B on MLX. This demonstrates using a pre-converted DFlash drafter with the Qwen3-4B model on Apple Silicon. Prerequisites: pip install mlx-lm dflash-mlx-universal # Convert the drafter (one-time) python -m dflash_mlx.convert \ --model z-lab/Qwen3-4B-DFlash-b16 \ --output ./Qwen3-4B-DFlash-mlx """ from mlx_lm import load from dflash_mlx import DFlashSpeculativeDecoder from dflash_mlx.convert import load_mlx_dflash def main(): print("=" * 60) print("DFlash Speculative Decoding Demo - Qwen3-4B") print("=" * 60) # 1. Load target model (MLX-converted) print("\n[1] Loading Qwen3-4B target model...") model, tokenizer = load("Qwen/Qwen3-4B-MLX-4bit") print(" ✓ Target model loaded") # 2. Load converted DFlash drafter print("\n[2] Loading DFlash drafter...") draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx") print(f" ✓ Drafter loaded ({draft_config['num_hidden_layers']} layers)") # 3. Create decoder print("\n[3] Creating DFlash speculative decoder...") decoder = DFlashSpeculativeDecoder( target_model=model, draft_model=draft_model, tokenizer=tokenizer, block_size=draft_config.get("block_size", 16), ) # 4. Generate print("\n[4] Generating with DFlash speculative decoding...") prompt = "Write a Python function to implement quicksort." print(f"\nPrompt: {prompt}") print("-" * 60) output = decoder.generate( prompt=prompt, max_tokens=1024, temperature=0.0, ) print(output) print("-" * 60) # 5. Compare with baseline print("\n[5] Running baseline (no speculative decoding)...") import time # Baseline start = time.time() baseline_output = model.generate( tokenizer.encode(prompt), max_tokens=512, temp=0.0, ) baseline_time = time.time() - start # DFlash start = time.time() dflash_output = decoder.generate( prompt=prompt, max_tokens=512, temperature=0.0, ) dflash_time = time.time() - start speedup = baseline_time / dflash_time print(f"\nBaseline: {baseline_time:.2f}s") print(f"DFlash: {dflash_time:.2f}s") print(f"Speedup: {speedup:.2f}x") print("\n" + "=" * 60) print("Demo complete!") print("=" * 60) if __name__ == "__main__": main()