| """ |
| Tinman-SmolOmni-MLA Benchmark Suite |
| Measures: VRAM usage, KV cache size, throughput, generation quality |
| Compares against SmolVLM baseline. |
| """ |
| import time |
| import json |
| import argparse |
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForImageTextToText, AutoTokenizer |
|
|
| from smolomni.config import SmolOmniConfig |
| from smolomni.model import SmolOmniModel |
| from smolomni.svd_init import initialize_mla_from_pretrained |
|
|
|
|
| def benchmark_kv_cache(config: SmolOmniConfig): |
| """Report KV cache sizes for different configurations.""" |
| info = config.kv_cache_size_per_token() |
| return info |
|
|
|
|
| def benchmark_model_loading(model_variant: str, device: str = "cuda"): |
| """Time model initialization and SVD init.""" |
| config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}") |
| |
| |
| start = time.time() |
| baseline = AutoModelForImageTextToText.from_pretrained( |
| config.base_model, torch_dtype=torch.bfloat16 |
| ).to(device) |
| baseline_time = time.time() - start |
| baseline_params = sum(p.numel() for p in baseline.parameters()) |
| |
| |
| start = time.time() |
| model = SmolOmniModel(config) |
| model = initialize_mla_from_pretrained(model, config.base_model, config) |
| model = model.to(device, dtype=torch.bfloat16) |
| smol_time = time.time() - start |
| smol_params = model.num_parameters() |
| |
| del baseline |
| torch.cuda.empty_cache() |
| |
| return { |
| "baseline": { |
| "load_time_s": baseline_time, |
| "params_M": baseline_params / 1e6, |
| }, |
| "smolomni": { |
| "load_time_s": smol_time, |
| "params_M": smol_params / 1e6, |
| }, |
| } |
|
|
|
|
| def benchmark_throughput(model, tokenizer, config, batch_size: int = 1, seq_len: int = 512): |
| """Measure tokens/second for understanding and generation.""" |
| device = next(model.parameters()).device |
| |
| |
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) |
| |
| |
| for _ in range(3): |
| with torch.no_grad(): |
| _ = model.forward_understanding(input_ids) |
| |
| torch.cuda.synchronize() |
| start = time.time() |
| n_iters = 10 |
| for _ in range(n_iters): |
| with torch.no_grad(): |
| _ = model.forward_understanding(input_ids) |
| torch.cuda.synchronize() |
| ar_time = (time.time() - start) / n_iters |
| ar_tps = (batch_size * seq_len) / ar_time |
| |
| |
| latent_shape = (batch_size, 4, 32, 32) |
| torch.cuda.synchronize() |
| start = time.time() |
| with torch.no_grad(): |
| _ = model.generate_image(input_ids[:batch_size], num_steps=50, latent_shape=latent_shape) |
| torch.cuda.synchronize() |
| gen_time = time.time() - start |
| |
| return { |
| "ar_time_ms": ar_time * 1000, |
| "ar_tokens_per_sec": ar_tps, |
| "gen_time_s": gen_time, |
| "gen_steps_per_sec": 50 / gen_time, |
| } |
|
|
|
|
| def benchmark_vram(model, batch_size: int = 1, seq_len: int = 512): |
| """Measure peak VRAM usage.""" |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| |
| device = next(model.parameters()).device |
| input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_len), device=device) |
| |
| with torch.no_grad(): |
| _ = model.forward_understanding(input_ids) |
| |
| peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| return {"peak_vram_mb": peak_mb} |
|
|
|
|
| def run_all_benchmarks(model_variant: str = "256M"): |
| """Run full benchmark suite.""" |
| print(f"="*70) |
| print(f"Tinman-SmolOmni-MLA Benchmark: {model_variant}") |
| print(f"="*70) |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| print("\n--- KV Cache Analysis ---") |
| config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}") |
| cache_info = benchmark_kv_cache(config) |
| for k, v in cache_info.items(): |
| print(f" {k}: {v}") |
| |
| |
| print("\n--- Model Loading ---") |
| load_info = benchmark_model_loading(model_variant, device) |
| print(f" Baseline (SmolVLM): {load_info['baseline']['load_time_s']:.1f}s, {load_info['baseline']['params_M']:.1f}M params") |
| print(f" SmolOmni-MLA: {load_info['smolomni']['load_time_s']:.1f}s, {load_info['smolomni']['params_M']:.1f}M params") |
| |
| |
| model = SmolOmniModel(config) |
| model = initialize_mla_from_pretrained(model, config.base_model, config) |
| model = model.to(device, dtype=torch.bfloat16) |
| model.eval() |
| |
| |
| print("\n--- VRAM Usage ---") |
| vram = benchmark_vram(model) |
| print(f" Peak VRAM: {vram['peak_vram_mb']:.0f} MB") |
| |
| |
| print("\n--- Throughput ---") |
| throughput = benchmark_throughput(model, None, config) |
| print(f" AR forward: {throughput['ar_time_ms']:.1f}ms ({throughput['ar_tokens_per_sec']:.0f} tok/s)") |
| print(f" Image gen (50 steps): {throughput['gen_time_s']:.1f}s ({throughput['gen_steps_per_sec']:.1f} step/s)") |
| |
| results = { |
| "model_variant": model_variant, |
| "kv_cache": cache_info, |
| "loading": load_info, |
| "vram": vram, |
| "throughput": throughput, |
| } |
| |
| |
| out_path = f"/app/benchmark_{model_variant}.json" |
| with open(out_path, 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to {out_path}") |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_variant", default="256M", choices=["256M", "500M"]) |
| args = parser.parse_args() |
| run_all_benchmarks(args.model_variant) |
|
|