geolip-conduit-experiments / cell_6_fresnel_freckles.py
AbstractPhil's picture
Create cell_6_fresnel_freckles.py
651f747 verified
"""
Cell 6 β€” Fresnel Spatial Conv Readout
=======================================
Fresnel v50 β€” trained on CLEAN ImageNet-64. No noise.
The SVD learned real structural decomposition.
Same 8 conduit configurations through conv on the 16Γ—16 grid.
No pooling. No flattening. Spatial readout respects geometric structure.
CRITICAL DIFFERENCE FROM FRECKLES:
Freckles learned noise reconstruction features.
Fresnel learned clean image structural decomposition.
The SVD elements from Fresnel actually encode learned relational behavior.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from tqdm import tqdm
device = torch.device('cuda')
# ═══════════════════════════════════════════════════════════════
# CONFIG β€” Set the Fresnel version here
# ═══════════════════════════════════════════════════════════════
FRESNEL_VERSION = 'v50_fresnel_64' # adjust if different checkpoint
IMG_SIZE = 64
# ═══════════════════════════════════════════════════════════════
# LOAD FRESNEL + CIFAR-10
# ═══════════════════════════════════════════════════════════════
print(f"Loading Fresnel ({FRESNEL_VERSION}) + CIFAR-10...")
from geolip_svae import load_model
from geolip_svae.model import extract_patches
import torchvision
import torchvision.transforms as T
from geolip_core.linalg.conduit import FLEighConduit
fresnel, cfg = load_model(hf_version=FRESNEL_VERSION, device=device)
fresnel.eval()
ps = fresnel.patch_size
gh, gw = IMG_SIZE // ps, IMG_SIZE // ps
D = cfg.get('D', 4) if isinstance(cfg, dict) else 4
print(f" Patch size: {ps}, Grid: {gh}x{gw}, D={D}")
print(f" Params: {sum(p.numel() for p in fresnel.parameters()):,}")
transform = T.Compose([T.Resize(IMG_SIZE), T.ToTensor()])
cifar_train = torchvision.datasets.CIFAR10(
root='/content/data', train=True, download=True, transform=transform)
cifar_test = torchvision.datasets.CIFAR10(
root='/content/data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
cifar_train, batch_size=128, shuffle=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(
cifar_test, batch_size=128, shuffle=False, num_workers=4)
conduit = FLEighConduit().to(device)
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Quick S statistics from Fresnel
print("\nFresnel S-value profile on CIFAR-10 sample:")
with torch.no_grad():
sample = next(iter(test_loader))[0][:16].to(device)
out = fresnel(sample)
S = out['svd']['S']
print(f" S mean: {S.mean(dim=(0,1)).tolist()}")
print(f" S std: {S.std(dim=(0,1)).tolist()}")
print(f" MSE: {F.mse_loss(out['recon'], sample):.6f}")
# ═══════════════════════════════════════════════════════════════
# PRECOMPUTE ALL CONDUIT MAPS
# ═══════════════════════════════════════════════════════════════
def extract_conduit_maps(loader, desc="Extracting"):
all_S = []
all_fric = []
all_settle = []
all_error = []
all_labels = []
for images, labels in tqdm(loader, desc=desc):
with torch.no_grad():
images_gpu = images.to(device)
out = fresnel(images_gpu)
recon = out['recon']
S = out['svd']['S']
Vt = out['svd']['Vt']
B_img, N, _ = S.shape
# Per-patch recon error
inp_p, _, _ = extract_patches(images_gpu, ps)
rec_p, _, _ = extract_patches(recon, ps)
patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1)
# Gram matrices for conduit
S2 = S.pow(2)
G = torch.einsum('bnij,bnj,bnjk->bnik',
Vt.transpose(-2, -1), S2, Vt)
G_flat = G.reshape(B_img * N, D, D)
packet = conduit(G_flat)
all_S.append(S.reshape(B_img, gh, gw, D).cpu())
all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu())
all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu())
all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu())
all_labels.append(labels)
return {
'S': torch.cat(all_S),
'friction': torch.cat(all_fric),
'settle': torch.cat(all_settle),
'error': torch.cat(all_error),
'labels': torch.cat(all_labels),
}
print("\nPrecomputing train set...")
train_data = extract_conduit_maps(train_loader, "Train")
print(f" Train: {len(train_data['labels'])} images")
print("Precomputing test set...")
test_data = extract_conduit_maps(test_loader, "Test")
print(f" Test: {len(test_data['labels'])} images")
# Signal profile
print(f"\n Fresnel signal profile:")
for key in ['S', 'friction', 'settle', 'error']:
t = train_data[key]
flat = t.reshape(t.shape[0], -1)
print(f" {key:10s}: mean={flat.mean():.4f} std={flat.std():.4f} "
f"min={flat.min():.4f} max={flat.max():.4f}")
# ═══════════════════════════════════════════════════════════════
# CONV CLASSIFIER
# ═══════════════════════════════════════════════════════════════
class SpatialConvClassifier(nn.Module):
def __init__(self, in_channels, n_classes=10):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
)
self.head = nn.Sequential(
nn.Linear(128, 64),
nn.GELU(),
nn.Linear(64, n_classes),
)
def forward(self, x):
h = self.conv(x).squeeze(-1).squeeze(-1)
return self.head(h)
class ConduitDataset(torch.utils.data.Dataset):
def __init__(self, data, channels='S', augment=False):
self.labels = data['labels']
self.augment = augment
parts = []
if 'S' in channels:
parts.append(data['S'].permute(0, 3, 1, 2))
if 'F' in channels:
parts.append(data['friction'].permute(0, 3, 1, 2))
if 'T' in channels:
parts.append(data['settle'].permute(0, 3, 1, 2))
if 'E' in channels:
parts.append(data['error'].permute(0, 3, 1, 2))
self.maps = torch.cat(parts, dim=1)
self.n_channels = self.maps.shape[1]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
x = self.maps[idx]
if self.augment and torch.rand(1).item() > 0.5:
x = x.flip(-1)
return x, self.labels[idx]
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4):
train_ds = ConduitDataset(train_data, channels, augment=True)
test_ds = ConduitDataset(test_data, channels, augment=False)
n_ch = train_ds.n_channels
tr_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
te_loader = torch.utils.data.DataLoader(
test_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
model = SpatialConvClassifier(n_ch, 10).to(device)
n_params = sum(p.numel() for p in model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_acc = 0
t0 = time.time()
for epoch in range(1, epochs + 1):
model.train()
correct, total = 0, 0
for x, y in tr_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
correct += (logits.argmax(-1) == y).sum().item()
total += len(y)
sched.step()
train_acc = correct / total
model.eval()
tc, tt = 0, 0
pcc = torch.zeros(10)
pct = torch.zeros(10)
with torch.no_grad():
for x, y in te_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(-1)
tc += (preds == y).sum().item()
tt += len(y)
for c in range(10):
m = y == c
pcc[c] += (preds[m] == y[m]).sum().item()
pct[c] += m.sum().item()
test_acc = tc / tt
if test_acc > best_acc:
best_acc = test_acc
if epoch % 5 == 0 or epoch == epochs:
print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}")
elapsed = time.time() - t0
pca = pcc / (pct + 1e-8)
print(f"\n {name}")
print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s")
print(f" Best test: {best_acc:.1%}")
print(f"\n {'Class':<10s} {'Acc':>6s}")
print(f" {'-' * 22}")
for c in range(10):
bar = 'β–ˆ' * int(pca[c] * 20)
print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}")
print()
return best_acc, n_params
# ═══════════════════════════════════════════════════════════════
# RUN ALL CONFIGURATIONS
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print(" FRESNEL β€” Spatial Conv Readout β€” All Conduit Configurations")
print("=" * 70)
results = {}
configs = [
('S', "Eigenvalues (S) only β€” 4ch"),
('F', "Friction only β€” 4ch"),
('E', "Release error only β€” 1ch"),
('T', "Settle only β€” 4ch"),
('SF', "S + Friction β€” 8ch"),
('SE', "S + Release error β€” 5ch"),
('SFE', "S + Friction + Release β€” 9ch"),
('SFET', "FULL CONDUIT β€” 13ch"),
]
for channels, name in configs:
print(f"\n{'─' * 70}")
print(f" Training: {name}")
print(f"{'─' * 70}")
acc, params = train_and_eval(channels, name)
results[channels] = (acc, params, name)
# ═══════════════════════════════════════════════════════════════
# SCOREBOARD
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(f" SCOREBOARD β€” Fresnel ({FRESNEL_VERSION}) Spatial Conv Readout")
print("=" * 70)
print(f"\n {'Configuration':<35s} {'Ch':>4s} {'Params':>10s} {'Test Acc':>9s}")
print(f" {'-' * 62}")
print(f" {'Chance':<35s} {'β€”':>4s} {'β€”':>10s} {'10.0%':>9s}")
for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]):
n_ch = sum([4 if c in 'SFT' else 1 for c in channels])
print(f" {name:<35s} {n_ch:>4d} {params:>10,d} {acc:>8.1%}")
print(f"\n {'--- FRECKLES REFERENCE ---':<35s}")
print(f" {'Scatter + conv (Freckles S)':<35s} {'4':>4s} {'2.9M':>10s} {'70.5%':>9s}")
s_acc = results.get('S', (0, 0, ''))[0]
best_ch, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0])
print(f"\n Fresnel S-only: {s_acc:.1%}")
print(f" Best conduit: {best_acc:.1%} ({best_name})")
print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp")
print(f"\n KEY QUESTION: Does Fresnel's clean training produce")
print(f" conduit signals that Freckles' noise training could not?")