""" Tinman-SmolOmni-MLA Benchmark Suite Measures: VRAM usage, KV cache size, throughput, generation quality Compares against SmolVLM baseline. """ import time import json import argparse import os import torch import torch.nn.functional as F from transformers import AutoModelForImageTextToText, AutoTokenizer from smolomni.config import SmolOmniConfig from smolomni.model import SmolOmniModel from smolomni.svd_init import initialize_mla_from_pretrained def benchmark_kv_cache(config: SmolOmniConfig): """Report KV cache sizes for different configurations.""" info = config.kv_cache_size_per_token() return info def benchmark_model_loading(model_variant: str, device: str = "cuda"): """Time model initialization and SVD init.""" config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}") # Baseline: SmolVLM start = time.time() baseline = AutoModelForImageTextToText.from_pretrained( config.base_model, torch_dtype=torch.bfloat16 ).to(device) baseline_time = time.time() - start baseline_params = sum(p.numel() for p in baseline.parameters()) # SmolOmni start = time.time() model = SmolOmniModel(config) model = initialize_mla_from_pretrained(model, config.base_model, config) model = model.to(device, dtype=torch.bfloat16) smol_time = time.time() - start smol_params = model.num_parameters() del baseline torch.cuda.empty_cache() return { "baseline": { "load_time_s": baseline_time, "params_M": baseline_params / 1e6, }, "smolomni": { "load_time_s": smol_time, "params_M": smol_params / 1e6, }, } def benchmark_throughput(model, tokenizer, config, batch_size: int = 1, seq_len: int = 512): """Measure tokens/second for understanding and generation.""" device = next(model.parameters()).device # Understanding (AR) input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) # Warm-up for _ in range(3): with torch.no_grad(): _ = model.forward_understanding(input_ids) torch.cuda.synchronize() start = time.time() n_iters = 10 for _ in range(n_iters): with torch.no_grad(): _ = model.forward_understanding(input_ids) torch.cuda.synchronize() ar_time = (time.time() - start) / n_iters ar_tps = (batch_size * seq_len) / ar_time # Generation (flow-matching, 50 steps) latent_shape = (batch_size, 4, 32, 32) torch.cuda.synchronize() start = time.time() with torch.no_grad(): _ = model.generate_image(input_ids[:batch_size], num_steps=50, latent_shape=latent_shape) torch.cuda.synchronize() gen_time = time.time() - start return { "ar_time_ms": ar_time * 1000, "ar_tokens_per_sec": ar_tps, "gen_time_s": gen_time, "gen_steps_per_sec": 50 / gen_time, } def benchmark_vram(model, batch_size: int = 1, seq_len: int = 512): """Measure peak VRAM usage.""" torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() device = next(model.parameters()).device input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_len), device=device) with torch.no_grad(): _ = model.forward_understanding(input_ids) peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 return {"peak_vram_mb": peak_mb} def run_all_benchmarks(model_variant: str = "256M"): """Run full benchmark suite.""" print(f"="*70) print(f"Tinman-SmolOmni-MLA Benchmark: {model_variant}") print(f"="*70) device = "cuda" if torch.cuda.is_available() else "cpu" # KV Cache comparison print("\n--- KV Cache Analysis ---") config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}") cache_info = benchmark_kv_cache(config) for k, v in cache_info.items(): print(f" {k}: {v}") # Model loading print("\n--- Model Loading ---") load_info = benchmark_model_loading(model_variant, device) print(f" Baseline (SmolVLM): {load_info['baseline']['load_time_s']:.1f}s, {load_info['baseline']['params_M']:.1f}M params") print(f" SmolOmni-MLA: {load_info['smolomni']['load_time_s']:.1f}s, {load_info['smolomni']['params_M']:.1f}M params") # Load model for throughput tests model = SmolOmniModel(config) model = initialize_mla_from_pretrained(model, config.base_model, config) model = model.to(device, dtype=torch.bfloat16) model.eval() # VRAM print("\n--- VRAM Usage ---") vram = benchmark_vram(model) print(f" Peak VRAM: {vram['peak_vram_mb']:.0f} MB") # Throughput print("\n--- Throughput ---") throughput = benchmark_throughput(model, None, config) print(f" AR forward: {throughput['ar_time_ms']:.1f}ms ({throughput['ar_tokens_per_sec']:.0f} tok/s)") print(f" Image gen (50 steps): {throughput['gen_time_s']:.1f}s ({throughput['gen_steps_per_sec']:.1f} step/s)") results = { "model_variant": model_variant, "kv_cache": cache_info, "loading": load_info, "vram": vram, "throughput": throughput, } # Save results out_path = f"/app/benchmark_{model_variant}.json" with open(out_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to {out_path}") return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_variant", default="256M", choices=["256M", "500M"]) args = parser.parse_args() run_all_benchmarks(args.model_variant)