TinmanLabSL's picture
Upload benchmark.py
116d857 verified
"""
Tinman-SmolOmni-MLA Benchmark Suite
Measures: VRAM usage, KV cache size, throughput, generation quality
Compares against SmolVLM baseline.
"""
import time
import json
import argparse
import os
import torch
import torch.nn.functional as F
from transformers import AutoModelForImageTextToText, AutoTokenizer
from smolomni.config import SmolOmniConfig
from smolomni.model import SmolOmniModel
from smolomni.svd_init import initialize_mla_from_pretrained
def benchmark_kv_cache(config: SmolOmniConfig):
"""Report KV cache sizes for different configurations."""
info = config.kv_cache_size_per_token()
return info
def benchmark_model_loading(model_variant: str, device: str = "cuda"):
"""Time model initialization and SVD init."""
config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}")
# Baseline: SmolVLM
start = time.time()
baseline = AutoModelForImageTextToText.from_pretrained(
config.base_model, torch_dtype=torch.bfloat16
).to(device)
baseline_time = time.time() - start
baseline_params = sum(p.numel() for p in baseline.parameters())
# SmolOmni
start = time.time()
model = SmolOmniModel(config)
model = initialize_mla_from_pretrained(model, config.base_model, config)
model = model.to(device, dtype=torch.bfloat16)
smol_time = time.time() - start
smol_params = model.num_parameters()
del baseline
torch.cuda.empty_cache()
return {
"baseline": {
"load_time_s": baseline_time,
"params_M": baseline_params / 1e6,
},
"smolomni": {
"load_time_s": smol_time,
"params_M": smol_params / 1e6,
},
}
def benchmark_throughput(model, tokenizer, config, batch_size: int = 1, seq_len: int = 512):
"""Measure tokens/second for understanding and generation."""
device = next(model.parameters()).device
# Understanding (AR)
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)
# Warm-up
for _ in range(3):
with torch.no_grad():
_ = model.forward_understanding(input_ids)
torch.cuda.synchronize()
start = time.time()
n_iters = 10
for _ in range(n_iters):
with torch.no_grad():
_ = model.forward_understanding(input_ids)
torch.cuda.synchronize()
ar_time = (time.time() - start) / n_iters
ar_tps = (batch_size * seq_len) / ar_time
# Generation (flow-matching, 50 steps)
latent_shape = (batch_size, 4, 32, 32)
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
_ = model.generate_image(input_ids[:batch_size], num_steps=50, latent_shape=latent_shape)
torch.cuda.synchronize()
gen_time = time.time() - start
return {
"ar_time_ms": ar_time * 1000,
"ar_tokens_per_sec": ar_tps,
"gen_time_s": gen_time,
"gen_steps_per_sec": 50 / gen_time,
}
def benchmark_vram(model, batch_size: int = 1, seq_len: int = 512):
"""Measure peak VRAM usage."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = next(model.parameters()).device
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_len), device=device)
with torch.no_grad():
_ = model.forward_understanding(input_ids)
peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
return {"peak_vram_mb": peak_mb}
def run_all_benchmarks(model_variant: str = "256M"):
"""Run full benchmark suite."""
print(f"="*70)
print(f"Tinman-SmolOmni-MLA Benchmark: {model_variant}")
print(f"="*70)
device = "cuda" if torch.cuda.is_available() else "cpu"
# KV Cache comparison
print("\n--- KV Cache Analysis ---")
config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}")
cache_info = benchmark_kv_cache(config)
for k, v in cache_info.items():
print(f" {k}: {v}")
# Model loading
print("\n--- Model Loading ---")
load_info = benchmark_model_loading(model_variant, device)
print(f" Baseline (SmolVLM): {load_info['baseline']['load_time_s']:.1f}s, {load_info['baseline']['params_M']:.1f}M params")
print(f" SmolOmni-MLA: {load_info['smolomni']['load_time_s']:.1f}s, {load_info['smolomni']['params_M']:.1f}M params")
# Load model for throughput tests
model = SmolOmniModel(config)
model = initialize_mla_from_pretrained(model, config.base_model, config)
model = model.to(device, dtype=torch.bfloat16)
model.eval()
# VRAM
print("\n--- VRAM Usage ---")
vram = benchmark_vram(model)
print(f" Peak VRAM: {vram['peak_vram_mb']:.0f} MB")
# Throughput
print("\n--- Throughput ---")
throughput = benchmark_throughput(model, None, config)
print(f" AR forward: {throughput['ar_time_ms']:.1f}ms ({throughput['ar_tokens_per_sec']:.0f} tok/s)")
print(f" Image gen (50 steps): {throughput['gen_time_s']:.1f}s ({throughput['gen_steps_per_sec']:.1f} step/s)")
results = {
"model_variant": model_variant,
"kv_cache": cache_info,
"loading": load_info,
"vram": vram,
"throughput": throughput,
}
# Save results
out_path = f"/app/benchmark_{model_variant}.json"
with open(out_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {out_path}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_variant", default="256M", choices=["256M", "500M"])
args = parser.parse_args()
run_all_benchmarks(args.model_variant)