File size: 6,750 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
Profiling utilities: torch.profiler wrapper and analysis tools.
Following D-103: profile first, optimize only hot paths.
Uses torch.profiler to identify training loop bottlenecks.
"""
import sys
import os
import json
import math
import torch
sys.path.insert(0, os.path.dirname(__file__))
from .main import ARBModel
from .config import VOCAB, CTX
def profile_training(model, train_data, device, n_steps=20, warmup_steps=5,
top_k=10, batch_size=64, ctx=CTX):
"""
Profile N training steps using torch.profiler.
Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling),
then profiled steps. Returns list of top-K hot path tuples and saves JSON.
Args:
model: ARBModel instance
train_data: 1D byte tensor of training data
device: 'cuda' or 'cpu'
n_steps: Number of profiled training steps
warmup_steps: Steps before profiling begins (no tracing)
top_k: Number of top operations to return
batch_size: Batch size for each training step
ctx: Context window length
Returns:
List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls
"""
model.train()
prof = None
if device == "cuda":
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
with_flops=True,
)
else:
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=True,
with_stack=False,
)
# Warmup steps (no profiling)
for _ in range(warmup_steps):
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
x = torch.stack([train_data[j: j + ctx] for j in ix])
targets = x[:, 3:]
x = x.to(device)
targets = targets.to(device)
with torch.no_grad():
model(x, targets=targets)
# Profiled steps
prof.start()
for _ in range(n_steps):
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
x = torch.stack([train_data[j: j + ctx] for j in ix])
targets = x[:, 3:]
x = x.to(device)
targets = targets.to(device)
with torch.no_grad():
model(x, targets=targets)
if device == "cuda":
torch.cuda.synchronize()
prof.stop()
# Process profiler output
if device == "cuda":
key_avg = prof.key_averages()
table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k)
else:
key_avg = prof.key_averages()
table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k)
# Extract top-K entries
events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k]
top_results = []
for evt in events[:top_k]:
# device_time replaces deprecated cuda_time in recent PyTorch
cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None
else evt.cuda_time if hasattr(evt, 'cuda_time') else 0)
entry = {
"op_name": evt.key if hasattr(evt, 'key') else str(evt),
"cuda_time_us": cuda_t,
"cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0,
"calls": evt.count if hasattr(evt, 'count') else 1,
}
top_results.append(entry)
# Print summary
print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k))
print(table)
print("============================================\n")
# Save profiler output as JSON
prof.export_chrome_trace("/tmp/profiler_trace.json")
return top_results
def analyze_profiler_output(prof_path):
"""
Load saved profiler JSON output and extract key insights.
Args:
prof_path: Path to saved profiler JSON file
Returns:
List of dicts with op_name, cuda_time_us, cpu_time_us, calls
"""
with open(prof_path, "r") as f:
data = json.load(f)
# Profiler JSON can be a dict with 'traceEvents' or a flat list
if isinstance(data, dict) and "traceEvents" in data:
events = data["traceEvents"]
elif isinstance(data, list):
events = data
else:
events = []
# Aggregate events by name
op_stats = {}
for evt in events:
if isinstance(evt, dict):
name = evt.get("name", "unknown")
dur = evt.get("dur", 0) # microseconds
cat = evt.get("cat", "")
if name not in op_stats:
op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0}
if "gpu" in cat.lower():
op_stats[name]["cuda_time_us"] += dur
elif "cpu" in cat.lower() or cat == "":
op_stats[name]["cpu_time_us"] += dur
op_stats[name]["calls"] += 1
# Sort by CUDA time descending
sorted_ops = sorted(
op_stats.items(),
key=lambda x: x[1]["cuda_time_us"],
reverse=True,
)
results = []
for name, stats in sorted_ops:
results.append({
"op_name": name,
"cuda_time_us": stats["cuda_time_us"],
"cpu_time_us": stats["cpu_time_us"],
"calls": stats["calls"],
})
# Print formatted summary
print("\n=== Profiler Analysis ===")
print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}")
print("-" * 80)
for r in results[:20]:
print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}")
# Identify dominating patterns
total_cuda = sum(r["cuda_time_us"] for r in results)
if total_cuda > 0:
print("\n=== Hot Path Analysis ===")
for r in results[:5]:
pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0
label = ""
if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower():
label = " → VQ candidate for Triton kernel"
elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower():
label = " → MoE dispatch candidate"
elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower():
label = " → Embedding gather (existing Triton kernel)"
elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower():
label = " → General matmul (torch.compile candidate)"
print(f" {r['op_name']:<40} {pct:>5.1f}%{label}")
print("============================================\n")
return results
|