| """ |
| Benchmark DFlash speculative decoding on Apple Silicon. |
| |
| Usage: |
| python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ~/models/dflash/Qwen3-8B-DFlash-mlx |
| python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ~/models/dflash/Qwen3-4B-DFlash-mlx --tokens 1024 |
| """ |
|
|
| import time |
| import argparse |
| import mlx.core as mx |
| from mlx_lm import load |
| from dflash_mlx import DFlashSpeculativeDecoder |
| from dflash_mlx.convert import load_mlx_dflash |
|
|
|
|
| def benchmark( |
| target_model_path: str, |
| draft_model_path: str, |
| prompt: str = "Write a Python function to implement merge sort with detailed comments.", |
| max_tokens: int = 512, |
| num_runs: int = 5, |
| block_size: int = 16, |
| temperature: float = 0.0, |
| ): |
| """Run comprehensive benchmark of DFlash vs baseline on MLX.""" |
| |
| print("=" * 70) |
| print(" DFlash Speculative Decoding Benchmark") |
| print("=" * 70) |
| print(f"Device: {mx.default_device()}") |
| print(f"Target Model: {target_model_path}") |
| print(f"Draft Model: {draft_model_path}") |
| print(f"Block Size: {block_size}") |
| print(f"Max Tokens: {max_tokens}") |
| print(f"Temperature: {temperature}") |
| print(f"Runs: {num_runs}") |
| print("=" * 70) |
|
|
| |
| print("\n[1/4] Loading target model...") |
| t0 = time.time() |
| model, tokenizer = load(target_model_path) |
| print(f" Loaded in {time.time() - t0:.2f}s") |
|
|
| print("\n[2/4] Loading draft model...") |
| t0 = time.time() |
| draft_model, draft_config = load_mlx_dflash(draft_model_path) |
| print(f" Loaded in {time.time() - t0:.2f}s") |
| print(f" Drafter: {draft_config.get('num_hidden_layers', '?')} layers, " |
| f"{draft_config.get('hidden_size', '?')} hidden dim") |
|
|
| |
| print("\n[3/4] Initializing DFlash decoder...") |
| decoder = DFlashSpeculativeDecoder( |
| target_model=model, |
| draft_model=draft_model, |
| tokenizer=tokenizer, |
| block_size=block_size, |
| ) |
| print(" Ready") |
|
|
| |
| print("\n[4/4] Warmup run (compiles Metal kernels)...") |
| t0 = time.time() |
| decoder.generate(prompt, max_tokens=50, temperature=temperature) |
| print(f" Warmup complete in {time.time() - t0:.2f}s") |
|
|
| |
| print(f"\n{'='*70}") |
| print(" Running DFlash Speculative Decoding") |
| print(f"{'='*70}") |
| |
| dflash_times = [] |
| dflash_outputs = [] |
| for i in range(num_runs): |
| start = time.time() |
| output = decoder.generate( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| ) |
| elapsed = time.time() - start |
| dflash_times.append(elapsed) |
| dflash_outputs.append(output) |
| print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)") |
|
|
| avg_dflash = sum(dflash_times) / len(dflash_times) |
| dflash_tok_s = max_tokens / avg_dflash |
|
|
| |
| print(f"\n{'='*70}") |
| print(" Running Baseline (No Speculative Decoding)") |
| print(f"{'='*70}") |
| |
| baseline_times = [] |
| for i in range(num_runs): |
| start = time.time() |
| |
| from mlx_lm import generate |
| generate( |
| model, |
| tokenizer, |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temp=temperature, |
| ) |
| elapsed = time.time() - start |
| baseline_times.append(elapsed) |
| print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)") |
|
|
| avg_baseline = sum(baseline_times) / len(baseline_times) |
| baseline_tok_s = max_tokens / avg_baseline |
| speedup = avg_baseline / avg_dflash |
|
|
| |
| print(f"\n{'='*70}") |
| print(" RESULTS SUMMARY") |
| print(f"{'='*70}") |
| print(f" Model: {target_model_path}") |
| print(f" Baseline: {avg_baseline:.3f}s avg ({baseline_tok_s:.1f} tok/s)") |
| print(f" DFlash: {avg_dflash:.3f}s avg ({dflash_tok_s:.1f} tok/s)") |
| print(f" Speedup: {speedup:.2f}x") |
| print(f" Tokens saved: {max_tokens * (1 - 1/speedup):.0f} per generation") |
| print(f" Time saved: {avg_baseline - avg_dflash:.3f}s per generation") |
| print(f"{'='*70}") |
|
|
| |
| try: |
| import psutil |
| mem = psutil.virtual_memory() |
| print(f"\n Memory:") |
| print(f" Total: {mem.total / 1e9:.1f} GB") |
| print(f" Used: {mem.used / 1e9:.1f} GB") |
| print(f" Available: {mem.available / 1e9:.1f} GB") |
| print(f" MLX Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB") |
| except ImportError: |
| pass |
|
|
| |
| print(f"\n{'='*70}") |
| print(" Sample Output (first 500 chars)") |
| print(f"{'='*70}") |
| print(dflash_outputs[0][:500] if dflash_outputs else "N/A") |
| print("...") |
| print(f"{'='*70}") |
|
|
| return { |
| "target_model": target_model_path, |
| "draft_model": draft_model_path, |
| "speedup": speedup, |
| "baseline_tok_s": baseline_tok_s, |
| "dflash_tok_s": dflash_tok_s, |
| "baseline_time": avg_baseline, |
| "dflash_time": avg_dflash, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Benchmark DFlash speculative decoding on Apple Silicon", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Qwen3-4B (fastest) |
| python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ./Qwen3-4B-DFlash-mlx |
| |
| # Qwen3-8B (best balance) |
| python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ./Qwen3-8B-DFlash-mlx |
| |
| # Custom model with temperature |
| python benchmark_m2.py --target mlx-community/Llama-3.1-8B-Instruct-4bit \\ |
| --draft ./llama3.1-dflash --temperature 0.7 --tokens 1024 |
| """, |
| ) |
| parser.add_argument( |
| "--target", |
| type=str, |
| required=True, |
| help="MLX target model ID or path (e.g., Qwen/Qwen3-8B-MLX-4bit)", |
| ) |
| parser.add_argument( |
| "--draft", |
| type=str, |
| required=True, |
| help="Path to converted DFlash drafter", |
| ) |
| parser.add_argument( |
| "--tokens", |
| type=int, |
| default=512, |
| help="Number of tokens to generate per run (default: 512)", |
| ) |
| parser.add_argument( |
| "--runs", |
| type=int, |
| default=5, |
| help="Number of benchmark runs (default: 5)", |
| ) |
| parser.add_argument( |
| "--block-size", |
| type=int, |
| default=16, |
| help="DFlash block size (default: 16)", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.0, |
| help="Sampling temperature (default: 0.0 = greedy)", |
| ) |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| default="Write a Python function to implement merge sort with detailed comments.", |
| help="Benchmark prompt", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| results = benchmark( |
| target_model_path=args.target, |
| draft_model_path=args.draft, |
| prompt=args.prompt, |
| max_tokens=args.tokens, |
| num_runs=args.runs, |
| block_size=args.block_size, |
| temperature=args.temperature, |
| ) |
|
|
| |
| import json |
| from datetime import datetime |
| |
| results["timestamp"] = datetime.now().isoformat() |
| results["device"] = str(mx.default_device()) |
| |
| output_file = f"benchmark_results_{results['target_model'].replace('/', '_')}.json" |
| with open(output_file, "w") as f: |
| json.dump(results, f, indent=2) |
| |
| print(f"\nResults saved to: {output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|