""" Cell 5 — Spatial Conv Readout on Conduit Maps =============================================== No pooling. No flattening. Conv reads the 16×16 spatial grid directly. This is the CORRECT way to evaluate whether conduit signals carry class-discriminative information. The linear probe was wrong — it destroyed the spatial structure that IS the signal. Channels on the 16×16 grid: S values: 4 channels (eigenvalues per patch) Friction: 4 channels (solver struggle per mode) Release error: 1 channel (reconstruction fidelity per patch) Settle: 4 channels (convergence speed per mode) Test each signal alone and combined, all through conv readout. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import time from tqdm import tqdm device = torch.device('cuda') # ═══════════════════════════════════════════════════════════════ # LOAD # ═══════════════════════════════════════════════════════════════ print("Loading Freckles v40 + CIFAR-10...") from geolip_svae import load_model from geolip_svae.model import extract_patches import torchvision import torchvision.transforms as T from geolip_core.linalg.conduit import FLEighConduit freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device) freckles.eval() ps = freckles.patch_size gh, gw = 64 // ps, 64 // ps D = 4 transform = T.Compose([T.Resize(64), T.ToTensor()]) cifar_train = torchvision.datasets.CIFAR10( root='/content/data', train=True, download=True, transform=transform) cifar_test = torchvision.datasets.CIFAR10( root='/content/data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader( cifar_train, batch_size=128, shuffle=False, num_workers=4) test_loader = torch.utils.data.DataLoader( cifar_test, batch_size=128, shuffle=False, num_workers=4) conduit = FLEighConduit().to(device) CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # ═══════════════════════════════════════════════════════════════ # PRECOMPUTE ALL CONDUIT MAPS # ═══════════════════════════════════════════════════════════════ def extract_conduit_maps(loader, desc="Extracting"): """Extract spatial conduit maps from all images. Returns per image: S_map: (gh, gw, 4) eigenvalues friction_map:(gh, gw, 4) solver friction settle_map: (gh, gw, 4) settle times error_map: (gh, gw, 1) per-patch recon error label: int """ all_S = [] all_fric = [] all_settle = [] all_error = [] all_labels = [] for images, labels in tqdm(loader, desc=desc): with torch.no_grad(): images_gpu = images.to(device) out = freckles(images_gpu) recon = out['recon'] S = out['svd']['S'] # (B, N, D) Vt = out['svd']['Vt'] # (B, N, D, D) B_img, N, _ = S.shape # Per-patch recon error inp_p, _, _ = extract_patches(images_gpu, ps) rec_p, _, _ = extract_patches(recon, ps) patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) # (B, N) # Gram matrices for conduit S2 = S.pow(2) G = torch.einsum('bnij,bnj,bnjk->bnik', Vt.transpose(-2, -1), S2, Vt) G_flat = G.reshape(B_img * N, D, D) packet = conduit(G_flat) # Reshape to spatial all_S.append(S.reshape(B_img, gh, gw, D).cpu()) all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu()) all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu()) all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu()) all_labels.append(labels) return { 'S': torch.cat(all_S), # (N, gh, gw, 4) 'friction': torch.cat(all_fric), # (N, gh, gw, 4) 'settle': torch.cat(all_settle), # (N, gh, gw, 4) 'error': torch.cat(all_error), # (N, gh, gw, 1) 'labels': torch.cat(all_labels), # (N,) } print("\nPrecomputing train set...") train_data = extract_conduit_maps(train_loader, "Train") print(f" Train: {len(train_data['labels'])} images") print("Precomputing test set...") test_data = extract_conduit_maps(test_loader, "Test") print(f" Test: {len(test_data['labels'])} images") # ═══════════════════════════════════════════════════════════════ # CONV CLASSIFIER — reads spatial maps directly # ═══════════════════════════════════════════════════════════════ class SpatialConvClassifier(nn.Module): """Conv readout on 16×16 spatial maps. No pooling until final adaptive.""" def __init__(self, in_channels, n_classes=10): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), # 16→8 nn.GELU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), # 8→4 nn.GELU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), # 4→4 nn.GELU(), nn.AdaptiveAvgPool2d(1), # 4→1 ) self.head = nn.Sequential( nn.Linear(128, 64), nn.GELU(), nn.Linear(64, n_classes), ) def forward(self, x): # x: (B, C, H, W) h = self.conv(x).squeeze(-1).squeeze(-1) return self.head(h) class ConduitDataset(torch.utils.data.Dataset): """Serves selected channels from precomputed conduit maps.""" def __init__(self, data, channels='S', augment=False): self.labels = data['labels'] self.augment = augment # Build channel tensor: (N, C, gh, gw) parts = [] if 'S' in channels: parts.append(data['S'].permute(0, 3, 1, 2)) # (N, 4, gh, gw) if 'F' in channels: parts.append(data['friction'].permute(0, 3, 1, 2)) # (N, 4, gh, gw) if 'T' in channels: parts.append(data['settle'].permute(0, 3, 1, 2)) # (N, 4, gh, gw) if 'E' in channels: parts.append(data['error'].permute(0, 3, 1, 2)) # (N, 1, gh, gw) self.maps = torch.cat(parts, dim=1) # (N, total_C, gh, gw) self.n_channels = self.maps.shape[1] def __len__(self): return len(self.labels) def __getitem__(self, idx): x = self.maps[idx] label = self.labels[idx] if self.augment: if torch.rand(1).item() > 0.5: x = x.flip(-1) # horizontal flip return x, label # ═══════════════════════════════════════════════════════════════ # TRAINING LOOP # ═══════════════════════════════════════════════════════════════ def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4): """Train conv classifier on specified conduit channels.""" train_ds = ConduitDataset(train_data, channels, augment=True) test_ds = ConduitDataset(test_data, channels, augment=False) n_ch = train_ds.n_channels tr_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) te_loader = torch.utils.data.DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) model = SpatialConvClassifier(n_ch, 10).to(device) n_params = sum(p.numel() for p in model.parameters()) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best_acc = 0 t0 = time.time() for epoch in range(1, epochs + 1): model.train() correct, total = 0, 0 for x, y in tr_loader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) opt.zero_grad() loss.backward() opt.step() correct += (logits.argmax(-1) == y).sum().item() total += len(y) sched.step() train_acc = correct / total # Test model.eval() tc, tt = 0, 0 pcc = torch.zeros(10) pct = torch.zeros(10) with torch.no_grad(): for x, y in te_loader: x, y = x.to(device), y.to(device) preds = model(x).argmax(-1) tc += (preds == y).sum().item() tt += len(y) for c in range(10): m = y == c pcc[c] += (preds[m] == y[m]).sum().item() pct[c] += m.sum().item() test_acc = tc / tt if test_acc > best_acc: best_acc = test_acc if epoch % 5 == 0 or epoch == epochs: print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}") elapsed = time.time() - t0 pca = pcc / (pct + 1e-8) print(f"\n {name}") print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s") print(f" Best test: {best_acc:.1%}") print(f"\n {'Class':<10s} {'Acc':>6s}") print(f" {'-' * 22}") for c in range(10): bar = '█' * int(pca[c] * 20) print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}") print() return best_acc, n_params # ═══════════════════════════════════════════════════════════════ # RUN ALL CONFIGURATIONS # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" SPATIAL CONV READOUT — All conduit configurations") print("=" * 70) results = {} configs = [ ('S', "Eigenvalues (S) only — 4ch"), ('F', "Friction only — 4ch"), ('E', "Release error only — 1ch"), ('T', "Settle only — 4ch"), ('SF', "S + Friction — 8ch"), ('SE', "S + Release error — 5ch"), ('SFE', "S + Friction + Release — 9ch"), ('SFET', "FULL CONDUIT — 13ch"), ] for channels, name in configs: print(f"\n{'─' * 70}") print(f" Training: {name}") print(f"{'─' * 70}") acc, params = train_and_eval(channels, name) results[channels] = (acc, params, name) # ═══════════════════════════════════════════════════════════════ # SCOREBOARD # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" SCOREBOARD — Spatial Conv Readout") print("=" * 70) print(f"\n {'Configuration':<35s} {'Channels':>8s} {'Params':>10s} {'Test Acc':>9s}") print(f" {'-' * 65}") print(f" {'Chance':<35s} {'—':>8s} {'—':>10s} {'10.0%':>9s}") for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]): print(f" {name:<35s} {channels:>8s} {params:>10,d} {acc:>8.1%}") # Reference results from earlier experiments print(f"\n {'--- REFERENCE (from earlier) ---':<35s}") print(f" {'Linear probe (friction flat)':<35s} {'—':>8s} {'—':>10s} {'24.3%':>9s}") print(f" {'Linear probe (S flat)':<35s} {'—':>8s} {'—':>10s} {'21.0%':>9s}") print(f" {'Patchwork + calibrated embeds':<35s} {'—':>8s} {'530K':>10s} {'48.0%':>9s}") print(f" {'Scatter + conv (raw S)':<35s} {'—':>8s} {'2.9M':>10s} {'70.5%':>9s}") print(f" {'CNN condensed (SGD)':<35s} {'—':>8s} {'730K':>10s} {'74.7%':>9s}") # Lift analysis s_acc = results.get('S', (0, 0, ''))[0] best_channels, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0]) print(f"\n S-only conv: {s_acc:.1%}") print(f" Best conduit: {best_acc:.1%} ({best_name})") print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp") print(f" vs scatter+conv reference: {(best_acc - 0.705) * 100:+.1f}pp")