""" Cell 6 — Fresnel Spatial Conv Readout ======================================= Fresnel v50 — trained on CLEAN ImageNet-64. No noise. The SVD learned real structural decomposition. Same 8 conduit configurations through conv on the 16×16 grid. No pooling. No flattening. Spatial readout respects geometric structure. CRITICAL DIFFERENCE FROM FRECKLES: Freckles learned noise reconstruction features. Fresnel learned clean image structural decomposition. The SVD elements from Fresnel actually encode learned relational behavior. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import time from tqdm import tqdm device = torch.device('cuda') # ═══════════════════════════════════════════════════════════════ # CONFIG — Set the Fresnel version here # ═══════════════════════════════════════════════════════════════ FRESNEL_VERSION = 'v50_fresnel_64' # adjust if different checkpoint IMG_SIZE = 64 # ═══════════════════════════════════════════════════════════════ # LOAD FRESNEL + CIFAR-10 # ═══════════════════════════════════════════════════════════════ print(f"Loading Fresnel ({FRESNEL_VERSION}) + CIFAR-10...") from geolip_svae import load_model from geolip_svae.model import extract_patches import torchvision import torchvision.transforms as T from geolip_core.linalg.conduit import FLEighConduit fresnel, cfg = load_model(hf_version=FRESNEL_VERSION, device=device) fresnel.eval() ps = fresnel.patch_size gh, gw = IMG_SIZE // ps, IMG_SIZE // ps D = cfg.get('D', 4) if isinstance(cfg, dict) else 4 print(f" Patch size: {ps}, Grid: {gh}x{gw}, D={D}") print(f" Params: {sum(p.numel() for p in fresnel.parameters()):,}") transform = T.Compose([T.Resize(IMG_SIZE), T.ToTensor()]) cifar_train = torchvision.datasets.CIFAR10( root='/content/data', train=True, download=True, transform=transform) cifar_test = torchvision.datasets.CIFAR10( root='/content/data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader( cifar_train, batch_size=128, shuffle=False, num_workers=4) test_loader = torch.utils.data.DataLoader( cifar_test, batch_size=128, shuffle=False, num_workers=4) conduit = FLEighConduit().to(device) CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Quick S statistics from Fresnel print("\nFresnel S-value profile on CIFAR-10 sample:") with torch.no_grad(): sample = next(iter(test_loader))[0][:16].to(device) out = fresnel(sample) S = out['svd']['S'] print(f" S mean: {S.mean(dim=(0,1)).tolist()}") print(f" S std: {S.std(dim=(0,1)).tolist()}") print(f" MSE: {F.mse_loss(out['recon'], sample):.6f}") # ═══════════════════════════════════════════════════════════════ # PRECOMPUTE ALL CONDUIT MAPS # ═══════════════════════════════════════════════════════════════ def extract_conduit_maps(loader, desc="Extracting"): all_S = [] all_fric = [] all_settle = [] all_error = [] all_labels = [] for images, labels in tqdm(loader, desc=desc): with torch.no_grad(): images_gpu = images.to(device) out = fresnel(images_gpu) recon = out['recon'] S = out['svd']['S'] Vt = out['svd']['Vt'] B_img, N, _ = S.shape # Per-patch recon error inp_p, _, _ = extract_patches(images_gpu, ps) rec_p, _, _ = extract_patches(recon, ps) patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) # Gram matrices for conduit S2 = S.pow(2) G = torch.einsum('bnij,bnj,bnjk->bnik', Vt.transpose(-2, -1), S2, Vt) G_flat = G.reshape(B_img * N, D, D) packet = conduit(G_flat) all_S.append(S.reshape(B_img, gh, gw, D).cpu()) all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu()) all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu()) all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu()) all_labels.append(labels) return { 'S': torch.cat(all_S), 'friction': torch.cat(all_fric), 'settle': torch.cat(all_settle), 'error': torch.cat(all_error), 'labels': torch.cat(all_labels), } print("\nPrecomputing train set...") train_data = extract_conduit_maps(train_loader, "Train") print(f" Train: {len(train_data['labels'])} images") print("Precomputing test set...") test_data = extract_conduit_maps(test_loader, "Test") print(f" Test: {len(test_data['labels'])} images") # Signal profile print(f"\n Fresnel signal profile:") for key in ['S', 'friction', 'settle', 'error']: t = train_data[key] flat = t.reshape(t.shape[0], -1) print(f" {key:10s}: mean={flat.mean():.4f} std={flat.std():.4f} " f"min={flat.min():.4f} max={flat.max():.4f}") # ═══════════════════════════════════════════════════════════════ # CONV CLASSIFIER # ═══════════════════════════════════════════════════════════════ class SpatialConvClassifier(nn.Module): def __init__(self, in_channels, n_classes=10): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), nn.GELU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.GELU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.GELU(), nn.AdaptiveAvgPool2d(1), ) self.head = nn.Sequential( nn.Linear(128, 64), nn.GELU(), nn.Linear(64, n_classes), ) def forward(self, x): h = self.conv(x).squeeze(-1).squeeze(-1) return self.head(h) class ConduitDataset(torch.utils.data.Dataset): def __init__(self, data, channels='S', augment=False): self.labels = data['labels'] self.augment = augment parts = [] if 'S' in channels: parts.append(data['S'].permute(0, 3, 1, 2)) if 'F' in channels: parts.append(data['friction'].permute(0, 3, 1, 2)) if 'T' in channels: parts.append(data['settle'].permute(0, 3, 1, 2)) if 'E' in channels: parts.append(data['error'].permute(0, 3, 1, 2)) self.maps = torch.cat(parts, dim=1) self.n_channels = self.maps.shape[1] def __len__(self): return len(self.labels) def __getitem__(self, idx): x = self.maps[idx] if self.augment and torch.rand(1).item() > 0.5: x = x.flip(-1) return x, self.labels[idx] # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4): train_ds = ConduitDataset(train_data, channels, augment=True) test_ds = ConduitDataset(test_data, channels, augment=False) n_ch = train_ds.n_channels tr_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) te_loader = torch.utils.data.DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) model = SpatialConvClassifier(n_ch, 10).to(device) n_params = sum(p.numel() for p in model.parameters()) opt = torch.optim.Adam(model.parameters(), lr=lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best_acc = 0 t0 = time.time() for epoch in range(1, epochs + 1): model.train() correct, total = 0, 0 for x, y in tr_loader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) opt.zero_grad() loss.backward() opt.step() correct += (logits.argmax(-1) == y).sum().item() total += len(y) sched.step() train_acc = correct / total model.eval() tc, tt = 0, 0 pcc = torch.zeros(10) pct = torch.zeros(10) with torch.no_grad(): for x, y in te_loader: x, y = x.to(device), y.to(device) preds = model(x).argmax(-1) tc += (preds == y).sum().item() tt += len(y) for c in range(10): m = y == c pcc[c] += (preds[m] == y[m]).sum().item() pct[c] += m.sum().item() test_acc = tc / tt if test_acc > best_acc: best_acc = test_acc if epoch % 5 == 0 or epoch == epochs: print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}") elapsed = time.time() - t0 pca = pcc / (pct + 1e-8) print(f"\n {name}") print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s") print(f" Best test: {best_acc:.1%}") print(f"\n {'Class':<10s} {'Acc':>6s}") print(f" {'-' * 22}") for c in range(10): bar = '█' * int(pca[c] * 20) print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}") print() return best_acc, n_params # ═══════════════════════════════════════════════════════════════ # RUN ALL CONFIGURATIONS # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" FRESNEL — Spatial Conv Readout — All Conduit Configurations") print("=" * 70) results = {} configs = [ ('S', "Eigenvalues (S) only — 4ch"), ('F', "Friction only — 4ch"), ('E', "Release error only — 1ch"), ('T', "Settle only — 4ch"), ('SF', "S + Friction — 8ch"), ('SE', "S + Release error — 5ch"), ('SFE', "S + Friction + Release — 9ch"), ('SFET', "FULL CONDUIT — 13ch"), ] for channels, name in configs: print(f"\n{'─' * 70}") print(f" Training: {name}") print(f"{'─' * 70}") acc, params = train_and_eval(channels, name) results[channels] = (acc, params, name) # ═══════════════════════════════════════════════════════════════ # SCOREBOARD # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(f" SCOREBOARD — Fresnel ({FRESNEL_VERSION}) Spatial Conv Readout") print("=" * 70) print(f"\n {'Configuration':<35s} {'Ch':>4s} {'Params':>10s} {'Test Acc':>9s}") print(f" {'-' * 62}") print(f" {'Chance':<35s} {'—':>4s} {'—':>10s} {'10.0%':>9s}") for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]): n_ch = sum([4 if c in 'SFT' else 1 for c in channels]) print(f" {name:<35s} {n_ch:>4d} {params:>10,d} {acc:>8.1%}") print(f"\n {'--- FRECKLES REFERENCE ---':<35s}") print(f" {'Scatter + conv (Freckles S)':<35s} {'4':>4s} {'2.9M':>10s} {'70.5%':>9s}") s_acc = results.get('S', (0, 0, ''))[0] best_ch, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0]) print(f"\n Fresnel S-only: {s_acc:.1%}") print(f" Best conduit: {best_acc:.1%} ({best_name})") print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp") print(f"\n KEY QUESTION: Does Fresnel's clean training produce") print(f" conduit signals that Freckles' noise training could not?")