dflash-mlx-universal / benchmark_m2.py
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
"""
Benchmark DFlash speculative decoding on Apple Silicon.
Usage:
python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ~/models/dflash/Qwen3-8B-DFlash-mlx
python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ~/models/dflash/Qwen3-4B-DFlash-mlx --tokens 1024
"""
import time
import argparse
import mlx.core as mx
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
def benchmark(
target_model_path: str,
draft_model_path: str,
prompt: str = "Write a Python function to implement merge sort with detailed comments.",
max_tokens: int = 512,
num_runs: int = 5,
block_size: int = 16,
temperature: float = 0.0,
):
"""Run comprehensive benchmark of DFlash vs baseline on MLX."""
print("=" * 70)
print(" DFlash Speculative Decoding Benchmark")
print("=" * 70)
print(f"Device: {mx.default_device()}")
print(f"Target Model: {target_model_path}")
print(f"Draft Model: {draft_model_path}")
print(f"Block Size: {block_size}")
print(f"Max Tokens: {max_tokens}")
print(f"Temperature: {temperature}")
print(f"Runs: {num_runs}")
print("=" * 70)
# Load models
print("\n[1/4] Loading target model...")
t0 = time.time()
model, tokenizer = load(target_model_path)
print(f" Loaded in {time.time() - t0:.2f}s")
print("\n[2/4] Loading draft model...")
t0 = time.time()
draft_model, draft_config = load_mlx_dflash(draft_model_path)
print(f" Loaded in {time.time() - t0:.2f}s")
print(f" Drafter: {draft_config.get('num_hidden_layers', '?')} layers, "
f"{draft_config.get('hidden_size', '?')} hidden dim")
# Create decoder
print("\n[3/4] Initializing DFlash decoder...")
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=block_size,
)
print(" Ready")
# Warmup
print("\n[4/4] Warmup run (compiles Metal kernels)...")
t0 = time.time()
decoder.generate(prompt, max_tokens=50, temperature=temperature)
print(f" Warmup complete in {time.time() - t0:.2f}s")
# Benchmark DFlash
print(f"\n{'='*70}")
print(" Running DFlash Speculative Decoding")
print(f"{'='*70}")
dflash_times = []
dflash_outputs = []
for i in range(num_runs):
start = time.time()
output = decoder.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
)
elapsed = time.time() - start
dflash_times.append(elapsed)
dflash_outputs.append(output)
print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)")
avg_dflash = sum(dflash_times) / len(dflash_times)
dflash_tok_s = max_tokens / avg_dflash
# Baseline benchmark (if requested)
print(f"\n{'='*70}")
print(" Running Baseline (No Speculative Decoding)")
print(f"{'='*70}")
baseline_times = []
for i in range(num_runs):
start = time.time()
# Native MLX generate without speculative decoding
from mlx_lm import generate
generate(
model,
tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temp=temperature,
)
elapsed = time.time() - start
baseline_times.append(elapsed)
print(f" Run {i+1}: {elapsed:.3f}s ({max_tokens/elapsed:.1f} tok/s)")
avg_baseline = sum(baseline_times) / len(baseline_times)
baseline_tok_s = max_tokens / avg_baseline
speedup = avg_baseline / avg_dflash
# Summary
print(f"\n{'='*70}")
print(" RESULTS SUMMARY")
print(f"{'='*70}")
print(f" Model: {target_model_path}")
print(f" Baseline: {avg_baseline:.3f}s avg ({baseline_tok_s:.1f} tok/s)")
print(f" DFlash: {avg_dflash:.3f}s avg ({dflash_tok_s:.1f} tok/s)")
print(f" Speedup: {speedup:.2f}x")
print(f" Tokens saved: {max_tokens * (1 - 1/speedup):.0f} per generation")
print(f" Time saved: {avg_baseline - avg_dflash:.3f}s per generation")
print(f"{'='*70}")
# Memory usage
try:
import psutil
mem = psutil.virtual_memory()
print(f"\n Memory:")
print(f" Total: {mem.total / 1e9:.1f} GB")
print(f" Used: {mem.used / 1e9:.1f} GB")
print(f" Available: {mem.available / 1e9:.1f} GB")
print(f" MLX Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB")
except ImportError:
pass
# Show sample output
print(f"\n{'='*70}")
print(" Sample Output (first 500 chars)")
print(f"{'='*70}")
print(dflash_outputs[0][:500] if dflash_outputs else "N/A")
print("...")
print(f"{'='*70}")
return {
"target_model": target_model_path,
"draft_model": draft_model_path,
"speedup": speedup,
"baseline_tok_s": baseline_tok_s,
"dflash_tok_s": dflash_tok_s,
"baseline_time": avg_baseline,
"dflash_time": avg_dflash,
}
def main():
parser = argparse.ArgumentParser(
description="Benchmark DFlash speculative decoding on Apple Silicon",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Qwen3-4B (fastest)
python benchmark_m2.py --target Qwen/Qwen3-4B-MLX-4bit --draft ./Qwen3-4B-DFlash-mlx
# Qwen3-8B (best balance)
python benchmark_m2.py --target Qwen/Qwen3-8B-MLX-4bit --draft ./Qwen3-8B-DFlash-mlx
# Custom model with temperature
python benchmark_m2.py --target mlx-community/Llama-3.1-8B-Instruct-4bit \\
--draft ./llama3.1-dflash --temperature 0.7 --tokens 1024
""",
)
parser.add_argument(
"--target",
type=str,
required=True,
help="MLX target model ID or path (e.g., Qwen/Qwen3-8B-MLX-4bit)",
)
parser.add_argument(
"--draft",
type=str,
required=True,
help="Path to converted DFlash drafter",
)
parser.add_argument(
"--tokens",
type=int,
default=512,
help="Number of tokens to generate per run (default: 512)",
)
parser.add_argument(
"--runs",
type=int,
default=5,
help="Number of benchmark runs (default: 5)",
)
parser.add_argument(
"--block-size",
type=int,
default=16,
help="DFlash block size (default: 16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (default: 0.0 = greedy)",
)
parser.add_argument(
"--prompt",
type=str,
default="Write a Python function to implement merge sort with detailed comments.",
help="Benchmark prompt",
)
args = parser.parse_args()
results = benchmark(
target_model_path=args.target,
draft_model_path=args.draft,
prompt=args.prompt,
max_tokens=args.tokens,
num_runs=args.runs,
block_size=args.block_size,
temperature=args.temperature,
)
# Save results to JSON
import json
from datetime import datetime
results["timestamp"] = datetime.now().isoformat()
results["device"] = str(mx.default_device())
output_file = f"benchmark_results_{results['target_model'].replace('/', '_')}.json"
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {output_file}")
if __name__ == "__main__":
main()