""" Benchmark DFlash speculative decoding on Apple Silicon. Usage: python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ~/models/dflash/Qwen3-8B-DFlash-mlx python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ~/models/dflash/Qwen3-4B-DFlash-mlx --tokens 1024 """ import time import argparse import mlx.core as mx from mlx_lm import load from dflash_mlx import DFlashSpeculativeDecoder from dflash_mlx.convert import load_mlx_dflash def benchmark( target_model_path: str, draft_model_path: str, prompt: str = "Write a Python function to implement merge sort with detailed comments.", max_tokens: int = 512, num_runs: int = 5, block_size: int = 16, temperature: float = 0.0, ): """Run comprehensive benchmark of DFlash vs baseline on MLX.""" print("=" * 70) print(" DFlash Speculative Decoding Benchmark") print("=" * 70) print(f"Device: {mx.default_device()}") print(f"Target Model: {target_model_path}") print(f"Draft Model: {draft_model_path}") print(f"Block Size: {block_size}") print(f"Max Tokens: {max_tokens}") print(f"Temperature: {temperature}") print(f"Runs: {num_runs}") print("=" * 70) # Load models print("\n[1/4] Loading target model...") t0 = time.time() model, tokenizer = load(target_model_path) print(f" Loaded in {time.time() - t0:.2f}s") print("\n[2/4] Loading draft model...") t0 = time.time() draft_model, draft_config = load_mlx_dflash(draft_model_path) print(f" Loaded in {time.time() - t0:.2f}s") print(f" Drafter: {draft_config.get('num_hidden_layers', '?')} layers, " f"{draft_config.get('hidden_size', '?')} hidden dim") # Create decoder print("\n[3/4] Initializing DFlash decoder...") decoder = DFlashSpeculativeDecoder( target_model=model, draft_model=draft_model, tokenizer=tokenizer, block_size=block_size, ) print(" Ready") # Warmup print("\n[4/4] Warmup run (compiles Metal kernels)...") t0 = time.time() decoder.generate(prompt, max_tokens=50, temperature=temperature) print(f" Warmup complete in {time.time() - t0:.2f}s") # Benchmark DFlash print(f"\n{'='*70}") print(" Running DFlash Speculative Decoding") print(f"{'='*70}") dflash_times = [] dflash_outputs = [] for i in range(num_runs): start = time.time() output = decoder.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, ) elapsed = time.time() - start dflash_times.append(elapsed) dflash_outputs.append(output) print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)") avg_dflash = sum(dflash_times) / len(dflash_times) dflash_tok_s = max_tokens / avg_dflash # Baseline benchmark (if requested) print(f"\n{'='*70}") print(" Running Baseline (No Speculative Decoding)") print(f"{'='*70}") baseline_times = [] for i in range(num_runs): start = time.time() # Native MLX generate without speculative decoding from mlx_lm import generate generate( model, tokenizer, prompt=prompt, max_tokens=max_tokens, temp=temperature, ) elapsed = time.time() - start baseline_times.append(elapsed) print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)") avg_baseline = sum(baseline_times) / len(baseline_times) baseline_tok_s = max_tokens / avg_baseline speedup = avg_baseline / avg_dflash # Summary print(f"\n{'='*70}") print(" RESULTS SUMMARY") print(f"{'='*70}") print(f" Model: {target_model_path}") print(f" Baseline: {avg_baseline:.3f}s avg ({baseline_tok_s:.1f} tok/s)") print(f" DFlash: {avg_dflash:.3f}s avg ({dflash_tok_s:.1f} tok/s)") print(f" Speedup: {speedup:.2f}x") print(f" Tokens saved: {max_tokens * (1 - 1/speedup):.0f} per generation") print(f" Time saved: {avg_baseline - avg_dflash:.3f}s per generation") print(f"{'='*70}") # Memory usage try: import psutil mem = psutil.virtual_memory() print(f"\n Memory:") print(f" Total: {mem.total / 1e9:.1f} GB") print(f" Used: {mem.used / 1e9:.1f} GB") print(f" Available: {mem.available / 1e9:.1f} GB") print(f" MLX Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB") except ImportError: pass # Show sample output print(f"\n{'='*70}") print(" Sample Output (first 500 chars)") print(f"{'='*70}") print(dflash_outputs[0][:500] if dflash_outputs else "N/A") print("...") print(f"{'='*70}") return { "target_model": target_model_path, "draft_model": draft_model_path, "speedup": speedup, "baseline_tok_s": baseline_tok_s, "dflash_tok_s": dflash_tok_s, "baseline_time": avg_baseline, "dflash_time": avg_dflash, } def main(): parser = argparse.ArgumentParser( description="Benchmark DFlash speculative decoding on Apple Silicon", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Qwen3-4B (fastest) python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ./Qwen3-4B-DFlash-mlx # Qwen3-8B (best balance) python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ./Qwen3-8B-DFlash-mlx # Custom model with temperature python benchmark_m2.py --target mlx-community/Llama-3.1-8B-Instruct-4bit \\ --draft ./llama3.1-dflash --temperature 0.7 --tokens 1024 """, ) parser.add_argument( "--target", type=str, required=True, help="MLX target model ID or path (e.g., Qwen/Qwen3-8B-MLX-4bit)", ) parser.add_argument( "--draft", type=str, required=True, help="Path to converted DFlash drafter", ) parser.add_argument( "--tokens", type=int, default=512, help="Number of tokens to generate per run (default: 512)", ) parser.add_argument( "--runs", type=int, default=5, help="Number of benchmark runs (default: 5)", ) parser.add_argument( "--block-size", type=int, default=16, help="DFlash block size (default: 16)", ) parser.add_argument( "--temperature", type=float, default=0.0, help="Sampling temperature (default: 0.0 = greedy)", ) parser.add_argument( "--prompt", type=str, default="Write a Python function to implement merge sort with detailed comments.", help="Benchmark prompt", ) args = parser.parse_args() results = benchmark( target_model_path=args.target, draft_model_path=args.draft, prompt=args.prompt, max_tokens=args.tokens, num_runs=args.runs, block_size=args.block_size, temperature=args.temperature, ) # Save results to JSON import json from datetime import datetime results["timestamp"] = datetime.now().isoformat() results["device"] = str(mx.default_device()) output_file = f"benchmark_results_{results['target_model'].replace('/', '_')}.json" with open(output_file, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved to: {output_file}") if __name__ == "__main__": main()