| """ |
| Benchmark harness for training throughput and peak GPU memory. |
| |
| Measures tokens/sec and peak memory MB, saves results as JSON. |
| Follows benchmark_phase2.py pattern for CUDA synchronization and memory tracking. |
| """ |
|
|
| import sys |
| import os |
| import json |
| import time |
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(__file__)) |
|
|
| from arbitor.main import MORPHTernaryModel, VOCAB, CTX |
|
|
|
|
| def run_benchmark(model, train_data, device, n_steps=100, warmup_steps=10, |
| batch_size=64, ctx=CTX): |
| """ |
| Measure training throughput (tokens/sec) and peak GPU memory (MB). |
| |
| Resets peak memory stats, runs warmup steps (no timing), then timed steps. |
| Uses torch.cuda.synchronize() before first and after last timed step for |
| accurate wall-clock timing. |
| |
| Args: |
| model: MORPHTernaryModel instance |
| train_data: 1D byte tensor of training data |
| device: 'cuda' or 'cpu' |
| n_steps: Number of timed steps |
| warmup_steps: Steps before timing begins |
| batch_size: Batch size for each step |
| ctx: Context window length |
| |
| Returns: |
| dict with keys: tokens_per_sec, peak_memory_mb, n_steps, batch_size, |
| ctx, device |
| """ |
| model.train() |
|
|
| |
| if device == "cuda": |
| torch.cuda.reset_peak_memory_stats(device) |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| all_ix = torch.randint(0, len(train_data) - ctx - 1, (warmup_steps + n_steps, batch_size)) |
| all_x = torch.stack([ |
| torch.stack([train_data[ix[i]: ix[i] + ctx] for i in range(batch_size)]) |
| for ix in all_ix |
| ]) |
| all_targets = all_x[:, :, 3:] |
|
|
| |
| for step_idx in range(warmup_steps): |
| x = all_x[step_idx].to(device, non_blocking=True) |
| targets = all_targets[step_idx].to(device, non_blocking=True) |
| with torch.no_grad(): |
| model(x, targets=targets) |
|
|
| if device == "cuda": |
| torch.cuda.synchronize() |
|
|
| |
| t_start = time.perf_counter() |
| for step_idx in range(warmup_steps, warmup_steps + n_steps): |
| x = all_x[step_idx].to(device, non_blocking=True) |
| targets = all_targets[step_idx].to(device, non_blocking=True) |
| with torch.no_grad(): |
| model(x, targets=targets) |
|
|
| if device == "cuda": |
| torch.cuda.synchronize() |
| t_end = time.perf_counter() |
|
|
| elapsed = t_end - t_start |
| tokens_total = n_steps * batch_size * ctx |
| tokens_per_sec = tokens_total / elapsed if elapsed > 0 else 0.0 |
|
|
| |
| peak_memory_mb = 0.0 |
| if device == "cuda": |
| peak_memory_mb = torch.cuda.max_memory_allocated(device) / (1024 * 1024) |
|
|
| result = { |
| "tokens_per_sec": round(tokens_per_sec, 2), |
| "peak_memory_mb": round(peak_memory_mb, 2), |
| "n_steps": n_steps, |
| "batch_size": batch_size, |
| "ctx": ctx, |
| "device": device, |
| } |
|
|
| return result |
|
|
|
|
| def compare_benchmarks(before_path, after_path): |
| """ |
| Compare two benchmark result JSON files and compute deltas. |
| |
| Args: |
| before_path: Path to baseline benchmark JSON |
| after_path: Path to optimized benchmark JSON |
| |
| Returns: |
| dict with keys: before, after, delta, pct_change |
| delta[tokens_per_sec] = after - before |
| pct_change[tokens_per_sec] = (after - before) / before * 100 |
| """ |
| with open(before_path, "r") as f: |
| before = json.load(f) |
| with open(after_path, "r") as f: |
| after = json.load(f) |
|
|
| metrics = ["tokens_per_sec", "peak_memory_mb"] |
| delta = {} |
| pct_change = {} |
|
|
| for key in metrics: |
| b = before.get(key, 0.0) |
| a = after.get(key, 0.0) |
| delta[key] = round(a - b, 2) |
| if b != 0: |
| pct_change[key] = round(((a - b) / abs(b)) * 100.0, 2) |
| else: |
| pct_change[key] = 0.0 |
|
|
| |
| print("\n=== Benchmark Comparison ===") |
| print(f"{'Metric':<20} {'Before':>12} {'After':>12} {'Delta':>12} {'Change':>10}") |
| print("-" * 66) |
| for key in metrics: |
| b = before.get(key, 0.0) |
| a = after.get(key, 0.0) |
| print(f"{key:<20} {b:>12.2f} {a:>12.2f} {delta[key]:>+12.2f} {pct_change[key]:>+9.2f}%") |
| print("=" * 66) |
|
|
| return { |
| "before": before, |
| "after": after, |
| "delta": delta, |
| "pct_change": pct_change, |
| } |
|
|