ARBS / arbitor /profiling.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
Profiling utilities: torch.profiler wrapper and analysis tools.
Following D-103: profile first, optimize only hot paths.
Uses torch.profiler to identify training loop bottlenecks.
"""
import sys
import os
import json
import math
import torch
sys.path.insert(0, os.path.dirname(__file__))
from .main import ARBModel
from .config import VOCAB, CTX
def profile_training(model, train_data, device, n_steps=20, warmup_steps=5,
top_k=10, batch_size=64, ctx=CTX):
"""
Profile N training steps using torch.profiler.
Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling),
then profiled steps. Returns list of top-K hot path tuples and saves JSON.
Args:
model: ARBModel instance
train_data: 1D byte tensor of training data
device: 'cuda' or 'cpu'
n_steps: Number of profiled training steps
warmup_steps: Steps before profiling begins (no tracing)
top_k: Number of top operations to return
batch_size: Batch size for each training step
ctx: Context window length
Returns:
List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls
"""
model.train()
prof = None
if device == "cuda":
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
with_flops=True,
)
else:
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=True,
with_stack=False,
)
# Warmup steps (no profiling)
for _ in range(warmup_steps):
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
x = torch.stack([train_data[j: j + ctx] for j in ix])
targets = x[:, 3:]
x = x.to(device)
targets = targets.to(device)
with torch.no_grad():
model(x, targets=targets)
# Profiled steps
prof.start()
for _ in range(n_steps):
ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,))
x = torch.stack([train_data[j: j + ctx] for j in ix])
targets = x[:, 3:]
x = x.to(device)
targets = targets.to(device)
with torch.no_grad():
model(x, targets=targets)
if device == "cuda":
torch.cuda.synchronize()
prof.stop()
# Process profiler output
if device == "cuda":
key_avg = prof.key_averages()
table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k)
else:
key_avg = prof.key_averages()
table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k)
# Extract top-K entries
events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k]
top_results = []
for evt in events[:top_k]:
# device_time replaces deprecated cuda_time in recent PyTorch
cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None
else evt.cuda_time if hasattr(evt, 'cuda_time') else 0)
entry = {
"op_name": evt.key if hasattr(evt, 'key') else str(evt),
"cuda_time_us": cuda_t,
"cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0,
"calls": evt.count if hasattr(evt, 'count') else 1,
}
top_results.append(entry)
# Print summary
print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k))
print(table)
print("============================================\n")
# Save profiler output as JSON
prof.export_chrome_trace("/tmp/profiler_trace.json")
return top_results
def analyze_profiler_output(prof_path):
"""
Load saved profiler JSON output and extract key insights.
Args:
prof_path: Path to saved profiler JSON file
Returns:
List of dicts with op_name, cuda_time_us, cpu_time_us, calls
"""
with open(prof_path, "r") as f:
data = json.load(f)
# Profiler JSON can be a dict with 'traceEvents' or a flat list
if isinstance(data, dict) and "traceEvents" in data:
events = data["traceEvents"]
elif isinstance(data, list):
events = data
else:
events = []
# Aggregate events by name
op_stats = {}
for evt in events:
if isinstance(evt, dict):
name = evt.get("name", "unknown")
dur = evt.get("dur", 0) # microseconds
cat = evt.get("cat", "")
if name not in op_stats:
op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0}
if "gpu" in cat.lower():
op_stats[name]["cuda_time_us"] += dur
elif "cpu" in cat.lower() or cat == "":
op_stats[name]["cpu_time_us"] += dur
op_stats[name]["calls"] += 1
# Sort by CUDA time descending
sorted_ops = sorted(
op_stats.items(),
key=lambda x: x[1]["cuda_time_us"],
reverse=True,
)
results = []
for name, stats in sorted_ops:
results.append({
"op_name": name,
"cuda_time_us": stats["cuda_time_us"],
"cpu_time_us": stats["cpu_time_us"],
"calls": stats["calls"],
})
# Print formatted summary
print("\n=== Profiler Analysis ===")
print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}")
print("-" * 80)
for r in results[:20]:
print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}")
# Identify dominating patterns
total_cuda = sum(r["cuda_time_us"] for r in results)
if total_cuda > 0:
print("\n=== Hot Path Analysis ===")
for r in results[:5]:
pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0
label = ""
if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower():
label = " → VQ candidate for Triton kernel"
elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower():
label = " → MoE dispatch candidate"
elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower():
label = " → Embedding gather (existing Triton kernel)"
elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower():
label = " → General matmul (torch.compile candidate)"
print(f" {r['op_name']:<40} {pct:>5.1f}%{label}")
print("============================================\n")
return results