""" Loss Spectrum Profiler — Standalone ===================================== Builds its own model + noise data. Profiles every loss computation in the GeoLIP pipeline with CUDA-synced microsecond timing. Zero external dependencies beyond torch. Single cell. """ import torch import torch.nn as nn import torch.nn.functional as F import time import math DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ═══════════════════════════════════════════════════════════════ # Config — matches our architecture # ═══════════════════════════════════════════════════════════════ DIM = 256 N_ANCHORS = 256 N_COMP = 8 D_COMP = 64 BATCH = 256 NUM_CLASSES = 100 # ═══════════════════════════════════════════════════════════════ # Minimal model components (self-contained, no imports) # ═══════════════════════════════════════════════════════════════ class ProfileEncoder(nn.Module): def __init__(self, dim=256): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), nn.MaxPool2d(2), nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(), nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d(1), nn.Flatten(), ) self.proj = nn.Sequential(nn.Linear(384, dim), nn.LayerNorm(dim)) def forward(self, x): feat = self.features(x) return F.normalize(self.proj(feat), dim=-1), feat[:, :1] # emb, fake raw_mag class ProfilePatchwork(nn.Module): def __init__(self, n_anchors=256, n_comp=8, d_comp=64): super().__init__() apc = n_anchors // n_comp self.n_comp = n_comp self.comps = nn.ModuleList([ nn.Sequential(nn.Linear(apc, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp)) for _ in range(n_comp) ]) def forward(self, tri): apc = tri.shape[1] // self.n_comp parts = [] for k in range(self.n_comp): parts.append(self.comps[k](tri[:, k*apc:(k+1)*apc])) return torch.cat(parts, dim=-1) # Build all components print("Building profile model...") encoder = ProfileEncoder(DIM).to(DEVICE) anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, DIM, device=DEVICE), dim=-1)) patchwork = ProfilePatchwork(N_ANCHORS, N_COMP, D_COMP).to(DEVICE) bridge = nn.Linear(N_COMP * D_COMP, N_ANCHORS).to(DEVICE) task_head = nn.Sequential( nn.Linear(N_ANCHORS + N_COMP * D_COMP + DIM, N_COMP * D_COMP), nn.GELU(), nn.LayerNorm(N_COMP * D_COMP), nn.Dropout(0.1), nn.Linear(N_COMP * D_COMP, NUM_CLASSES), ).to(DEVICE) # Fake batch — random images + labels v1 = torch.randn(BATCH, 3, 32, 32, device=DEVICE) v2 = torch.randn(BATCH, 3, 32, 32, device=DEVICE) targets = torch.randint(0, NUM_CLASSES, (BATCH,), device=DEVICE) labels_nce = torch.arange(BATCH, device=DEVICE) # Pre-compute intermediates with torch.no_grad(): emb1, raw_mag1 = encoder(v1) emb2, raw_mag2 = encoder(v2) anchors_n = F.normalize(anchors, dim=-1) cos1 = emb1 @ anchors_n.T cos2 = emb2 @ anchors_n.T tri1 = 1.0 - cos1 tri2 = 1.0 - cos2 assign1 = F.softmax(cos1 / 0.1, dim=-1) assign2 = F.softmax(cos2 / 0.1, dim=-1) pw1 = patchwork(tri1) pw2 = patchwork(tri2) bridge1 = bridge(pw1) feat1 = torch.cat([assign1, pw1, emb1], dim=-1) logits1 = task_head(feat1) all_params = (list(encoder.parameters()) + [anchors] + list(patchwork.parameters()) + list(bridge.parameters()) + list(task_head.parameters())) print(f" Device: {DEVICE}") print(f" Batch: {BATCH}, Dim: {DIM}, Anchors: {N_ANCHORS}, Comp: {N_COMP}×{D_COMP}") n_params = sum(p.numel() for p in all_params) print(f" Parameters: {n_params:,}") # ═══════════════════════════════════════════════════════════════ # Timer # ═══════════════════════════════════════════════════════════════ def timed(name, fn, n_runs=30, warmup=5): """CUDA-synced timing. Returns (result, avg_ms).""" for _ in range(warmup): r = fn() torch.cuda.synchronize() times = [] for _ in range(n_runs): torch.cuda.synchronize() t0 = time.perf_counter() r = fn() torch.cuda.synchronize() times.append((time.perf_counter() - t0) * 1000) avg = sum(times) / len(times) return r, avg results = [] def record(name, fn, **kw): _, ms = timed(name, fn, **kw) results.append((name, ms)) return ms # ═══════════════════════════════════════════════════════════════ # SECTION 1: Forward Components # ═══════════════════════════════════════════════════════════════ print(f"\n{'='*80}") print("SECTION 1: FORWARD PASS COMPONENTS") print(f"{'='*80}\n") record("encoder(v1)", lambda: encoder(v1)) record("triangulation (emb@A.T)", lambda: emb1 @ anchors_n.T) record("soft_assign (softmax)", lambda: F.softmax(cos1 / 0.1, dim=-1)) record("patchwork(tri)", lambda: patchwork(tri1)) record("bridge(pw)", lambda: bridge(pw1)) record("task_head(feat)", lambda: task_head(feat1)) def _full_fwd(): e1, _ = encoder(v1) e2, _ = encoder(v2) an = F.normalize(anchors, dim=-1) c1 = e1 @ an.T; c2 = e2 @ an.T t1 = 1 - c1; t2 = 1 - c2 a1 = F.softmax(c1/0.1, dim=-1); a2 = F.softmax(c2/0.1, dim=-1) p1 = patchwork(t1); p2 = patchwork(t2) b1 = bridge(p1) f1 = torch.cat([a1, p1, e1], -1) return task_head(f1) record("FULL forward (both views)", _full_fwd) # ═══════════════════════════════════════════════════════════════ # SECTION 2: Individual Loss Terms (forward only) # ═══════════════════════════════════════════════════════════════ print(f"\n{'='*80}") print("SECTION 2: INDIVIDUAL LOSS TERMS (forward only)") print(f"{'='*80}\n") record("CE (cross_entropy)", lambda: F.cross_entropy(logits1, targets)) record("NCE_emb (B×B + CE)", lambda: F.cross_entropy( emb1 @ emb2.T / 0.07, labels_nce)) record("NCE_pw (norm + B×B + CE)", lambda: F.cross_entropy( F.normalize(pw1, dim=-1) @ F.normalize(pw2, dim=-1).T / 0.1, labels_nce)) record("NCE_tri (norm + B×B + CE)", lambda: F.cross_entropy( F.normalize(tri1, dim=-1) @ F.normalize(tri2, dim=-1).T / 0.1, labels_nce)) record("NCE_assign (B×B + CE)", lambda: F.cross_entropy( assign1 @ assign2.T / 0.1, labels_nce)) def _bridge_loss(): at = assign1.detach() return -(at * F.log_softmax(bridge1, dim=-1)).sum(-1).mean() record("Bridge (soft CE)", _bridge_loss) def _assign_bce(): nearest = cos1.argmax(dim=-1) hard = torch.zeros_like(assign1) hard.scatter_(1, nearest.unsqueeze(1), 1.0) return F.binary_cross_entropy(assign1.float().clamp(1e-7, 1-1e-7), hard.float()) record("Assign BCE", _assign_bce) record("Attraction (max + mean)", lambda: (1.0 - cos1.max(dim=1).values).mean()) def _spread(): a = F.normalize(anchors, dim=-1) sim = a @ a.T mask = ~torch.eye(N_ANCHORS, dtype=torch.bool, device=DEVICE) return F.relu(sim[mask]).mean() record("Spread (A×A + relu)", _spread) record("kNN (B×B + argmax)", lambda: ( targets[(emb1 @ emb1.T).fill_diagonal_(-1).argmax(1)] == targets).float().mean()) # ═══════════════════════════════════════════════════════════════ # SECTION 3: CV Loss — Old vs Batched # ═══════════════════════════════════════════════════════════════ print(f"\n{'='*80}") print("SECTION 3: CV LOSS — OLD SEQUENTIAL vs BATCHED") print(f"{'='*80}\n") # Old sequential def _cv_old(n_samples=64): vols = [] for _ in range(n_samples): idx = torch.randperm(min(BATCH, 256), device=DEVICE)[:5] pts = emb1[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=pts.dtype) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2)) v2 = pf * torch.linalg.det(cm.float()) if v2[0].item() > 1e-20: vols.append(v2[0].sqrt()) if len(vols) < 5: return torch.tensor(0.0, device=DEVICE) vt = torch.stack(vols) return ((vt.std() / (vt.mean() + 1e-8)) - 0.22).pow(2) # Batched def _cv_batched(n_samples=64): pool = min(BATCH, 256) rand_keys = torch.rand(n_samples, pool, device=DEVICE) indices = rand_keys.argsort(dim=1)[:, :5] pts = emb1[:pool][indices] gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(n_samples, 6, 6, device=DEVICE, dtype=pts.dtype) cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2 pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2)) dets = pf * torch.linalg.det(cm.float()) valid = dets > 1e-20 vols = dets[valid].sqrt() if vols.shape[0] < 5: return torch.tensor(0.0, device=DEVICE) return ((vols.std() / (vols.mean() + 1e-8)) - 0.22).pow(2) for ns in [32, 64, 128, 200]: record(f"CV OLD n={ns}", lambda ns=ns: _cv_old(ns), n_runs=10) record(f"CV BATCH n={ns}", lambda ns=ns: _cv_batched(ns), n_runs=10) # Non-differentiable metric versions def _cv_metric_old(n_samples=200): with torch.no_grad(): return _cv_old(n_samples) def _cv_metric_batch(n_samples=200): with torch.no_grad(): return _cv_batched(n_samples) record("CV metric OLD n=200", _cv_metric_old, n_runs=10) record("CV metric BATCH n=200", _cv_metric_batch, n_runs=10) # ═══════════════════════════════════════════════════════════════ # SECTION 4: Backward costs # ═══════════════════════════════════════════════════════════════ print(f"\n{'='*80}") print("SECTION 4: BACKWARD COSTS (forward + backward)") print(f"{'='*80}\n") def _bwd(loss_fn): for p in all_params: if p.grad is not None: p.grad.zero_() loss = loss_fn() if torch.is_tensor(loss) and loss.requires_grad: loss.backward() return loss # Need fresh forward for each backward def _fwd_bwd_ce(): e, _ = encoder(v1) an = F.normalize(anchors, dim=-1) c = e @ an.T; t = 1 - c a = F.softmax(c/0.1, dim=-1) p = patchwork(t) f = torch.cat([a, p, e], -1) return _bwd(lambda: F.cross_entropy(task_head(f), targets)) def _fwd_bwd_nce_emb(): e1, _ = encoder(v1); e2, _ = encoder(v2) return _bwd(lambda: F.cross_entropy(e1 @ e2.T / 0.07, labels_nce)) def _fwd_bwd_nce_pw(): e1, _ = encoder(v1); e2, _ = encoder(v2) an = F.normalize(anchors, dim=-1) t1 = 1 - e1 @ an.T; t2 = 1 - e2 @ an.T p1 = patchwork(t1); p2 = patchwork(t2) return _bwd(lambda: F.cross_entropy( F.normalize(p1, dim=-1) @ F.normalize(p2, dim=-1).T / 0.1, labels_nce)) def _fwd_bwd_cv_old(): e, _ = encoder(v1) return _bwd(lambda: _cv_old(64)) def _fwd_bwd_cv_batch(): e, _ = encoder(v1) return _bwd(lambda: _cv_batched(64)) def _fwd_bwd_bridge(): e, _ = encoder(v1) an = F.normalize(anchors, dim=-1) c = e @ an.T; t = 1 - c a = F.softmax(c/0.1, dim=-1) p = patchwork(t); b = bridge(p) at = a.detach() return _bwd(lambda: -(at * F.log_softmax(b, dim=-1)).sum(-1).mean()) record("fwd+bwd CE", _fwd_bwd_ce, n_runs=10, warmup=3) record("fwd+bwd NCE_emb", _fwd_bwd_nce_emb, n_runs=10, warmup=3) record("fwd+bwd NCE_pw", _fwd_bwd_nce_pw, n_runs=10, warmup=3) record("fwd+bwd CV old", _fwd_bwd_cv_old, n_runs=10, warmup=3) record("fwd+bwd CV batch", _fwd_bwd_cv_batch, n_runs=10, warmup=3) record("fwd+bwd Bridge", _fwd_bwd_bridge, n_runs=10, warmup=3) # ═══════════════════════════════════════════════════════════════ # REPORT # ═══════════════════════════════════════════════════════════════ print(f"\n\n{'='*80}") print("FULL TIMING REPORT (sorted by cost)") print(f"{'='*80}\n") total = sum(ms for _, ms in results) for name, ms in sorted(results, key=lambda x: -x[1]): pct = 100 * ms / total if total > 0 else 0 bar_len = int(pct / 2) bar = "█" * bar_len + "░" * (40 - bar_len) print(f" {name:35s} {ms:>9.3f}ms {bar} {pct:>5.1f}%") print(f" {'─'*90}") print(f" {'SUM':35s} {total:>9.3f}ms") # CV speedup summary print(f"\n{'='*80}") print("CV SPEEDUP SUMMARY") print(f"{'='*80}") cv_pairs = {} for name, ms in results: if name.startswith("CV "): key = name.split("n=")[1] if "n=" in name else "?" tag = "old" if "OLD" in name else "batch" cv_pairs.setdefault(key, {})[tag] = ms for k in sorted(cv_pairs.keys()): p = cv_pairs[k] if 'old' in p and 'batch' in p: speedup = p['old'] / p['batch'] if p['batch'] > 0 else 0 print(f" n={k:>4s}: {p['old']:>8.2f}ms → {p['batch']:>8.2f}ms ({speedup:.1f}x speedup)") # Per-step estimate print(f"\n{'='*80}") print("PER-STEP ESTIMATE") print(f"{'='*80}") fwd_time = next((ms for n, ms in results if n == "FULL forward (both views)"), 0) bwd_ce = next((ms for n, ms in results if n == "fwd+bwd CE"), 0) bwd_cv_old = next((ms for n, ms in results if n == "fwd+bwd CV old"), 0) bwd_cv_new = next((ms for n, ms in results if n == "fwd+bwd CV batch"), 0) print(f" Forward (both views): {fwd_time:.2f}ms") print(f" fwd+bwd CE: {bwd_ce:.2f}ms") print(f" fwd+bwd CV (old): {bwd_cv_old:.2f}ms") print(f" fwd+bwd CV (batched): {bwd_cv_new:.2f}ms") if bwd_cv_old > 0 and bwd_cv_new > 0: saved = bwd_cv_old - bwd_cv_new print(f" CV savings per step: {saved:.2f}ms ({saved/bwd_cv_old*100:.0f}%)") print(f"\n{'='*80}") print("PROFILING COMPLETE") print(f"{'='*80}")