Create notebook_cell_3_theorem_2.py
Browse files- notebook_cell_3_theorem_2.py +525 -0
notebook_cell_3_theorem_2.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cell 3 β Spatial Friction Map Analysis
|
| 3 |
+
========================================
|
| 4 |
+
The mean friction is uniform across classes (12.19 Β± 0.08).
|
| 5 |
+
But the SPATIAL PATTERN of friction within images might differ.
|
| 6 |
+
|
| 7 |
+
Questions:
|
| 8 |
+
1. Do friction maps have spatial structure? (or uniform per image)
|
| 9 |
+
2. Does the spatial pattern differ across classes?
|
| 10 |
+
3. Do edge/boundary patches have higher friction than interior?
|
| 11 |
+
4. Is per-patch friction discriminative even if per-class mean is not?
|
| 12 |
+
5. What does the friction map look like for individual images?
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import numpy as np
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from geolip_core.linalg.conduit import FLEighConduit
|
| 21 |
+
|
| 22 |
+
device = torch.device('cuda')
|
| 23 |
+
|
| 24 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
# LOAD DATA
|
| 26 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
|
| 28 |
+
print("Loading Freckles v40 + CIFAR-10...")
|
| 29 |
+
from geolip_svae import load_model
|
| 30 |
+
import torchvision
|
| 31 |
+
import torchvision.transforms as T
|
| 32 |
+
|
| 33 |
+
freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
|
| 34 |
+
freckles.eval()
|
| 35 |
+
|
| 36 |
+
transform = T.Compose([T.Resize(64), T.ToTensor()])
|
| 37 |
+
cifar_test = torchvision.datasets.CIFAR10(
|
| 38 |
+
root='/content/data', train=False, download=True, transform=transform)
|
| 39 |
+
loader = torch.utils.data.DataLoader(
|
| 40 |
+
cifar_test, batch_size=64, shuffle=False, num_workers=4)
|
| 41 |
+
|
| 42 |
+
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
|
| 43 |
+
'dog', 'frog', 'horse', 'ship', 'truck']
|
| 44 |
+
|
| 45 |
+
conduit = FLEighConduit().to(device)
|
| 46 |
+
gh, gw = 16, 16 # patch grid
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
# COLLECT SPATIAL FRICTION MAPS
|
| 51 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
+
print("Collecting spatial friction maps (full test set)...\n")
|
| 54 |
+
|
| 55 |
+
# Per-class friction maps: (10, gh, gw, D=4)
|
| 56 |
+
class_friction_sum = torch.zeros(10, gh, gw, 4)
|
| 57 |
+
class_friction_sq = torch.zeros(10, gh, gw, 4)
|
| 58 |
+
class_settle_sum = torch.zeros(10, gh, gw, 4)
|
| 59 |
+
class_counts = torch.zeros(10)
|
| 60 |
+
|
| 61 |
+
# Also collect per-image statistics for discriminability analysis
|
| 62 |
+
all_friction_maps = [] # list of (friction_map, label)
|
| 63 |
+
all_settle_maps = []
|
| 64 |
+
|
| 65 |
+
n_images_collected = 0
|
| 66 |
+
max_collect = 2000 # collect individual maps for first 2000 images
|
| 67 |
+
|
| 68 |
+
for images, labels in tqdm(loader, desc="Processing"):
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
out = freckles(images.to(device))
|
| 71 |
+
S = out['svd']['S'] # (B, N, D)
|
| 72 |
+
Vt = out['svd']['Vt'] # (B, N, D, D)
|
| 73 |
+
B_img, N, D = S.shape
|
| 74 |
+
|
| 75 |
+
# Build Gram matrices
|
| 76 |
+
S2 = S.pow(2)
|
| 77 |
+
G = torch.einsum('bnij,bnj,bnjk->bnik',
|
| 78 |
+
Vt.transpose(-2, -1), S2, Vt)
|
| 79 |
+
G_flat = G.reshape(B_img * N, D, D)
|
| 80 |
+
|
| 81 |
+
packet = conduit(G_flat)
|
| 82 |
+
|
| 83 |
+
# Reshape to spatial: (B, gh, gw, D)
|
| 84 |
+
fric_map = packet.friction.reshape(B_img, gh, gw, D)
|
| 85 |
+
sett_map = packet.settle.reshape(B_img, gh, gw, D)
|
| 86 |
+
|
| 87 |
+
fric_cpu = fric_map.cpu()
|
| 88 |
+
sett_cpu = sett_map.cpu()
|
| 89 |
+
|
| 90 |
+
for i in range(B_img):
|
| 91 |
+
c = labels[i].item()
|
| 92 |
+
class_friction_sum[c] += fric_cpu[i]
|
| 93 |
+
class_friction_sq[c] += fric_cpu[i].pow(2)
|
| 94 |
+
class_settle_sum[c] += sett_cpu[i]
|
| 95 |
+
class_counts[c] += 1
|
| 96 |
+
|
| 97 |
+
if n_images_collected < max_collect:
|
| 98 |
+
all_friction_maps.append((fric_cpu[i], c))
|
| 99 |
+
all_settle_maps.append((sett_cpu[i], c))
|
| 100 |
+
n_images_collected += 1
|
| 101 |
+
|
| 102 |
+
print(f"\nCollected {int(class_counts.sum().item())} images, "
|
| 103 |
+
f"{n_images_collected} individual maps\n")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
# 1. SPATIAL STRUCTURE WITHIN IMAGES
|
| 108 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
|
| 110 |
+
print("=" * 70)
|
| 111 |
+
print(" 1. SPATIAL STRUCTURE β Do friction maps have spatial variance?")
|
| 112 |
+
print("=" * 70)
|
| 113 |
+
|
| 114 |
+
# Per-image spatial variance: does friction vary across patches within ONE image?
|
| 115 |
+
per_image_spatial_var = []
|
| 116 |
+
for fric_map, label in all_friction_maps:
|
| 117 |
+
# fric_map: (gh, gw, 4)
|
| 118 |
+
# Spatial variance: how much does friction vary across the 16x16 grid?
|
| 119 |
+
per_mode_var = fric_map.reshape(-1, 4).var(dim=0) # var across 256 patches
|
| 120 |
+
per_image_spatial_var.append((per_mode_var, label))
|
| 121 |
+
|
| 122 |
+
spatial_vars = torch.stack([v for v, _ in per_image_spatial_var]) # (N, 4)
|
| 123 |
+
|
| 124 |
+
print(f"\n Per-image spatial friction variance (across 256 patches):")
|
| 125 |
+
print(f" Mode 0 (Sβ): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}")
|
| 126 |
+
print(f" Mode 1 (Sβ): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}")
|
| 127 |
+
print(f" Mode 2 (Sβ): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}")
|
| 128 |
+
print(f" Mode 3 (Sβ): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}")
|
| 129 |
+
|
| 130 |
+
# Coefficient of variation: spatial_std / spatial_mean per image
|
| 131 |
+
spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps])
|
| 132 |
+
spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps])
|
| 133 |
+
spatial_cv = spatial_stds / (spatial_means + 1e-8)
|
| 134 |
+
|
| 135 |
+
print(f"\n Per-image spatial CV (std/mean):")
|
| 136 |
+
for d in range(4):
|
| 137 |
+
print(f" Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} "
|
| 138 |
+
f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}")
|
| 139 |
+
|
| 140 |
+
has_spatial_structure = spatial_cv.mean() > 0.1
|
| 141 |
+
print(f"\n VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} "
|
| 142 |
+
f"(mean CV = {spatial_cv.mean():.4f})")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
# 2. PER-CLASS SPATIAL FRICTION PATTERNS
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
|
| 149 |
+
print(f"\n{'=' * 70}")
|
| 150 |
+
print(" 2. PER-CLASS SPATIAL PATTERNS β Do classes have different friction maps?")
|
| 151 |
+
print("=" * 70)
|
| 152 |
+
|
| 153 |
+
# Average friction map per class
|
| 154 |
+
class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1)
|
| 155 |
+
class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2)
|
| 156 |
+
|
| 157 |
+
# Flatten spatial maps and compare between classes
|
| 158 |
+
class_flat = class_means.reshape(10, -1) # (10, gh*gw*4)
|
| 159 |
+
|
| 160 |
+
# Inter-class distance matrix
|
| 161 |
+
dists = torch.cdist(class_flat, class_flat)
|
| 162 |
+
|
| 163 |
+
print(f"\n Inter-class friction map L2 distances:")
|
| 164 |
+
print(f" {'':>10s}", end="")
|
| 165 |
+
for c in range(10):
|
| 166 |
+
print(f" {CLASSES[c][:5]:>6s}", end="")
|
| 167 |
+
print()
|
| 168 |
+
for c1 in range(10):
|
| 169 |
+
print(f" {CLASSES[c1][:10]:>10s}", end="")
|
| 170 |
+
for c2 in range(10):
|
| 171 |
+
print(f" {dists[c1, c2]:6.3f}", end="")
|
| 172 |
+
print()
|
| 173 |
+
|
| 174 |
+
# Mean inter-class vs intra-class distance
|
| 175 |
+
inter_mask = ~torch.eye(10, dtype=torch.bool)
|
| 176 |
+
inter_dist = dists[inter_mask].mean().item()
|
| 177 |
+
print(f"\n Mean inter-class distance: {inter_dist:.4f}")
|
| 178 |
+
|
| 179 |
+
# Cosine similarity between class friction maps
|
| 180 |
+
class_flat_norm = F.normalize(class_flat, dim=-1)
|
| 181 |
+
cos_sim = class_flat_norm @ class_flat_norm.T
|
| 182 |
+
cos_off_diag = cos_sim[inter_mask].mean().item()
|
| 183 |
+
cos_min = cos_sim[inter_mask].min().item()
|
| 184 |
+
print(f" Mean cosine similarity: {cos_off_diag:.6f}")
|
| 185 |
+
print(f" Min cosine similarity: {cos_min:.6f}")
|
| 186 |
+
print(f" VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 190 |
+
# 3. CENTER vs EDGE FRICTION
|
| 191 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
|
| 193 |
+
print(f"\n{'=' * 70}")
|
| 194 |
+
print(" 3. CENTER vs EDGE β Do boundary patches have higher friction?")
|
| 195 |
+
print("=" * 70)
|
| 196 |
+
|
| 197 |
+
# Define center and edge regions
|
| 198 |
+
center_mask = torch.zeros(gh, gw, dtype=torch.bool)
|
| 199 |
+
center_mask[4:12, 4:12] = True # center 8Γ8
|
| 200 |
+
edge_mask = ~center_mask # border ring
|
| 201 |
+
|
| 202 |
+
for c in range(10):
|
| 203 |
+
fric_c = class_means[c] # (gh, gw, 4)
|
| 204 |
+
center_fric = fric_c[center_mask].mean().item()
|
| 205 |
+
edge_fric = fric_c[edge_mask].mean().item()
|
| 206 |
+
ratio = edge_fric / (center_fric + 1e-8)
|
| 207 |
+
if c == 0:
|
| 208 |
+
print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}")
|
| 209 |
+
print(f" {'-' * 40}")
|
| 210 |
+
print(f" {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
# 4. PER-PATCH-POSITION DISCRIMINABILITY
|
| 215 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 216 |
+
|
| 217 |
+
print(f"\n{'=' * 70}")
|
| 218 |
+
print(" 4. PER-PATCH-POSITION DISCRIMINABILITY")
|
| 219 |
+
print("=" * 70)
|
| 220 |
+
|
| 221 |
+
# For each patch position (i,j), is friction discriminative across classes?
|
| 222 |
+
# Use inter-class variance / intra-class variance ratio (F-statistic proxy)
|
| 223 |
+
|
| 224 |
+
position_f_stat = torch.zeros(gh, gw, 4)
|
| 225 |
+
|
| 226 |
+
for pi in range(gh):
|
| 227 |
+
for pj in range(gw):
|
| 228 |
+
for d in range(4):
|
| 229 |
+
# Class means at this position
|
| 230 |
+
c_means = class_means[:, pi, pj, d] # (10,)
|
| 231 |
+
# Inter-class variance
|
| 232 |
+
inter_var = c_means.var().item()
|
| 233 |
+
# Intra-class variance (averaged)
|
| 234 |
+
intra_var = class_vars[:, pi, pj, d].mean().item()
|
| 235 |
+
position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10)
|
| 236 |
+
|
| 237 |
+
# Summary
|
| 238 |
+
print(f"\n F-statistic (inter-class var / intra-class var) per mode:")
|
| 239 |
+
for d in range(4):
|
| 240 |
+
fs = position_f_stat[:, :, d]
|
| 241 |
+
print(f" Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} "
|
| 242 |
+
f"top 5% threshold={fs.quantile(0.95):.6f}")
|
| 243 |
+
|
| 244 |
+
# Best discriminative positions
|
| 245 |
+
for d in range(4):
|
| 246 |
+
fs = position_f_stat[:, :, d]
|
| 247 |
+
best_idx = fs.argmax()
|
| 248 |
+
bi, bj = best_idx // gw, best_idx % gw
|
| 249 |
+
print(f" Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}")
|
| 250 |
+
|
| 251 |
+
overall_f = position_f_stat.mean(dim=-1) # avg across modes
|
| 252 |
+
print(f"\n Overall best discriminative patch position: "
|
| 253 |
+
f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} "
|
| 254 |
+
f"F={overall_f.max():.6f}")
|
| 255 |
+
print(f" Overall mean F-statistic: {overall_f.mean():.6f}")
|
| 256 |
+
print(f" VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
# 5. PER-MODE ANALYSIS β Which SVD mode carries most spatial variance?
|
| 261 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 262 |
+
|
| 263 |
+
print(f"\n{'=' * 70}")
|
| 264 |
+
print(" 5. PER-MODE SPATIAL VARIANCE β Which mode has the most structure?")
|
| 265 |
+
print("=" * 70)
|
| 266 |
+
|
| 267 |
+
for d in range(4):
|
| 268 |
+
# Spatial variance of mean friction map (across all images)
|
| 269 |
+
overall_mean_map = class_friction_sum.sum(0) / class_counts.sum() # (gh, gw, 4)
|
| 270 |
+
mode_map = overall_mean_map[:, :, d]
|
| 271 |
+
sv = mode_map.var().item()
|
| 272 |
+
sm = mode_map.mean().item()
|
| 273 |
+
print(f" Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 277 |
+
# 6. INDIVIDUAL IMAGE FRICTION MAPS
|
| 278 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 279 |
+
|
| 280 |
+
print(f"\n{'=' * 70}")
|
| 281 |
+
print(" 6. SAMPLE FRICTION MAPS β Individual images")
|
| 282 |
+
print("=" * 70)
|
| 283 |
+
|
| 284 |
+
# Show friction statistics for 2 images per class
|
| 285 |
+
for c in range(10):
|
| 286 |
+
maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2]
|
| 287 |
+
for idx, (fric_map, _) in enumerate(maps_c):
|
| 288 |
+
# fric_map: (gh, gw, 4)
|
| 289 |
+
flat = fric_map.reshape(-1, 4)
|
| 290 |
+
fmean = flat.mean(0)
|
| 291 |
+
fstd = flat.std(0)
|
| 292 |
+
fmin = flat.min(0).values
|
| 293 |
+
fmax = flat.max(0).values
|
| 294 |
+
|
| 295 |
+
# Spatial entropy: how concentrated is the friction?
|
| 296 |
+
fric_total = flat.sum(dim=-1) # per-patch total friction
|
| 297 |
+
fric_prob = fric_total / (fric_total.sum() + 1e-8)
|
| 298 |
+
entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item()
|
| 299 |
+
max_entropy = np.log(256) # uniform = max entropy
|
| 300 |
+
|
| 301 |
+
# Hot spots: patches with friction > 2Γ mean
|
| 302 |
+
hot = (fric_total > 2 * fric_total.mean()).sum().item()
|
| 303 |
+
|
| 304 |
+
if idx == 0 and c == 0:
|
| 305 |
+
print(f"\n {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} "
|
| 306 |
+
f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}")
|
| 307 |
+
print(f" {'-' * 55}")
|
| 308 |
+
|
| 309 |
+
print(f" {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} "
|
| 310 |
+
f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 314 |
+
# 7. FRICTION MAP AS CLASSIFIER β Linear probe on spatial friction
|
| 315 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 316 |
+
|
| 317 |
+
print(f"\n{'=' * 70}")
|
| 318 |
+
print(" 7. LINEAR PROBE β Can flattened friction maps classify?")
|
| 319 |
+
print("=" * 70)
|
| 320 |
+
|
| 321 |
+
# Collect features and labels
|
| 322 |
+
features = []
|
| 323 |
+
labels_all = []
|
| 324 |
+
for fric_map, label in all_friction_maps:
|
| 325 |
+
features.append(fric_map.reshape(-1)) # (gh*gw*4,) = 1024
|
| 326 |
+
labels_all.append(label)
|
| 327 |
+
|
| 328 |
+
X = torch.stack(features) # (N, 1024)
|
| 329 |
+
y = torch.tensor(labels_all) # (N,)
|
| 330 |
+
|
| 331 |
+
# Train/test split
|
| 332 |
+
N = len(y)
|
| 333 |
+
perm = torch.randperm(N)
|
| 334 |
+
n_train = int(0.8 * N)
|
| 335 |
+
X_train, y_train = X[perm[:n_train]], y[perm[:n_train]]
|
| 336 |
+
X_test, y_test = X[perm[n_train:]], y[perm[n_train:]]
|
| 337 |
+
|
| 338 |
+
# Standardize
|
| 339 |
+
mean = X_train.mean(0)
|
| 340 |
+
std = X_train.std(0).clamp(min=1e-6)
|
| 341 |
+
X_train_n = (X_train - mean) / std
|
| 342 |
+
X_test_n = (X_test - mean) / std
|
| 343 |
+
|
| 344 |
+
# Ridge regression (closed form, no training loop)
|
| 345 |
+
lam = 1.0
|
| 346 |
+
n_classes = 10
|
| 347 |
+
Y_onehot = torch.zeros(n_train, n_classes)
|
| 348 |
+
Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0)
|
| 349 |
+
|
| 350 |
+
XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1])
|
| 351 |
+
XtY = X_train_n.T @ Y_onehot
|
| 352 |
+
W = torch.linalg.solve(XtX, XtY)
|
| 353 |
+
|
| 354 |
+
train_pred = (X_train_n @ W).argmax(1)
|
| 355 |
+
test_pred = (X_test_n @ W).argmax(1)
|
| 356 |
+
train_acc = (train_pred == y_train).float().mean().item()
|
| 357 |
+
test_acc = (test_pred == y_test).float().mean().item()
|
| 358 |
+
|
| 359 |
+
print(f"\n Features: flattened friction map ({X.shape[1]} dims)")
|
| 360 |
+
print(f" Train: {n_train}, Test: {N - n_train}")
|
| 361 |
+
print(f" Train accuracy: {train_acc:.1%}")
|
| 362 |
+
print(f" Test accuracy: {test_acc:.1%}")
|
| 363 |
+
print(f" Chance: 10.0%")
|
| 364 |
+
|
| 365 |
+
# Per-class accuracy
|
| 366 |
+
print(f"\n {'Class':<10s} {'Acc':>6s}")
|
| 367 |
+
print(f" {'-' * 18}")
|
| 368 |
+
for c in range(n_classes):
|
| 369 |
+
mask = y_test == c
|
| 370 |
+
if mask.sum() > 0:
|
| 371 |
+
acc = (test_pred[mask] == y_test[mask]).float().mean().item()
|
| 372 |
+
bar = 'β' * int(acc * 20)
|
| 373 |
+
print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}")
|
| 374 |
+
|
| 375 |
+
print(f"\n VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} "
|
| 376 |
+
f"spatial friction signal")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 380 |
+
# 8. SETTLE MAP ANALYSIS β Same treatment for settle times
|
| 381 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 382 |
+
|
| 383 |
+
print(f"\n{'=' * 70}")
|
| 384 |
+
print(" 8. SETTLE MAP β Spatial convergence patterns")
|
| 385 |
+
print("=" * 70)
|
| 386 |
+
|
| 387 |
+
settle_features = []
|
| 388 |
+
settle_labels = []
|
| 389 |
+
for sett_map, label in all_settle_maps:
|
| 390 |
+
settle_features.append(sett_map.reshape(-1))
|
| 391 |
+
settle_labels.append(label)
|
| 392 |
+
|
| 393 |
+
X_s = torch.stack(settle_features)
|
| 394 |
+
y_s = torch.tensor(settle_labels)
|
| 395 |
+
|
| 396 |
+
perm_s = torch.randperm(len(y_s))
|
| 397 |
+
n_train_s = int(0.8 * len(y_s))
|
| 398 |
+
X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]]
|
| 399 |
+
X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]]
|
| 400 |
+
|
| 401 |
+
mean_s = X_train_s.mean(0)
|
| 402 |
+
std_s = X_train_s.std(0).clamp(min=1e-6)
|
| 403 |
+
X_train_sn = (X_train_s - mean_s) / std_s
|
| 404 |
+
X_test_sn = (X_test_s - mean_s) / std_s
|
| 405 |
+
|
| 406 |
+
Y_onehot_s = torch.zeros(n_train_s, n_classes)
|
| 407 |
+
Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0)
|
| 408 |
+
XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1])
|
| 409 |
+
XtY_s = X_train_sn.T @ Y_onehot_s
|
| 410 |
+
W_s = torch.linalg.solve(XtX_s, XtY_s)
|
| 411 |
+
|
| 412 |
+
test_pred_s = (X_test_sn @ W_s).argmax(1)
|
| 413 |
+
test_acc_s = (test_pred_s == y_test_s).float().mean().item()
|
| 414 |
+
|
| 415 |
+
print(f" Settle map linear probe:")
|
| 416 |
+
print(f" Test accuracy: {test_acc_s:.1%}")
|
| 417 |
+
print(f" VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}")
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 421 |
+
# 9. COMBINED CONDUIT β friction + settle + eigenvalues
|
| 422 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 423 |
+
|
| 424 |
+
print(f"\n{'=' * 70}")
|
| 425 |
+
print(" 9. COMBINED CONDUIT β All evidence stacked")
|
| 426 |
+
print("=" * 70)
|
| 427 |
+
|
| 428 |
+
# Also test: raw eigenvalues (S values) as spatial maps for comparison
|
| 429 |
+
print("\n Collecting eigenvalue spatial maps...")
|
| 430 |
+
all_eval_maps = []
|
| 431 |
+
all_combined = []
|
| 432 |
+
|
| 433 |
+
for fric_map, label in all_friction_maps:
|
| 434 |
+
pass # Already collected
|
| 435 |
+
|
| 436 |
+
# Re-collect with eigenvalues
|
| 437 |
+
eval_features = []
|
| 438 |
+
combined_features = []
|
| 439 |
+
combined_labels = []
|
| 440 |
+
|
| 441 |
+
idx = 0
|
| 442 |
+
for images, labels_batch in loader:
|
| 443 |
+
if idx >= max_collect:
|
| 444 |
+
break
|
| 445 |
+
with torch.no_grad():
|
| 446 |
+
out = freckles(images.to(device))
|
| 447 |
+
S = out['svd']['S']
|
| 448 |
+
Vt = out['svd']['Vt']
|
| 449 |
+
B_img, N, D = S.shape
|
| 450 |
+
|
| 451 |
+
S2 = S.pow(2)
|
| 452 |
+
G = torch.einsum('bnij,bnj,bnjk->bnik',
|
| 453 |
+
Vt.transpose(-2, -1), S2, Vt)
|
| 454 |
+
G_flat = G.reshape(B_img * N, D, D)
|
| 455 |
+
packet = conduit(G_flat)
|
| 456 |
+
|
| 457 |
+
fric = packet.friction.reshape(B_img, gh, gw, D)
|
| 458 |
+
sett = packet.settle.reshape(B_img, gh, gw, D)
|
| 459 |
+
evals = S.reshape(B_img, gh, gw, D) # S values as spatial map
|
| 460 |
+
|
| 461 |
+
for i in range(B_img):
|
| 462 |
+
if idx >= max_collect:
|
| 463 |
+
break
|
| 464 |
+
# Eigenvalue spatial map
|
| 465 |
+
eval_features.append(evals[i].cpu().reshape(-1))
|
| 466 |
+
# Combined: friction + settle + eigenvalues
|
| 467 |
+
combined = torch.cat([
|
| 468 |
+
fric[i].cpu().reshape(-1),
|
| 469 |
+
sett[i].cpu().reshape(-1),
|
| 470 |
+
evals[i].cpu().reshape(-1),
|
| 471 |
+
])
|
| 472 |
+
combined_features.append(combined)
|
| 473 |
+
combined_labels.append(labels_batch[i].item())
|
| 474 |
+
idx += 1
|
| 475 |
+
|
| 476 |
+
# Eigenvalue-only probe
|
| 477 |
+
X_e = torch.stack(eval_features)
|
| 478 |
+
y_e = torch.tensor(combined_labels)
|
| 479 |
+
|
| 480 |
+
perm_e = torch.randperm(len(y_e))
|
| 481 |
+
n_train_e = int(0.8 * len(y_e))
|
| 482 |
+
|
| 483 |
+
def ridge_probe(X, y, perm, n_train, name):
|
| 484 |
+
X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
|
| 485 |
+
X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
|
| 486 |
+
m = X_tr.mean(0)
|
| 487 |
+
s = X_tr.std(0).clamp(min=1e-6)
|
| 488 |
+
X_tr_n = (X_tr - m) / s
|
| 489 |
+
X_te_n = (X_te - m) / s
|
| 490 |
+
Y_oh = torch.zeros(n_train, n_classes)
|
| 491 |
+
Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
|
| 492 |
+
W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
|
| 493 |
+
acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
|
| 494 |
+
print(f" {name:<30s} dims={X.shape[1]:>5d} test_acc={acc:.1%}")
|
| 495 |
+
return acc
|
| 496 |
+
|
| 497 |
+
print(f"\n Linear probe comparison (all use same train/test split):\n")
|
| 498 |
+
acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial")
|
| 499 |
+
acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial")
|
| 500 |
+
acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial")
|
| 501 |
+
|
| 502 |
+
X_c = torch.stack(combined_features)
|
| 503 |
+
acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)")
|
| 504 |
+
|
| 505 |
+
print(f"\n Chance: 10.0%")
|
| 506 |
+
print(f" VERDICT: Combined vs eigenvalues-only lift = "
|
| 507 |
+
f"{(acc_comb - acc_evals) * 100:+.1f} percentage points")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 511 |
+
# SUMMARY
|
| 512 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 513 |
+
|
| 514 |
+
print(f"\n{'=' * 70}")
|
| 515 |
+
print(" SPATIAL FRICTION ANALYSIS β SUMMARY")
|
| 516 |
+
print("=" * 70)
|
| 517 |
+
print(f" 1. Spatial structure within images: CV = {spatial_cv.mean():.4f}")
|
| 518 |
+
print(f" 2. Inter-class pattern distance: cos_min = {cos_min:.6f}")
|
| 519 |
+
print(f" 3. Center vs edge asymmetry: (see table above)")
|
| 520 |
+
print(f" 4. Per-position F-statistic: max = {overall_f.max():.6f}")
|
| 521 |
+
print(f" 5. Friction map linear probe: {test_acc:.1%}")
|
| 522 |
+
print(f" 6. Settle map linear probe: {test_acc_s:.1%}")
|
| 523 |
+
print(f" 7. Eigenvalue map linear probe: {acc_evals:.1%}")
|
| 524 |
+
print(f" 8. Combined conduit linear probe: {acc_comb:.1%}")
|
| 525 |
+
print(f" 9. Conduit lift over eigenvalues: {(acc_comb - acc_evals)*100:+.1f}pp")
|