geolip-conduit-experiments / notebook_cell_3_theorem_2.py
AbstractPhil's picture
Create notebook_cell_3_theorem_2.py
a986628 verified
"""
Cell 3 β€” Spatial Friction Map Analysis
========================================
The mean friction is uniform across classes (12.19 Β± 0.08).
But the SPATIAL PATTERN of friction within images might differ.
Questions:
1. Do friction maps have spatial structure? (or uniform per image)
2. Does the spatial pattern differ across classes?
3. Do edge/boundary patches have higher friction than interior?
4. Is per-patch friction discriminative even if per-class mean is not?
5. What does the friction map look like for individual images?
"""
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from geolip_core.linalg.conduit import FLEighConduit
device = torch.device('cuda')
# ═══════════════════════════════════════════════════════════════
# LOAD DATA
# ═══════════════════════════════════════════════════════════════
print("Loading Freckles v40 + CIFAR-10...")
from geolip_svae import load_model
import torchvision
import torchvision.transforms as T
freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
freckles.eval()
transform = T.Compose([T.Resize(64), T.ToTensor()])
cifar_test = torchvision.datasets.CIFAR10(
root='/content/data', train=False, download=True, transform=transform)
loader = torch.utils.data.DataLoader(
cifar_test, batch_size=64, shuffle=False, num_workers=4)
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
conduit = FLEighConduit().to(device)
gh, gw = 16, 16 # patch grid
# ═══════════════════════════════════════════════════════════════
# COLLECT SPATIAL FRICTION MAPS
# ═══════════════════════════════════════════════════════════════
print("Collecting spatial friction maps (full test set)...\n")
# Per-class friction maps: (10, gh, gw, D=4)
class_friction_sum = torch.zeros(10, gh, gw, 4)
class_friction_sq = torch.zeros(10, gh, gw, 4)
class_settle_sum = torch.zeros(10, gh, gw, 4)
class_counts = torch.zeros(10)
# Also collect per-image statistics for discriminability analysis
all_friction_maps = [] # list of (friction_map, label)
all_settle_maps = []
n_images_collected = 0
max_collect = 2000 # collect individual maps for first 2000 images
for images, labels in tqdm(loader, desc="Processing"):
with torch.no_grad():
out = freckles(images.to(device))
S = out['svd']['S'] # (B, N, D)
Vt = out['svd']['Vt'] # (B, N, D, D)
B_img, N, D = S.shape
# Build Gram matrices
S2 = S.pow(2)
G = torch.einsum('bnij,bnj,bnjk->bnik',
Vt.transpose(-2, -1), S2, Vt)
G_flat = G.reshape(B_img * N, D, D)
packet = conduit(G_flat)
# Reshape to spatial: (B, gh, gw, D)
fric_map = packet.friction.reshape(B_img, gh, gw, D)
sett_map = packet.settle.reshape(B_img, gh, gw, D)
fric_cpu = fric_map.cpu()
sett_cpu = sett_map.cpu()
for i in range(B_img):
c = labels[i].item()
class_friction_sum[c] += fric_cpu[i]
class_friction_sq[c] += fric_cpu[i].pow(2)
class_settle_sum[c] += sett_cpu[i]
class_counts[c] += 1
if n_images_collected < max_collect:
all_friction_maps.append((fric_cpu[i], c))
all_settle_maps.append((sett_cpu[i], c))
n_images_collected += 1
print(f"\nCollected {int(class_counts.sum().item())} images, "
f"{n_images_collected} individual maps\n")
# ═══════════════════════════════════════════════════════════════
# 1. SPATIAL STRUCTURE WITHIN IMAGES
# ═══════════════════════════════════════════════════════════════
print("=" * 70)
print(" 1. SPATIAL STRUCTURE β€” Do friction maps have spatial variance?")
print("=" * 70)
# Per-image spatial variance: does friction vary across patches within ONE image?
per_image_spatial_var = []
for fric_map, label in all_friction_maps:
# fric_map: (gh, gw, 4)
# Spatial variance: how much does friction vary across the 16x16 grid?
per_mode_var = fric_map.reshape(-1, 4).var(dim=0) # var across 256 patches
per_image_spatial_var.append((per_mode_var, label))
spatial_vars = torch.stack([v for v, _ in per_image_spatial_var]) # (N, 4)
print(f"\n Per-image spatial friction variance (across 256 patches):")
print(f" Mode 0 (Sβ‚€): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}")
print(f" Mode 1 (S₁): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}")
print(f" Mode 2 (Sβ‚‚): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}")
print(f" Mode 3 (S₃): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}")
# Coefficient of variation: spatial_std / spatial_mean per image
spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps])
spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps])
spatial_cv = spatial_stds / (spatial_means + 1e-8)
print(f"\n Per-image spatial CV (std/mean):")
for d in range(4):
print(f" Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} "
f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}")
has_spatial_structure = spatial_cv.mean() > 0.1
print(f"\n VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} "
f"(mean CV = {spatial_cv.mean():.4f})")
# ═══════════════════════════════════════════════════════════════
# 2. PER-CLASS SPATIAL FRICTION PATTERNS
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 2. PER-CLASS SPATIAL PATTERNS β€” Do classes have different friction maps?")
print("=" * 70)
# Average friction map per class
class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1)
class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2)
# Flatten spatial maps and compare between classes
class_flat = class_means.reshape(10, -1) # (10, gh*gw*4)
# Inter-class distance matrix
dists = torch.cdist(class_flat, class_flat)
print(f"\n Inter-class friction map L2 distances:")
print(f" {'':>10s}", end="")
for c in range(10):
print(f" {CLASSES[c][:5]:>6s}", end="")
print()
for c1 in range(10):
print(f" {CLASSES[c1][:10]:>10s}", end="")
for c2 in range(10):
print(f" {dists[c1, c2]:6.3f}", end="")
print()
# Mean inter-class vs intra-class distance
inter_mask = ~torch.eye(10, dtype=torch.bool)
inter_dist = dists[inter_mask].mean().item()
print(f"\n Mean inter-class distance: {inter_dist:.4f}")
# Cosine similarity between class friction maps
class_flat_norm = F.normalize(class_flat, dim=-1)
cos_sim = class_flat_norm @ class_flat_norm.T
cos_off_diag = cos_sim[inter_mask].mean().item()
cos_min = cos_sim[inter_mask].min().item()
print(f" Mean cosine similarity: {cos_off_diag:.6f}")
print(f" Min cosine similarity: {cos_min:.6f}")
print(f" VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}")
# ═══════════════════════════════════════════════════════════════
# 3. CENTER vs EDGE FRICTION
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 3. CENTER vs EDGE β€” Do boundary patches have higher friction?")
print("=" * 70)
# Define center and edge regions
center_mask = torch.zeros(gh, gw, dtype=torch.bool)
center_mask[4:12, 4:12] = True # center 8Γ—8
edge_mask = ~center_mask # border ring
for c in range(10):
fric_c = class_means[c] # (gh, gw, 4)
center_fric = fric_c[center_mask].mean().item()
edge_fric = fric_c[edge_mask].mean().item()
ratio = edge_fric / (center_fric + 1e-8)
if c == 0:
print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}")
print(f" {'-' * 40}")
print(f" {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}")
# ═══════════════════════════════════════════════════════════════
# 4. PER-PATCH-POSITION DISCRIMINABILITY
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 4. PER-PATCH-POSITION DISCRIMINABILITY")
print("=" * 70)
# For each patch position (i,j), is friction discriminative across classes?
# Use inter-class variance / intra-class variance ratio (F-statistic proxy)
position_f_stat = torch.zeros(gh, gw, 4)
for pi in range(gh):
for pj in range(gw):
for d in range(4):
# Class means at this position
c_means = class_means[:, pi, pj, d] # (10,)
# Inter-class variance
inter_var = c_means.var().item()
# Intra-class variance (averaged)
intra_var = class_vars[:, pi, pj, d].mean().item()
position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10)
# Summary
print(f"\n F-statistic (inter-class var / intra-class var) per mode:")
for d in range(4):
fs = position_f_stat[:, :, d]
print(f" Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} "
f"top 5% threshold={fs.quantile(0.95):.6f}")
# Best discriminative positions
for d in range(4):
fs = position_f_stat[:, :, d]
best_idx = fs.argmax()
bi, bj = best_idx // gw, best_idx % gw
print(f" Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}")
overall_f = position_f_stat.mean(dim=-1) # avg across modes
print(f"\n Overall best discriminative patch position: "
f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} "
f"F={overall_f.max():.6f}")
print(f" Overall mean F-statistic: {overall_f.mean():.6f}")
print(f" VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}")
# ═══════════════════════════════════════════════════════════════
# 5. PER-MODE ANALYSIS β€” Which SVD mode carries most spatial variance?
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 5. PER-MODE SPATIAL VARIANCE β€” Which mode has the most structure?")
print("=" * 70)
for d in range(4):
# Spatial variance of mean friction map (across all images)
overall_mean_map = class_friction_sum.sum(0) / class_counts.sum() # (gh, gw, 4)
mode_map = overall_mean_map[:, :, d]
sv = mode_map.var().item()
sm = mode_map.mean().item()
print(f" Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}")
# ═══════════════════════════════════════════════════════════════
# 6. INDIVIDUAL IMAGE FRICTION MAPS
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 6. SAMPLE FRICTION MAPS β€” Individual images")
print("=" * 70)
# Show friction statistics for 2 images per class
for c in range(10):
maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2]
for idx, (fric_map, _) in enumerate(maps_c):
# fric_map: (gh, gw, 4)
flat = fric_map.reshape(-1, 4)
fmean = flat.mean(0)
fstd = flat.std(0)
fmin = flat.min(0).values
fmax = flat.max(0).values
# Spatial entropy: how concentrated is the friction?
fric_total = flat.sum(dim=-1) # per-patch total friction
fric_prob = fric_total / (fric_total.sum() + 1e-8)
entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item()
max_entropy = np.log(256) # uniform = max entropy
# Hot spots: patches with friction > 2Γ— mean
hot = (fric_total > 2 * fric_total.mean()).sum().item()
if idx == 0 and c == 0:
print(f"\n {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} "
f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}")
print(f" {'-' * 55}")
print(f" {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} "
f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}")
# ═══════════════════════════════════════════════════════════════
# 7. FRICTION MAP AS CLASSIFIER β€” Linear probe on spatial friction
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 7. LINEAR PROBE β€” Can flattened friction maps classify?")
print("=" * 70)
# Collect features and labels
features = []
labels_all = []
for fric_map, label in all_friction_maps:
features.append(fric_map.reshape(-1)) # (gh*gw*4,) = 1024
labels_all.append(label)
X = torch.stack(features) # (N, 1024)
y = torch.tensor(labels_all) # (N,)
# Train/test split
N = len(y)
perm = torch.randperm(N)
n_train = int(0.8 * N)
X_train, y_train = X[perm[:n_train]], y[perm[:n_train]]
X_test, y_test = X[perm[n_train:]], y[perm[n_train:]]
# Standardize
mean = X_train.mean(0)
std = X_train.std(0).clamp(min=1e-6)
X_train_n = (X_train - mean) / std
X_test_n = (X_test - mean) / std
# Ridge regression (closed form, no training loop)
lam = 1.0
n_classes = 10
Y_onehot = torch.zeros(n_train, n_classes)
Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0)
XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1])
XtY = X_train_n.T @ Y_onehot
W = torch.linalg.solve(XtX, XtY)
train_pred = (X_train_n @ W).argmax(1)
test_pred = (X_test_n @ W).argmax(1)
train_acc = (train_pred == y_train).float().mean().item()
test_acc = (test_pred == y_test).float().mean().item()
print(f"\n Features: flattened friction map ({X.shape[1]} dims)")
print(f" Train: {n_train}, Test: {N - n_train}")
print(f" Train accuracy: {train_acc:.1%}")
print(f" Test accuracy: {test_acc:.1%}")
print(f" Chance: 10.0%")
# Per-class accuracy
print(f"\n {'Class':<10s} {'Acc':>6s}")
print(f" {'-' * 18}")
for c in range(n_classes):
mask = y_test == c
if mask.sum() > 0:
acc = (test_pred[mask] == y_test[mask]).float().mean().item()
bar = 'β–ˆ' * int(acc * 20)
print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}")
print(f"\n VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} "
f"spatial friction signal")
# ═══════════════════════════════════════════════════════════════
# 8. SETTLE MAP ANALYSIS β€” Same treatment for settle times
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 8. SETTLE MAP β€” Spatial convergence patterns")
print("=" * 70)
settle_features = []
settle_labels = []
for sett_map, label in all_settle_maps:
settle_features.append(sett_map.reshape(-1))
settle_labels.append(label)
X_s = torch.stack(settle_features)
y_s = torch.tensor(settle_labels)
perm_s = torch.randperm(len(y_s))
n_train_s = int(0.8 * len(y_s))
X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]]
X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]]
mean_s = X_train_s.mean(0)
std_s = X_train_s.std(0).clamp(min=1e-6)
X_train_sn = (X_train_s - mean_s) / std_s
X_test_sn = (X_test_s - mean_s) / std_s
Y_onehot_s = torch.zeros(n_train_s, n_classes)
Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0)
XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1])
XtY_s = X_train_sn.T @ Y_onehot_s
W_s = torch.linalg.solve(XtX_s, XtY_s)
test_pred_s = (X_test_sn @ W_s).argmax(1)
test_acc_s = (test_pred_s == y_test_s).float().mean().item()
print(f" Settle map linear probe:")
print(f" Test accuracy: {test_acc_s:.1%}")
print(f" VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}")
# ═══════════════════════════════════════════════════════════════
# 9. COMBINED CONDUIT β€” friction + settle + eigenvalues
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 9. COMBINED CONDUIT β€” All evidence stacked")
print("=" * 70)
# Also test: raw eigenvalues (S values) as spatial maps for comparison
print("\n Collecting eigenvalue spatial maps...")
all_eval_maps = []
all_combined = []
for fric_map, label in all_friction_maps:
pass # Already collected
# Re-collect with eigenvalues
eval_features = []
combined_features = []
combined_labels = []
idx = 0
for images, labels_batch in loader:
if idx >= max_collect:
break
with torch.no_grad():
out = freckles(images.to(device))
S = out['svd']['S']
Vt = out['svd']['Vt']
B_img, N, D = S.shape
S2 = S.pow(2)
G = torch.einsum('bnij,bnj,bnjk->bnik',
Vt.transpose(-2, -1), S2, Vt)
G_flat = G.reshape(B_img * N, D, D)
packet = conduit(G_flat)
fric = packet.friction.reshape(B_img, gh, gw, D)
sett = packet.settle.reshape(B_img, gh, gw, D)
evals = S.reshape(B_img, gh, gw, D) # S values as spatial map
for i in range(B_img):
if idx >= max_collect:
break
# Eigenvalue spatial map
eval_features.append(evals[i].cpu().reshape(-1))
# Combined: friction + settle + eigenvalues
combined = torch.cat([
fric[i].cpu().reshape(-1),
sett[i].cpu().reshape(-1),
evals[i].cpu().reshape(-1),
])
combined_features.append(combined)
combined_labels.append(labels_batch[i].item())
idx += 1
# Eigenvalue-only probe
X_e = torch.stack(eval_features)
y_e = torch.tensor(combined_labels)
perm_e = torch.randperm(len(y_e))
n_train_e = int(0.8 * len(y_e))
def ridge_probe(X, y, perm, n_train, name):
X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
m = X_tr.mean(0)
s = X_tr.std(0).clamp(min=1e-6)
X_tr_n = (X_tr - m) / s
X_te_n = (X_te - m) / s
Y_oh = torch.zeros(n_train, n_classes)
Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
print(f" {name:<30s} dims={X.shape[1]:>5d} test_acc={acc:.1%}")
return acc
print(f"\n Linear probe comparison (all use same train/test split):\n")
acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial")
acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial")
acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial")
X_c = torch.stack(combined_features)
acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)")
print(f"\n Chance: 10.0%")
print(f" VERDICT: Combined vs eigenvalues-only lift = "
f"{(acc_comb - acc_evals) * 100:+.1f} percentage points")
# ═══════════════════════════════════════════════════════════════
# SUMMARY
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" SPATIAL FRICTION ANALYSIS β€” SUMMARY")
print("=" * 70)
print(f" 1. Spatial structure within images: CV = {spatial_cv.mean():.4f}")
print(f" 2. Inter-class pattern distance: cos_min = {cos_min:.6f}")
print(f" 3. Center vs edge asymmetry: (see table above)")
print(f" 4. Per-position F-statistic: max = {overall_f.max():.6f}")
print(f" 5. Friction map linear probe: {test_acc:.1%}")
print(f" 6. Settle map linear probe: {test_acc_s:.1%}")
print(f" 7. Eigenvalue map linear probe: {acc_evals:.1%}")
print(f" 8. Combined conduit linear probe: {acc_comb:.1%}")
print(f" 9. Conduit lift over eigenvalues: {(acc_comb - acc_evals)*100:+.1f}pp")