| """ |
| Profiling utilities: torch.profiler wrapper and analysis tools. |
| |
| Following D-103: profile first, optimize only hot paths. |
| Uses torch.profiler to identify training loop bottlenecks. |
| """ |
|
|
| import sys |
| import os |
| import json |
| import math |
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(__file__)) |
|
|
| from .main import ARBModel |
| from .config import VOCAB, CTX |
|
|
|
|
| def profile_training(model, train_data, device, n_steps=20, warmup_steps=5, |
| top_k=10, batch_size=64, ctx=CTX): |
| """ |
| Profile N training steps using torch.profiler. |
| |
| Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling), |
| then profiled steps. Returns list of top-K hot path tuples and saves JSON. |
| |
| Args: |
| model: ARBModel instance |
| train_data: 1D byte tensor of training data |
| device: 'cuda' or 'cpu' |
| n_steps: Number of profiled training steps |
| warmup_steps: Steps before profiling begins (no tracing) |
| top_k: Number of top operations to return |
| batch_size: Batch size for each training step |
| ctx: Context window length |
| |
| Returns: |
| List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls |
| """ |
| model.train() |
| prof = None |
|
|
| if device == "cuda": |
| prof = torch.profiler.profile( |
| activities=[ |
| torch.profiler.ProfilerActivity.CPU, |
| torch.profiler.ProfilerActivity.CUDA, |
| ], |
| record_shapes=True, |
| with_stack=True, |
| with_flops=True, |
| ) |
| else: |
| prof = torch.profiler.profile( |
| activities=[torch.profiler.ProfilerActivity.CPU], |
| record_shapes=True, |
| with_stack=False, |
| ) |
|
|
| |
| for _ in range(warmup_steps): |
| ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,)) |
| x = torch.stack([train_data[j: j + ctx] for j in ix]) |
| targets = x[:, 3:] |
| x = x.to(device) |
| targets = targets.to(device) |
| with torch.no_grad(): |
| model(x, targets=targets) |
|
|
| |
| prof.start() |
| for _ in range(n_steps): |
| ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,)) |
| x = torch.stack([train_data[j: j + ctx] for j in ix]) |
| targets = x[:, 3:] |
| x = x.to(device) |
| targets = targets.to(device) |
| with torch.no_grad(): |
| model(x, targets=targets) |
| if device == "cuda": |
| torch.cuda.synchronize() |
| prof.stop() |
|
|
| |
| if device == "cuda": |
| key_avg = prof.key_averages() |
| table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k) |
| else: |
| key_avg = prof.key_averages() |
| table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k) |
|
|
| |
| events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k] |
| top_results = [] |
| for evt in events[:top_k]: |
| |
| cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None |
| else evt.cuda_time if hasattr(evt, 'cuda_time') else 0) |
| entry = { |
| "op_name": evt.key if hasattr(evt, 'key') else str(evt), |
| "cuda_time_us": cuda_t, |
| "cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0, |
| "calls": evt.count if hasattr(evt, 'count') else 1, |
| } |
| top_results.append(entry) |
|
|
| |
| print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k)) |
| print(table) |
| print("============================================\n") |
|
|
| |
| prof.export_chrome_trace("/tmp/profiler_trace.json") |
|
|
| return top_results |
|
|
|
|
| def analyze_profiler_output(prof_path): |
| """ |
| Load saved profiler JSON output and extract key insights. |
| |
| Args: |
| prof_path: Path to saved profiler JSON file |
| |
| Returns: |
| List of dicts with op_name, cuda_time_us, cpu_time_us, calls |
| """ |
| with open(prof_path, "r") as f: |
| data = json.load(f) |
|
|
| |
| if isinstance(data, dict) and "traceEvents" in data: |
| events = data["traceEvents"] |
| elif isinstance(data, list): |
| events = data |
| else: |
| events = [] |
|
|
| |
| op_stats = {} |
| for evt in events: |
| if isinstance(evt, dict): |
| name = evt.get("name", "unknown") |
| dur = evt.get("dur", 0) |
| cat = evt.get("cat", "") |
| if name not in op_stats: |
| op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0} |
| if "gpu" in cat.lower(): |
| op_stats[name]["cuda_time_us"] += dur |
| elif "cpu" in cat.lower() or cat == "": |
| op_stats[name]["cpu_time_us"] += dur |
| op_stats[name]["calls"] += 1 |
|
|
| |
| sorted_ops = sorted( |
| op_stats.items(), |
| key=lambda x: x[1]["cuda_time_us"], |
| reverse=True, |
| ) |
|
|
| results = [] |
| for name, stats in sorted_ops: |
| results.append({ |
| "op_name": name, |
| "cuda_time_us": stats["cuda_time_us"], |
| "cpu_time_us": stats["cpu_time_us"], |
| "calls": stats["calls"], |
| }) |
|
|
| |
| print("\n=== Profiler Analysis ===") |
| print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}") |
| print("-" * 80) |
| for r in results[:20]: |
| print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}") |
|
|
| |
| total_cuda = sum(r["cuda_time_us"] for r in results) |
| if total_cuda > 0: |
| print("\n=== Hot Path Analysis ===") |
| for r in results[:5]: |
| pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0 |
| label = "" |
| if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower(): |
| label = " → VQ candidate for Triton kernel" |
| elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower(): |
| label = " → MoE dispatch candidate" |
| elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower(): |
| label = " → Embedding gather (existing Triton kernel)" |
| elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower(): |
| label = " → General matmul (torch.compile candidate)" |
| print(f" {r['op_name']:<40} {pct:>5.1f}%{label}") |
|
|
| print("============================================\n") |
| return results |
|
|