File size: 4,416 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
Benchmark harness for training throughput and peak GPU memory.
Measures tokens/sec and peak memory MB, saves results as JSON.
Follows benchmark_phase2.py pattern for CUDA synchronization and memory tracking.
"""
import sys
import os
import json
import time
import torch
sys.path.insert(0, os.path.dirname(__file__))
from arbitor.main import MORPHTernaryModel, VOCAB, CTX
def run_benchmark(model, train_data, device, n_steps=100, warmup_steps=10,
batch_size=64, ctx=CTX):
"""
Measure training throughput (tokens/sec) and peak GPU memory (MB).
Resets peak memory stats, runs warmup steps (no timing), then timed steps.
Uses torch.cuda.synchronize() before first and after last timed step for
accurate wall-clock timing.
Args:
model: MORPHTernaryModel instance
train_data: 1D byte tensor of training data
device: 'cuda' or 'cpu'
n_steps: Number of timed steps
warmup_steps: Steps before timing begins
batch_size: Batch size for each step
ctx: Context window length
Returns:
dict with keys: tokens_per_sec, peak_memory_mb, n_steps, batch_size,
ctx, device
"""
model.train()
# Reset memory tracking
if device == "cuda":
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Generate data once (avoid IO jitter)
all_ix = torch.randint(0, len(train_data) - ctx - 1, (warmup_steps + n_steps, batch_size))
all_x = torch.stack([
torch.stack([train_data[ix[i]: ix[i] + ctx] for i in range(batch_size)])
for ix in all_ix
])
all_targets = all_x[:, :, 3:]
# Warmup steps
for step_idx in range(warmup_steps):
x = all_x[step_idx].to(device, non_blocking=True)
targets = all_targets[step_idx].to(device, non_blocking=True)
with torch.no_grad():
model(x, targets=targets)
if device == "cuda":
torch.cuda.synchronize()
# Timed steps
t_start = time.perf_counter()
for step_idx in range(warmup_steps, warmup_steps + n_steps):
x = all_x[step_idx].to(device, non_blocking=True)
targets = all_targets[step_idx].to(device, non_blocking=True)
with torch.no_grad():
model(x, targets=targets)
if device == "cuda":
torch.cuda.synchronize()
t_end = time.perf_counter()
elapsed = t_end - t_start
tokens_total = n_steps * batch_size * ctx
tokens_per_sec = tokens_total / elapsed if elapsed > 0 else 0.0
# Peak memory
peak_memory_mb = 0.0
if device == "cuda":
peak_memory_mb = torch.cuda.max_memory_allocated(device) / (1024 * 1024)
result = {
"tokens_per_sec": round(tokens_per_sec, 2),
"peak_memory_mb": round(peak_memory_mb, 2),
"n_steps": n_steps,
"batch_size": batch_size,
"ctx": ctx,
"device": device,
}
return result
def compare_benchmarks(before_path, after_path):
"""
Compare two benchmark result JSON files and compute deltas.
Args:
before_path: Path to baseline benchmark JSON
after_path: Path to optimized benchmark JSON
Returns:
dict with keys: before, after, delta, pct_change
delta[tokens_per_sec] = after - before
pct_change[tokens_per_sec] = (after - before) / before * 100
"""
with open(before_path, "r") as f:
before = json.load(f)
with open(after_path, "r") as f:
after = json.load(f)
metrics = ["tokens_per_sec", "peak_memory_mb"]
delta = {}
pct_change = {}
for key in metrics:
b = before.get(key, 0.0)
a = after.get(key, 0.0)
delta[key] = round(a - b, 2)
if b != 0:
pct_change[key] = round(((a - b) / abs(b)) * 100.0, 2)
else:
pct_change[key] = 0.0
# Print comparison table
print("\n=== Benchmark Comparison ===")
print(f"{'Metric':<20} {'Before':>12} {'After':>12} {'Delta':>12} {'Change':>10}")
print("-" * 66)
for key in metrics:
b = before.get(key, 0.0)
a = after.get(key, 0.0)
print(f"{key:<20} {b:>12.2f} {a:>12.2f} {delta[key]:>+12.2f} {pct_change[key]:>+9.2f}%")
print("=" * 66)
return {
"before": before,
"after": after,
"delta": delta,
"pct_change": pct_change,
}
|