| """ |
| Cell 6 β Fresnel Spatial Conv Readout |
| ======================================= |
| Fresnel v50 β trained on CLEAN ImageNet-64. No noise. |
| The SVD learned real structural decomposition. |
| |
| Same 8 conduit configurations through conv on the 16Γ16 grid. |
| No pooling. No flattening. Spatial readout respects geometric structure. |
| |
| CRITICAL DIFFERENCE FROM FRECKLES: |
| Freckles learned noise reconstruction features. |
| Fresnel learned clean image structural decomposition. |
| The SVD elements from Fresnel actually encode learned relational behavior. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import time |
| from tqdm import tqdm |
|
|
| device = torch.device('cuda') |
|
|
| |
| |
| |
|
|
| FRESNEL_VERSION = 'v50_fresnel_64' |
| IMG_SIZE = 64 |
|
|
| |
| |
| |
|
|
| print(f"Loading Fresnel ({FRESNEL_VERSION}) + CIFAR-10...") |
| from geolip_svae import load_model |
| from geolip_svae.model import extract_patches |
| import torchvision |
| import torchvision.transforms as T |
| from geolip_core.linalg.conduit import FLEighConduit |
|
|
| fresnel, cfg = load_model(hf_version=FRESNEL_VERSION, device=device) |
| fresnel.eval() |
|
|
| ps = fresnel.patch_size |
| gh, gw = IMG_SIZE // ps, IMG_SIZE // ps |
| D = cfg.get('D', 4) if isinstance(cfg, dict) else 4 |
|
|
| print(f" Patch size: {ps}, Grid: {gh}x{gw}, D={D}") |
| print(f" Params: {sum(p.numel() for p in fresnel.parameters()):,}") |
|
|
| transform = T.Compose([T.Resize(IMG_SIZE), T.ToTensor()]) |
| cifar_train = torchvision.datasets.CIFAR10( |
| root='/content/data', train=True, download=True, transform=transform) |
| cifar_test = torchvision.datasets.CIFAR10( |
| root='/content/data', train=False, download=True, transform=transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| cifar_train, batch_size=128, shuffle=False, num_workers=4) |
| test_loader = torch.utils.data.DataLoader( |
| cifar_test, batch_size=128, shuffle=False, num_workers=4) |
|
|
| conduit = FLEighConduit().to(device) |
|
|
| CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| |
| print("\nFresnel S-value profile on CIFAR-10 sample:") |
| with torch.no_grad(): |
| sample = next(iter(test_loader))[0][:16].to(device) |
| out = fresnel(sample) |
| S = out['svd']['S'] |
| print(f" S mean: {S.mean(dim=(0,1)).tolist()}") |
| print(f" S std: {S.std(dim=(0,1)).tolist()}") |
| print(f" MSE: {F.mse_loss(out['recon'], sample):.6f}") |
|
|
|
|
| |
| |
| |
|
|
| def extract_conduit_maps(loader, desc="Extracting"): |
| all_S = [] |
| all_fric = [] |
| all_settle = [] |
| all_error = [] |
| all_labels = [] |
|
|
| for images, labels in tqdm(loader, desc=desc): |
| with torch.no_grad(): |
| images_gpu = images.to(device) |
| out = fresnel(images_gpu) |
| recon = out['recon'] |
| S = out['svd']['S'] |
| Vt = out['svd']['Vt'] |
| B_img, N, _ = S.shape |
|
|
| |
| inp_p, _, _ = extract_patches(images_gpu, ps) |
| rec_p, _, _ = extract_patches(recon, ps) |
| patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) |
|
|
| |
| S2 = S.pow(2) |
| G = torch.einsum('bnij,bnj,bnjk->bnik', |
| Vt.transpose(-2, -1), S2, Vt) |
| G_flat = G.reshape(B_img * N, D, D) |
| packet = conduit(G_flat) |
|
|
| all_S.append(S.reshape(B_img, gh, gw, D).cpu()) |
| all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu()) |
| all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu()) |
| all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu()) |
| all_labels.append(labels) |
|
|
| return { |
| 'S': torch.cat(all_S), |
| 'friction': torch.cat(all_fric), |
| 'settle': torch.cat(all_settle), |
| 'error': torch.cat(all_error), |
| 'labels': torch.cat(all_labels), |
| } |
|
|
| print("\nPrecomputing train set...") |
| train_data = extract_conduit_maps(train_loader, "Train") |
| print(f" Train: {len(train_data['labels'])} images") |
|
|
| print("Precomputing test set...") |
| test_data = extract_conduit_maps(test_loader, "Test") |
| print(f" Test: {len(test_data['labels'])} images") |
|
|
| |
| print(f"\n Fresnel signal profile:") |
| for key in ['S', 'friction', 'settle', 'error']: |
| t = train_data[key] |
| flat = t.reshape(t.shape[0], -1) |
| print(f" {key:10s}: mean={flat.mean():.4f} std={flat.std():.4f} " |
| f"min={flat.min():.4f} max={flat.max():.4f}") |
|
|
|
|
| |
| |
| |
|
|
| class SpatialConvClassifier(nn.Module): |
| def __init__(self, in_channels, n_classes=10): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), |
| nn.GELU(), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| nn.GELU(), |
| nn.Conv2d(128, 128, 3, stride=1, padding=1), |
| nn.GELU(), |
| nn.AdaptiveAvgPool2d(1), |
| ) |
| self.head = nn.Sequential( |
| nn.Linear(128, 64), |
| nn.GELU(), |
| nn.Linear(64, n_classes), |
| ) |
|
|
| def forward(self, x): |
| h = self.conv(x).squeeze(-1).squeeze(-1) |
| return self.head(h) |
|
|
|
|
| class ConduitDataset(torch.utils.data.Dataset): |
| def __init__(self, data, channels='S', augment=False): |
| self.labels = data['labels'] |
| self.augment = augment |
| parts = [] |
| if 'S' in channels: |
| parts.append(data['S'].permute(0, 3, 1, 2)) |
| if 'F' in channels: |
| parts.append(data['friction'].permute(0, 3, 1, 2)) |
| if 'T' in channels: |
| parts.append(data['settle'].permute(0, 3, 1, 2)) |
| if 'E' in channels: |
| parts.append(data['error'].permute(0, 3, 1, 2)) |
| self.maps = torch.cat(parts, dim=1) |
| self.n_channels = self.maps.shape[1] |
|
|
| def __len__(self): |
| return len(self.labels) |
|
|
| def __getitem__(self, idx): |
| x = self.maps[idx] |
| if self.augment and torch.rand(1).item() > 0.5: |
| x = x.flip(-1) |
| return x, self.labels[idx] |
|
|
|
|
| |
| |
| |
|
|
| def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4): |
| train_ds = ConduitDataset(train_data, channels, augment=True) |
| test_ds = ConduitDataset(test_data, channels, augment=False) |
| n_ch = train_ds.n_channels |
|
|
| tr_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
| te_loader = torch.utils.data.DataLoader( |
| test_ds, batch_size=batch_size, shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| model = SpatialConvClassifier(n_ch, 10).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| best_acc = 0 |
| t0 = time.time() |
|
|
| for epoch in range(1, epochs + 1): |
| model.train() |
| correct, total = 0, 0 |
| for x, y in tr_loader: |
| x, y = x.to(device), y.to(device) |
| logits = model(x) |
| loss = F.cross_entropy(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| correct += (logits.argmax(-1) == y).sum().item() |
| total += len(y) |
| sched.step() |
| train_acc = correct / total |
|
|
| model.eval() |
| tc, tt = 0, 0 |
| pcc = torch.zeros(10) |
| pct = torch.zeros(10) |
| with torch.no_grad(): |
| for x, y in te_loader: |
| x, y = x.to(device), y.to(device) |
| preds = model(x).argmax(-1) |
| tc += (preds == y).sum().item() |
| tt += len(y) |
| for c in range(10): |
| m = y == c |
| pcc[c] += (preds[m] == y[m]).sum().item() |
| pct[c] += m.sum().item() |
|
|
| test_acc = tc / tt |
| if test_acc > best_acc: |
| best_acc = test_acc |
|
|
| if epoch % 5 == 0 or epoch == epochs: |
| print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}") |
|
|
| elapsed = time.time() - t0 |
| pca = pcc / (pct + 1e-8) |
|
|
| print(f"\n {name}") |
| print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s") |
| print(f" Best test: {best_acc:.1%}") |
| print(f"\n {'Class':<10s} {'Acc':>6s}") |
| print(f" {'-' * 22}") |
| for c in range(10): |
| bar = 'β' * int(pca[c] * 20) |
| print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}") |
| print() |
| return best_acc, n_params |
|
|
|
|
| |
| |
| |
|
|
| print("\n" + "=" * 70) |
| print(" FRESNEL β Spatial Conv Readout β All Conduit Configurations") |
| print("=" * 70) |
|
|
| results = {} |
|
|
| configs = [ |
| ('S', "Eigenvalues (S) only β 4ch"), |
| ('F', "Friction only β 4ch"), |
| ('E', "Release error only β 1ch"), |
| ('T', "Settle only β 4ch"), |
| ('SF', "S + Friction β 8ch"), |
| ('SE', "S + Release error β 5ch"), |
| ('SFE', "S + Friction + Release β 9ch"), |
| ('SFET', "FULL CONDUIT β 13ch"), |
| ] |
|
|
| for channels, name in configs: |
| print(f"\n{'β' * 70}") |
| print(f" Training: {name}") |
| print(f"{'β' * 70}") |
| acc, params = train_and_eval(channels, name) |
| results[channels] = (acc, params, name) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(f" SCOREBOARD β Fresnel ({FRESNEL_VERSION}) Spatial Conv Readout") |
| print("=" * 70) |
|
|
| print(f"\n {'Configuration':<35s} {'Ch':>4s} {'Params':>10s} {'Test Acc':>9s}") |
| print(f" {'-' * 62}") |
| print(f" {'Chance':<35s} {'β':>4s} {'β':>10s} {'10.0%':>9s}") |
|
|
| for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]): |
| n_ch = sum([4 if c in 'SFT' else 1 for c in channels]) |
| print(f" {name:<35s} {n_ch:>4d} {params:>10,d} {acc:>8.1%}") |
|
|
| print(f"\n {'--- FRECKLES REFERENCE ---':<35s}") |
| print(f" {'Scatter + conv (Freckles S)':<35s} {'4':>4s} {'2.9M':>10s} {'70.5%':>9s}") |
|
|
| s_acc = results.get('S', (0, 0, ''))[0] |
| best_ch, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0]) |
| print(f"\n Fresnel S-only: {s_acc:.1%}") |
| print(f" Best conduit: {best_acc:.1%} ({best_name})") |
| print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp") |
| print(f"\n KEY QUESTION: Does Fresnel's clean training produce") |
| print(f" conduit signals that Freckles' noise training could not?") |