geolip-conduit-experiments / cell_5_conduit_sweep.py
AbstractPhil's picture
Create cell_5_conduit_sweep.py
69769dc verified
"""
Cell 5 β€” Spatial Conv Readout on Conduit Maps
===============================================
No pooling. No flattening. Conv reads the 16Γ—16 spatial grid directly.
This is the CORRECT way to evaluate whether conduit signals carry
class-discriminative information. The linear probe was wrong β€”
it destroyed the spatial structure that IS the signal.
Channels on the 16Γ—16 grid:
S values: 4 channels (eigenvalues per patch)
Friction: 4 channels (solver struggle per mode)
Release error: 1 channel (reconstruction fidelity per patch)
Settle: 4 channels (convergence speed per mode)
Test each signal alone and combined, all through conv readout.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from tqdm import tqdm
device = torch.device('cuda')
# ═══════════════════════════════════════════════════════════════
# LOAD
# ═══════════════════════════════════════════════════════════════
print("Loading Freckles v40 + CIFAR-10...")
from geolip_svae import load_model
from geolip_svae.model import extract_patches
import torchvision
import torchvision.transforms as T
from geolip_core.linalg.conduit import FLEighConduit
freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
freckles.eval()
ps = freckles.patch_size
gh, gw = 64 // ps, 64 // ps
D = 4
transform = T.Compose([T.Resize(64), T.ToTensor()])
cifar_train = torchvision.datasets.CIFAR10(
root='/content/data', train=True, download=True, transform=transform)
cifar_test = torchvision.datasets.CIFAR10(
root='/content/data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
cifar_train, batch_size=128, shuffle=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(
cifar_test, batch_size=128, shuffle=False, num_workers=4)
conduit = FLEighConduit().to(device)
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# ═══════════════════════════════════════════════════════════════
# PRECOMPUTE ALL CONDUIT MAPS
# ═══════════════════════════════════════════════════════════════
def extract_conduit_maps(loader, desc="Extracting"):
"""Extract spatial conduit maps from all images.
Returns per image:
S_map: (gh, gw, 4) eigenvalues
friction_map:(gh, gw, 4) solver friction
settle_map: (gh, gw, 4) settle times
error_map: (gh, gw, 1) per-patch recon error
label: int
"""
all_S = []
all_fric = []
all_settle = []
all_error = []
all_labels = []
for images, labels in tqdm(loader, desc=desc):
with torch.no_grad():
images_gpu = images.to(device)
out = freckles(images_gpu)
recon = out['recon']
S = out['svd']['S'] # (B, N, D)
Vt = out['svd']['Vt'] # (B, N, D, D)
B_img, N, _ = S.shape
# Per-patch recon error
inp_p, _, _ = extract_patches(images_gpu, ps)
rec_p, _, _ = extract_patches(recon, ps)
patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) # (B, N)
# Gram matrices for conduit
S2 = S.pow(2)
G = torch.einsum('bnij,bnj,bnjk->bnik',
Vt.transpose(-2, -1), S2, Vt)
G_flat = G.reshape(B_img * N, D, D)
packet = conduit(G_flat)
# Reshape to spatial
all_S.append(S.reshape(B_img, gh, gw, D).cpu())
all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu())
all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu())
all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu())
all_labels.append(labels)
return {
'S': torch.cat(all_S), # (N, gh, gw, 4)
'friction': torch.cat(all_fric), # (N, gh, gw, 4)
'settle': torch.cat(all_settle), # (N, gh, gw, 4)
'error': torch.cat(all_error), # (N, gh, gw, 1)
'labels': torch.cat(all_labels), # (N,)
}
print("\nPrecomputing train set...")
train_data = extract_conduit_maps(train_loader, "Train")
print(f" Train: {len(train_data['labels'])} images")
print("Precomputing test set...")
test_data = extract_conduit_maps(test_loader, "Test")
print(f" Test: {len(test_data['labels'])} images")
# ═══════════════════════════════════════════════════════════════
# CONV CLASSIFIER β€” reads spatial maps directly
# ═══════════════════════════════════════════════════════════════
class SpatialConvClassifier(nn.Module):
"""Conv readout on 16Γ—16 spatial maps. No pooling until final adaptive."""
def __init__(self, in_channels, n_classes=10):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), # 16β†’8
nn.GELU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), # 8β†’4
nn.GELU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1), # 4β†’4
nn.GELU(),
nn.AdaptiveAvgPool2d(1), # 4β†’1
)
self.head = nn.Sequential(
nn.Linear(128, 64),
nn.GELU(),
nn.Linear(64, n_classes),
)
def forward(self, x):
# x: (B, C, H, W)
h = self.conv(x).squeeze(-1).squeeze(-1)
return self.head(h)
class ConduitDataset(torch.utils.data.Dataset):
"""Serves selected channels from precomputed conduit maps."""
def __init__(self, data, channels='S', augment=False):
self.labels = data['labels']
self.augment = augment
# Build channel tensor: (N, C, gh, gw)
parts = []
if 'S' in channels:
parts.append(data['S'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
if 'F' in channels:
parts.append(data['friction'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
if 'T' in channels:
parts.append(data['settle'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
if 'E' in channels:
parts.append(data['error'].permute(0, 3, 1, 2)) # (N, 1, gh, gw)
self.maps = torch.cat(parts, dim=1) # (N, total_C, gh, gw)
self.n_channels = self.maps.shape[1]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
x = self.maps[idx]
label = self.labels[idx]
if self.augment:
if torch.rand(1).item() > 0.5:
x = x.flip(-1) # horizontal flip
return x, label
# ═══════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════
def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4):
"""Train conv classifier on specified conduit channels."""
train_ds = ConduitDataset(train_data, channels, augment=True)
test_ds = ConduitDataset(test_data, channels, augment=False)
n_ch = train_ds.n_channels
tr_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
te_loader = torch.utils.data.DataLoader(
test_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
model = SpatialConvClassifier(n_ch, 10).to(device)
n_params = sum(p.numel() for p in model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best_acc = 0
t0 = time.time()
for epoch in range(1, epochs + 1):
model.train()
correct, total = 0, 0
for x, y in tr_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
correct += (logits.argmax(-1) == y).sum().item()
total += len(y)
sched.step()
train_acc = correct / total
# Test
model.eval()
tc, tt = 0, 0
pcc = torch.zeros(10)
pct = torch.zeros(10)
with torch.no_grad():
for x, y in te_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(-1)
tc += (preds == y).sum().item()
tt += len(y)
for c in range(10):
m = y == c
pcc[c] += (preds[m] == y[m]).sum().item()
pct[c] += m.sum().item()
test_acc = tc / tt
if test_acc > best_acc:
best_acc = test_acc
if epoch % 5 == 0 or epoch == epochs:
print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}")
elapsed = time.time() - t0
pca = pcc / (pct + 1e-8)
print(f"\n {name}")
print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s")
print(f" Best test: {best_acc:.1%}")
print(f"\n {'Class':<10s} {'Acc':>6s}")
print(f" {'-' * 22}")
for c in range(10):
bar = 'β–ˆ' * int(pca[c] * 20)
print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}")
print()
return best_acc, n_params
# ═══════════════════════════════════════════════════════════════
# RUN ALL CONFIGURATIONS
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print(" SPATIAL CONV READOUT β€” All conduit configurations")
print("=" * 70)
results = {}
configs = [
('S', "Eigenvalues (S) only β€” 4ch"),
('F', "Friction only β€” 4ch"),
('E', "Release error only β€” 1ch"),
('T', "Settle only β€” 4ch"),
('SF', "S + Friction β€” 8ch"),
('SE', "S + Release error β€” 5ch"),
('SFE', "S + Friction + Release β€” 9ch"),
('SFET', "FULL CONDUIT β€” 13ch"),
]
for channels, name in configs:
print(f"\n{'─' * 70}")
print(f" Training: {name}")
print(f"{'─' * 70}")
acc, params = train_and_eval(channels, name)
results[channels] = (acc, params, name)
# ═══════════════════════════════════════════════════════════════
# SCOREBOARD
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" SCOREBOARD β€” Spatial Conv Readout")
print("=" * 70)
print(f"\n {'Configuration':<35s} {'Channels':>8s} {'Params':>10s} {'Test Acc':>9s}")
print(f" {'-' * 65}")
print(f" {'Chance':<35s} {'β€”':>8s} {'β€”':>10s} {'10.0%':>9s}")
for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]):
print(f" {name:<35s} {channels:>8s} {params:>10,d} {acc:>8.1%}")
# Reference results from earlier experiments
print(f"\n {'--- REFERENCE (from earlier) ---':<35s}")
print(f" {'Linear probe (friction flat)':<35s} {'β€”':>8s} {'β€”':>10s} {'24.3%':>9s}")
print(f" {'Linear probe (S flat)':<35s} {'β€”':>8s} {'β€”':>10s} {'21.0%':>9s}")
print(f" {'Patchwork + calibrated embeds':<35s} {'β€”':>8s} {'530K':>10s} {'48.0%':>9s}")
print(f" {'Scatter + conv (raw S)':<35s} {'β€”':>8s} {'2.9M':>10s} {'70.5%':>9s}")
print(f" {'CNN condensed (SGD)':<35s} {'β€”':>8s} {'730K':>10s} {'74.7%':>9s}")
# Lift analysis
s_acc = results.get('S', (0, 0, ''))[0]
best_channels, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0])
print(f"\n S-only conv: {s_acc:.1%}")
print(f" Best conduit: {best_acc:.1%} ({best_name})")
print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp")
print(f" vs scatter+conv reference: {(best_acc - 0.705) * 100:+.1f}pp")