TinmanLabSL commited on
Commit
f77e6fe
·
verified ·
1 Parent(s): 2cc0940

Add benchmark suite

Browse files
Files changed (1) hide show
  1. benchmark.py +174 -0
benchmark.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolOmni-MLA Benchmark Suite
3
+ Measures: VRAM usage, KV cache size, throughput, generation quality
4
+ Compares against SmolVLM baseline.
5
+ """
6
+ import time
7
+ import json
8
+ import argparse
9
+ import os
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import AutoModelForImageTextToText, AutoTokenizer
14
+
15
+ from smolomni.config import SmolOmniConfig
16
+ from smolomni.model import SmolOmniModel
17
+ from smolomni.svd_init import initialize_mla_from_pretrained
18
+
19
+
20
+ def benchmark_kv_cache(config: SmolOmniConfig):
21
+ """Report KV cache sizes for different configurations."""
22
+ info = config.kv_cache_size_per_token()
23
+ return info
24
+
25
+
26
+ def benchmark_model_loading(model_variant: str, device: str = "cuda"):
27
+ """Time model initialization and SVD init."""
28
+ config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}")
29
+
30
+ # Baseline: SmolVLM
31
+ start = time.time()
32
+ baseline = AutoModelForImageTextToText.from_pretrained(
33
+ config.base_model, torch_dtype=torch.bfloat16
34
+ ).to(device)
35
+ baseline_time = time.time() - start
36
+ baseline_params = sum(p.numel() for p in baseline.parameters())
37
+
38
+ # SmolOmni
39
+ start = time.time()
40
+ model = SmolOmniModel(config)
41
+ model = initialize_mla_from_pretrained(model, config.base_model, config)
42
+ model = model.to(device, dtype=torch.bfloat16)
43
+ smol_time = time.time() - start
44
+ smol_params = model.num_parameters()
45
+
46
+ del baseline
47
+ torch.cuda.empty_cache()
48
+
49
+ return {
50
+ "baseline": {
51
+ "load_time_s": baseline_time,
52
+ "params_M": baseline_params / 1e6,
53
+ },
54
+ "smolomni": {
55
+ "load_time_s": smol_time,
56
+ "params_M": smol_params / 1e6,
57
+ },
58
+ }
59
+
60
+
61
+ def benchmark_throughput(model, tokenizer, config, batch_size: int = 1, seq_len: int = 512):
62
+ """Measure tokens/second for understanding and generation."""
63
+ device = next(model.parameters()).device
64
+
65
+ # Understanding (AR)
66
+ input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)
67
+
68
+ # Warm-up
69
+ for _ in range(3):
70
+ with torch.no_grad():
71
+ _ = model.forward_understanding(input_ids)
72
+
73
+ torch.cuda.synchronize()
74
+ start = time.time()
75
+ n_iters = 10
76
+ for _ in range(n_iters):
77
+ with torch.no_grad():
78
+ _ = model.forward_understanding(input_ids)
79
+ torch.cuda.synchronize()
80
+ ar_time = (time.time() - start) / n_iters
81
+ ar_tps = (batch_size * seq_len) / ar_time
82
+
83
+ # Generation (flow-matching, 50 steps)
84
+ latent_shape = (batch_size, 4, 32, 32)
85
+ torch.cuda.synchronize()
86
+ start = time.time()
87
+ with torch.no_grad():
88
+ _ = model.generate_image(input_ids[:batch_size], num_steps=50, latent_shape=latent_shape)
89
+ torch.cuda.synchronize()
90
+ gen_time = time.time() - start
91
+
92
+ return {
93
+ "ar_time_ms": ar_time * 1000,
94
+ "ar_tokens_per_sec": ar_tps,
95
+ "gen_time_s": gen_time,
96
+ "gen_steps_per_sec": 50 / gen_time,
97
+ }
98
+
99
+
100
+ def benchmark_vram(model, batch_size: int = 1, seq_len: int = 512):
101
+ """Measure peak VRAM usage."""
102
+ torch.cuda.empty_cache()
103
+ torch.cuda.reset_peak_memory_stats()
104
+
105
+ device = next(model.parameters()).device
106
+ input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_len), device=device)
107
+
108
+ with torch.no_grad():
109
+ _ = model.forward_understanding(input_ids)
110
+
111
+ peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
112
+ return {"peak_vram_mb": peak_mb}
113
+
114
+
115
+ def run_all_benchmarks(model_variant: str = "256M"):
116
+ """Run full benchmark suite."""
117
+ print(f"="*70)
118
+ print(f"SmolOmni-MLA Benchmark: {model_variant}")
119
+ print(f"="*70)
120
+
121
+ device = "cuda" if torch.cuda.is_available() else "cpu"
122
+
123
+ # KV Cache comparison
124
+ print("\n--- KV Cache Analysis ---")
125
+ config = SmolOmniConfig.from_pretrained(f"mla-hybrid-ar-flow-{model_variant}")
126
+ cache_info = benchmark_kv_cache(config)
127
+ for k, v in cache_info.items():
128
+ print(f" {k}: {v}")
129
+
130
+ # Model loading
131
+ print("\n--- Model Loading ---")
132
+ load_info = benchmark_model_loading(model_variant, device)
133
+ print(f" Baseline (SmolVLM): {load_info['baseline']['load_time_s']:.1f}s, {load_info['baseline']['params_M']:.1f}M params")
134
+ print(f" SmolOmni-MLA: {load_info['smolomni']['load_time_s']:.1f}s, {load_info['smolomni']['params_M']:.1f}M params")
135
+
136
+ # Load model for throughput tests
137
+ model = SmolOmniModel(config)
138
+ model = initialize_mla_from_pretrained(model, config.base_model, config)
139
+ model = model.to(device, dtype=torch.bfloat16)
140
+ model.eval()
141
+
142
+ # VRAM
143
+ print("\n--- VRAM Usage ---")
144
+ vram = benchmark_vram(model)
145
+ print(f" Peak VRAM: {vram['peak_vram_mb']:.0f} MB")
146
+
147
+ # Throughput
148
+ print("\n--- Throughput ---")
149
+ throughput = benchmark_throughput(model, None, config)
150
+ print(f" AR forward: {throughput['ar_time_ms']:.1f}ms ({throughput['ar_tokens_per_sec']:.0f} tok/s)")
151
+ print(f" Image gen (50 steps): {throughput['gen_time_s']:.1f}s ({throughput['gen_steps_per_sec']:.1f} step/s)")
152
+
153
+ results = {
154
+ "model_variant": model_variant,
155
+ "kv_cache": cache_info,
156
+ "loading": load_info,
157
+ "vram": vram,
158
+ "throughput": throughput,
159
+ }
160
+
161
+ # Save results
162
+ out_path = f"/app/benchmark_{model_variant}.json"
163
+ with open(out_path, 'w') as f:
164
+ json.dump(results, f, indent=2)
165
+ print(f"\nResults saved to {out_path}")
166
+
167
+ return results
168
+
169
+
170
+ if __name__ == "__main__":
171
+ parser = argparse.ArgumentParser()
172
+ parser.add_argument("--model_variant", default="256M", choices=["256M", "500M"])
173
+ args = parser.parse_args()
174
+ run_all_benchmarks(args.model_variant)