Create cell_4_proper_experiment_3.py
Browse files- cell_4_proper_experiment_3.py +482 -0
cell_4_proper_experiment_3.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cell 4 β Theorem 3: Release Fidelity
|
| 3 |
+
=======================================
|
| 4 |
+
The light speeding back up after leaving the lens.
|
| 5 |
+
|
| 6 |
+
Full encodeβSVDβdecode round-trip reconstruction analysis.
|
| 7 |
+
NOT the SVD-only residual (which is ~1e-12).
|
| 8 |
+
The FULL decoder reconstruction β where the model chooses
|
| 9 |
+
what to preserve and what to lose.
|
| 10 |
+
|
| 11 |
+
Questions:
|
| 12 |
+
1. Does per-patch reconstruction error vary spatially?
|
| 13 |
+
2. Does it differ across classes?
|
| 14 |
+
3. Per-mode reconstruction: which modes matter for which patches?
|
| 15 |
+
4. Does the release residual map classify better than friction?
|
| 16 |
+
5. Combined release + friction + eigenvalues β full conduit test
|
| 17 |
+
6. Where does the model FAIL to reconstruct? Those are the boundaries.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import numpy as np
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
device = torch.device('cuda')
|
| 26 |
+
|
| 27 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
# LOAD
|
| 29 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
|
| 31 |
+
print("Loading Freckles v40 + CIFAR-10...")
|
| 32 |
+
from geolip_svae import load_model
|
| 33 |
+
from geolip_svae.model import extract_patches, stitch_patches
|
| 34 |
+
import torchvision
|
| 35 |
+
import torchvision.transforms as T
|
| 36 |
+
|
| 37 |
+
freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
|
| 38 |
+
freckles.eval()
|
| 39 |
+
|
| 40 |
+
ps = freckles.patch_size # 4
|
| 41 |
+
transform = T.Compose([T.Resize(64), T.ToTensor()])
|
| 42 |
+
cifar_test = torchvision.datasets.CIFAR10(
|
| 43 |
+
root='/content/data', train=False, download=True, transform=transform)
|
| 44 |
+
loader = torch.utils.data.DataLoader(
|
| 45 |
+
cifar_test, batch_size=64, shuffle=False, num_workers=4)
|
| 46 |
+
|
| 47 |
+
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
|
| 48 |
+
'dog', 'frog', 'horse', 'ship', 'truck']
|
| 49 |
+
|
| 50 |
+
gh, gw = 64 // ps, 64 // ps # 16, 16
|
| 51 |
+
n_patches = gh * gw # 256
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
# 1. FULL ROUND-TRIP RECONSTRUCTION β Per-patch error maps
|
| 56 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
print("\n" + "=" * 70)
|
| 59 |
+
print(" 1. FULL ROUND-TRIP β Per-patch reconstruction error")
|
| 60 |
+
print("=" * 70)
|
| 61 |
+
|
| 62 |
+
print("\nCollecting per-patch reconstruction errors...\n")
|
| 63 |
+
|
| 64 |
+
# Per-class spatial error maps
|
| 65 |
+
class_error_sum = torch.zeros(10, gh, gw)
|
| 66 |
+
class_error_sq = torch.zeros(10, gh, gw)
|
| 67 |
+
class_counts = torch.zeros(10)
|
| 68 |
+
|
| 69 |
+
# Individual maps for probing
|
| 70 |
+
all_error_maps = [] # (error_map, label)
|
| 71 |
+
all_s_maps = [] # (S_map, label)
|
| 72 |
+
max_collect = 2000
|
| 73 |
+
n_collected = 0
|
| 74 |
+
|
| 75 |
+
for images, labels in tqdm(loader, desc="Reconstructing"):
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
images_gpu = images.to(device)
|
| 78 |
+
out = freckles(images_gpu)
|
| 79 |
+
recon = out['recon']
|
| 80 |
+
|
| 81 |
+
B = images_gpu.shape[0]
|
| 82 |
+
S = out['svd']['S'] # (B, N, D)
|
| 83 |
+
|
| 84 |
+
# Per-patch error: split input and recon into patches, compare
|
| 85 |
+
# Input patches: (B, N, C*ps*ps)
|
| 86 |
+
input_patches, _, _ = extract_patches(images_gpu, ps)
|
| 87 |
+
recon_patches, _, _ = extract_patches(recon, ps)
|
| 88 |
+
|
| 89 |
+
# Per-patch MSE: (B, N)
|
| 90 |
+
patch_mse = (input_patches - recon_patches).pow(2).mean(dim=-1)
|
| 91 |
+
|
| 92 |
+
# Reshape to spatial: (B, gh, gw)
|
| 93 |
+
error_map = patch_mse.reshape(B, gh, gw)
|
| 94 |
+
s_map = S.reshape(B, gh, gw, -1)
|
| 95 |
+
|
| 96 |
+
error_cpu = error_map.cpu()
|
| 97 |
+
s_cpu = s_map.cpu()
|
| 98 |
+
|
| 99 |
+
for i in range(B):
|
| 100 |
+
c = labels[i].item()
|
| 101 |
+
class_error_sum[c] += error_cpu[i]
|
| 102 |
+
class_error_sq[c] += error_cpu[i].pow(2)
|
| 103 |
+
class_counts[c] += 1
|
| 104 |
+
|
| 105 |
+
if n_collected < max_collect:
|
| 106 |
+
all_error_maps.append((error_cpu[i], c))
|
| 107 |
+
all_s_maps.append((s_cpu[i], c))
|
| 108 |
+
n_collected += 1
|
| 109 |
+
|
| 110 |
+
print(f"Collected {int(class_counts.sum().item())} images, "
|
| 111 |
+
f"{n_collected} individual maps\n")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
# 1a. SPATIAL STRUCTURE OF RECONSTRUCTION ERROR
|
| 116 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
+
|
| 118 |
+
print("=" * 70)
|
| 119 |
+
print(" 1a. SPATIAL STRUCTURE β Does recon error vary across patches?")
|
| 120 |
+
print("=" * 70)
|
| 121 |
+
|
| 122 |
+
per_image_cv = []
|
| 123 |
+
for error_map, label in all_error_maps:
|
| 124 |
+
flat = error_map.reshape(-1)
|
| 125 |
+
cv = flat.std() / (flat.mean() + 1e-10)
|
| 126 |
+
per_image_cv.append(cv.item())
|
| 127 |
+
|
| 128 |
+
cv_arr = np.array(per_image_cv)
|
| 129 |
+
print(f"\n Per-image spatial CV of reconstruction error:")
|
| 130 |
+
print(f" Mean CV: {cv_arr.mean():.4f}")
|
| 131 |
+
print(f" Median CV: {np.median(cv_arr):.4f}")
|
| 132 |
+
print(f" Min CV: {cv_arr.min():.4f}")
|
| 133 |
+
print(f" Max CV: {cv_arr.max():.4f}")
|
| 134 |
+
print(f" VERDICT: {'HAS SPATIAL STRUCTURE' if cv_arr.mean() > 0.1 else 'SPATIALLY UNIFORM'}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
# 1b. PER-CLASS RECONSTRUCTION ERROR
|
| 139 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
|
| 141 |
+
print(f"\n{'=' * 70}")
|
| 142 |
+
print(" 1b. PER-CLASS RECONSTRUCTION ERROR")
|
| 143 |
+
print("=" * 70)
|
| 144 |
+
|
| 145 |
+
class_means = class_error_sum / class_counts[:, None, None].clamp(min=1)
|
| 146 |
+
class_vars = class_error_sq / class_counts[:, None, None].clamp(min=1) - class_means.pow(2)
|
| 147 |
+
|
| 148 |
+
print(f"\n {'Class':<10s} {'Mean MSE':>10s} {'Std MSE':>10s} {'Max patch':>10s}")
|
| 149 |
+
print(f" {'-' * 42}")
|
| 150 |
+
for c in range(10):
|
| 151 |
+
m = class_means[c]
|
| 152 |
+
print(f" {CLASSES[c]:<10s} {m.mean():10.6f} {m.std():10.6f} {m.max():10.6f}")
|
| 153 |
+
|
| 154 |
+
# Inter-class distance
|
| 155 |
+
class_flat = class_means.reshape(10, -1)
|
| 156 |
+
class_flat_norm = F.normalize(class_flat, dim=-1)
|
| 157 |
+
cos_sim = class_flat_norm @ class_flat_norm.T
|
| 158 |
+
inter_mask = ~torch.eye(10, dtype=torch.bool)
|
| 159 |
+
print(f"\n Mean inter-class cosine similarity: {cos_sim[inter_mask].mean():.6f}")
|
| 160 |
+
print(f" Min inter-class cosine similarity: {cos_sim[inter_mask].min():.6f}")
|
| 161 |
+
print(f" VERDICT: {'DISTINCT PATTERNS' if cos_sim[inter_mask].min() < 0.99 else 'SIMILAR PATTERNS'}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
+
# 2. CENTER vs EDGE RECONSTRUCTION
|
| 166 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
|
| 168 |
+
print(f"\n{'=' * 70}")
|
| 169 |
+
print(" 2. CENTER vs EDGE β Where does reconstruction fail?")
|
| 170 |
+
print("=" * 70)
|
| 171 |
+
|
| 172 |
+
center_mask = torch.zeros(gh, gw, dtype=torch.bool)
|
| 173 |
+
center_mask[4:12, 4:12] = True
|
| 174 |
+
edge_mask = ~center_mask
|
| 175 |
+
|
| 176 |
+
# Corner masks for finer granularity
|
| 177 |
+
corner_mask = torch.zeros(gh, gw, dtype=torch.bool)
|
| 178 |
+
corner_mask[:4, :4] = True
|
| 179 |
+
corner_mask[:4, 12:] = True
|
| 180 |
+
corner_mask[12:, :4] = True
|
| 181 |
+
corner_mask[12:, 12:] = True
|
| 182 |
+
|
| 183 |
+
print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Corner':>8s} {'E/C ratio':>10s}")
|
| 184 |
+
print(f" {'-' * 48}")
|
| 185 |
+
for c in range(10):
|
| 186 |
+
m = class_means[c]
|
| 187 |
+
center = m[center_mask].mean().item()
|
| 188 |
+
edge = m[edge_mask].mean().item()
|
| 189 |
+
corner = m[corner_mask].mean().item()
|
| 190 |
+
ratio = edge / (center + 1e-10)
|
| 191 |
+
print(f" {CLASSES[c]:<10s} {center:8.6f} {edge:8.6f} {corner:8.6f} {ratio:10.4f}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
# 3. PER-MODE RECONSTRUCTION β Which modes carry class signal?
|
| 196 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 197 |
+
|
| 198 |
+
print(f"\n{'=' * 70}")
|
| 199 |
+
print(" 3. PER-MODE RECONSTRUCTION β Ablating SVD modes")
|
| 200 |
+
print("=" * 70)
|
| 201 |
+
|
| 202 |
+
print("\nReconstructing with individual modes...")
|
| 203 |
+
|
| 204 |
+
# For a subset, reconstruct using only mode k
|
| 205 |
+
n_ablate = 256
|
| 206 |
+
subset = torch.utils.data.Subset(cifar_test, range(n_ablate))
|
| 207 |
+
ablate_loader = torch.utils.data.DataLoader(subset, batch_size=64)
|
| 208 |
+
|
| 209 |
+
mode_errors = {k: [] for k in range(4)}
|
| 210 |
+
mode_labels = []
|
| 211 |
+
full_errors = []
|
| 212 |
+
|
| 213 |
+
for images, labels in ablate_loader:
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
images_gpu = images.to(device)
|
| 216 |
+
out = freckles(images_gpu)
|
| 217 |
+
|
| 218 |
+
S = out['svd']['S'] # (B, N, D)
|
| 219 |
+
U = out['svd']['U'] # (B, N, V, D)
|
| 220 |
+
Vt = out['svd']['Vt'] # (B, N, D, D)
|
| 221 |
+
B_img, N, D = S.shape
|
| 222 |
+
|
| 223 |
+
# Full reconstruction error per patch
|
| 224 |
+
recon = out['recon']
|
| 225 |
+
input_p, _, _ = extract_patches(images_gpu, ps)
|
| 226 |
+
recon_p, _, _ = extract_patches(recon, ps)
|
| 227 |
+
full_err = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N)
|
| 228 |
+
full_errors.append(full_err.cpu())
|
| 229 |
+
|
| 230 |
+
# Per-mode ablation: reconstruct using only mode k
|
| 231 |
+
for k in range(D):
|
| 232 |
+
# Zero out all modes except k
|
| 233 |
+
S_ablated = torch.zeros_like(S)
|
| 234 |
+
S_ablated[:, :, k] = S[:, :, k]
|
| 235 |
+
|
| 236 |
+
# Reconstruct: decoded_patches = U @ diag(S) @ Vt
|
| 237 |
+
decoded = torch.einsum('bnvd,bnd,bndk->bnvk', U, S_ablated, Vt)
|
| 238 |
+
# decoded: (B, N, V, D) but we need (B, N, V*D) = (B, N, patch_dim)
|
| 239 |
+
# Actually the SVAE decoder is more complex β it uses cross-attention.
|
| 240 |
+
# For a clean per-mode test, compare S_ablated contribution to full S.
|
| 241 |
+
# Mode k's contribution to the enc_out matrix M:
|
| 242 |
+
# M_k = U[:,:,:,k] * S[:,:,k] @ Vt[:,:,k,:]
|
| 243 |
+
# Fraction of total energy in mode k:
|
| 244 |
+
mode_energy = S[:, :, k].pow(2) / (S.pow(2).sum(dim=-1) + 1e-10)
|
| 245 |
+
mode_errors[k].append(mode_energy.cpu())
|
| 246 |
+
|
| 247 |
+
mode_labels.append(labels)
|
| 248 |
+
|
| 249 |
+
mode_labels = torch.cat(mode_labels)
|
| 250 |
+
full_errors = torch.cat(full_errors) # (N_img, N_patches)
|
| 251 |
+
|
| 252 |
+
print(f"\n Per-mode energy fraction (how much each mode contributes):")
|
| 253 |
+
print(f"\n {'Class':<10s}", end="")
|
| 254 |
+
for k in range(4):
|
| 255 |
+
print(f" {'Mode'+str(k):>8s}", end="")
|
| 256 |
+
print(f" {'FullMSE':>10s}")
|
| 257 |
+
print(f" {'-' * 50}")
|
| 258 |
+
|
| 259 |
+
for c in range(10):
|
| 260 |
+
mask = mode_labels == c
|
| 261 |
+
if mask.sum() == 0:
|
| 262 |
+
continue
|
| 263 |
+
print(f" {CLASSES[c]:<10s}", end="")
|
| 264 |
+
for k in range(4):
|
| 265 |
+
me = torch.cat(mode_errors[k])
|
| 266 |
+
energy = me[mask].mean().item()
|
| 267 |
+
print(f" {energy:8.4f}", end="")
|
| 268 |
+
ferr = full_errors[mask].mean().item()
|
| 269 |
+
print(f" {ferr:10.6f}")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 273 |
+
# 4. RECONSTRUCTION ERROR AS CLASSIFIER
|
| 274 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 275 |
+
|
| 276 |
+
print(f"\n{'=' * 70}")
|
| 277 |
+
print(" 4. LINEAR PROBE β Reconstruction error maps as features")
|
| 278 |
+
print("=" * 70)
|
| 279 |
+
|
| 280 |
+
# Flatten per-patch error map as feature
|
| 281 |
+
error_features = []
|
| 282 |
+
error_labels = []
|
| 283 |
+
for error_map, label in all_error_maps:
|
| 284 |
+
error_features.append(error_map.reshape(-1)) # (256,)
|
| 285 |
+
error_labels.append(label)
|
| 286 |
+
|
| 287 |
+
X_err = torch.stack(error_features) # (N, 256)
|
| 288 |
+
y_err = torch.tensor(error_labels)
|
| 289 |
+
|
| 290 |
+
N = len(y_err)
|
| 291 |
+
perm = torch.randperm(N)
|
| 292 |
+
n_train = int(0.8 * N)
|
| 293 |
+
n_classes = 10
|
| 294 |
+
|
| 295 |
+
def ridge_probe(X, y, perm, n_train, name, lam=1.0):
|
| 296 |
+
X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
|
| 297 |
+
X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
|
| 298 |
+
m = X_tr.mean(0)
|
| 299 |
+
s = X_tr.std(0).clamp(min=1e-8)
|
| 300 |
+
X_tr_n = (X_tr - m) / s
|
| 301 |
+
X_te_n = (X_te - m) / s
|
| 302 |
+
Y_oh = torch.zeros(len(y_tr), n_classes)
|
| 303 |
+
Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
|
| 304 |
+
W = torch.linalg.solve(
|
| 305 |
+
X_tr_n.T @ X_tr_n + lam * torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
|
| 306 |
+
train_acc = ((X_tr_n @ W).argmax(1) == y_tr).float().mean().item()
|
| 307 |
+
test_acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
|
| 308 |
+
print(f" {name:<40s} dims={X.shape[1]:>5d} "
|
| 309 |
+
f"train={train_acc:.1%} test={test_acc:.1%}")
|
| 310 |
+
|
| 311 |
+
# Per-class
|
| 312 |
+
preds = (X_te_n @ W).argmax(1)
|
| 313 |
+
for c in range(n_classes):
|
| 314 |
+
cm = y_te == c
|
| 315 |
+
if cm.sum() > 0:
|
| 316 |
+
acc = (preds[cm] == y_te[cm]).float().mean().item()
|
| 317 |
+
if c == 0:
|
| 318 |
+
print(f" {'Class':<10s} {'Acc':>6s}")
|
| 319 |
+
print(f" {'-' * 18}")
|
| 320 |
+
bar = 'β' * int(acc * 20)
|
| 321 |
+
print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}")
|
| 322 |
+
return test_acc
|
| 323 |
+
|
| 324 |
+
print(f"\n Ridge probe comparison:\n")
|
| 325 |
+
acc_err = ridge_probe(X_err, y_err, perm, n_train, "Recon error spatial map")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
+
# 5. COMBINED: RELEASE + EIGENVALUES + FRICTION
|
| 330 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 331 |
+
|
| 332 |
+
print(f"\n{'=' * 70}")
|
| 333 |
+
print(" 5. FULL CONDUIT β Release error + eigenvalues + friction")
|
| 334 |
+
print("=" * 70)
|
| 335 |
+
|
| 336 |
+
# Rebuild combined features with release error included
|
| 337 |
+
from geolip_core.linalg.conduit import FLEighConduit
|
| 338 |
+
conduit = FLEighConduit().to(device)
|
| 339 |
+
|
| 340 |
+
combined_features = []
|
| 341 |
+
combined_labels = []
|
| 342 |
+
n_collected2 = 0
|
| 343 |
+
|
| 344 |
+
for images, labels in tqdm(loader, desc="Full conduit"):
|
| 345 |
+
if n_collected2 >= max_collect:
|
| 346 |
+
break
|
| 347 |
+
with torch.no_grad():
|
| 348 |
+
images_gpu = images.to(device)
|
| 349 |
+
out = freckles(images_gpu)
|
| 350 |
+
recon = out['recon']
|
| 351 |
+
S = out['svd']['S']
|
| 352 |
+
Vt = out['svd']['Vt']
|
| 353 |
+
B_img, N, D = S.shape
|
| 354 |
+
|
| 355 |
+
# Per-patch recon error
|
| 356 |
+
input_p, _, _ = extract_patches(images_gpu, ps)
|
| 357 |
+
recon_p, _, _ = extract_patches(recon, ps)
|
| 358 |
+
patch_mse = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N)
|
| 359 |
+
|
| 360 |
+
# Friction from conduit
|
| 361 |
+
S2 = S.pow(2)
|
| 362 |
+
G = torch.einsum('bnij,bnj,bnjk->bnik',
|
| 363 |
+
Vt.transpose(-2, -1), S2, Vt)
|
| 364 |
+
G_flat = G.reshape(B_img * N, D, D)
|
| 365 |
+
packet = conduit(G_flat)
|
| 366 |
+
fric = packet.friction.reshape(B_img, N, D)
|
| 367 |
+
|
| 368 |
+
# Combine: error_map(256) + S_map(256Γ4) + friction_map(256Γ4)
|
| 369 |
+
err_flat = patch_mse.reshape(B_img, gh * gw)
|
| 370 |
+
s_flat = S.reshape(B_img, gh * gw * D)
|
| 371 |
+
f_flat = fric.reshape(B_img, gh * gw * D)
|
| 372 |
+
|
| 373 |
+
for i in range(B_img):
|
| 374 |
+
if n_collected2 >= max_collect:
|
| 375 |
+
break
|
| 376 |
+
feat = torch.cat([
|
| 377 |
+
err_flat[i].cpu(),
|
| 378 |
+
s_flat[i].cpu(),
|
| 379 |
+
f_flat[i].cpu(),
|
| 380 |
+
])
|
| 381 |
+
combined_features.append(feat)
|
| 382 |
+
combined_labels.append(labels[i].item())
|
| 383 |
+
n_collected2 += 1
|
| 384 |
+
|
| 385 |
+
X_full = torch.stack(combined_features)
|
| 386 |
+
y_full = torch.tensor(combined_labels)
|
| 387 |
+
|
| 388 |
+
perm2 = torch.randperm(len(y_full))
|
| 389 |
+
n_train2 = int(0.8 * len(y_full))
|
| 390 |
+
|
| 391 |
+
print(f"\n Comparative linear probes:\n")
|
| 392 |
+
|
| 393 |
+
# Individual features
|
| 394 |
+
X_err_only = X_full[:, :256]
|
| 395 |
+
X_s_only = X_full[:, 256:256 + 256 * 4]
|
| 396 |
+
X_f_only = X_full[:, 256 + 256 * 4:]
|
| 397 |
+
|
| 398 |
+
acc_err2 = ridge_probe(X_err_only, y_full, perm2, n_train2,
|
| 399 |
+
"Release error only")
|
| 400 |
+
print()
|
| 401 |
+
acc_s2 = ridge_probe(X_s_only, y_full, perm2, n_train2,
|
| 402 |
+
"Eigenvalues (S) only")
|
| 403 |
+
print()
|
| 404 |
+
acc_f2 = ridge_probe(X_f_only, y_full, perm2, n_train2,
|
| 405 |
+
"Friction only")
|
| 406 |
+
|
| 407 |
+
# Combinations
|
| 408 |
+
print(f"\n Combinations:\n")
|
| 409 |
+
X_err_s = torch.cat([X_err_only, X_s_only], dim=-1)
|
| 410 |
+
acc_err_s = ridge_probe(X_err_s, y_full, perm2, n_train2,
|
| 411 |
+
"Release + Eigenvalues")
|
| 412 |
+
|
| 413 |
+
X_err_f = torch.cat([X_err_only, X_f_only], dim=-1)
|
| 414 |
+
acc_err_f = ridge_probe(X_err_f, y_full, perm2, n_train2,
|
| 415 |
+
"Release + Friction")
|
| 416 |
+
|
| 417 |
+
acc_all = ridge_probe(X_full, y_full, perm2, n_train2,
|
| 418 |
+
"Release + Eigenvalues + Friction")
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 422 |
+
# 6. HIGH-ERROR PATCH ANALYSIS β Where does the model fail?
|
| 423 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 424 |
+
|
| 425 |
+
print(f"\n{'=' * 70}")
|
| 426 |
+
print(" 6. HIGH-ERROR PATCHES β Where does reconstruction fail?")
|
| 427 |
+
print("=" * 70)
|
| 428 |
+
|
| 429 |
+
# For each class, find patches with highest error
|
| 430 |
+
print(f"\n Top error positions per class (patch coordinates):")
|
| 431 |
+
print(f" {'Class':<10s} {'Top 3 positions (row, col)':>40s} {'Error ratio':>12s}")
|
| 432 |
+
print(f" {'-' * 64}")
|
| 433 |
+
|
| 434 |
+
for c in range(10):
|
| 435 |
+
cm = class_means[c] # (gh, gw)
|
| 436 |
+
flat = cm.reshape(-1)
|
| 437 |
+
top3 = flat.argsort(descending=True)[:3]
|
| 438 |
+
positions = [(idx.item() // gw, idx.item() % gw) for idx in top3]
|
| 439 |
+
errs = [flat[idx].item() for idx in top3]
|
| 440 |
+
mean_err = cm.mean().item()
|
| 441 |
+
ratio = errs[0] / (mean_err + 1e-10)
|
| 442 |
+
pos_str = ", ".join(f"({r},{c_})" for r, c_ in positions)
|
| 443 |
+
print(f" {CLASSES[c]:<10s} {pos_str:>40s} {ratio:12.2f}x")
|
| 444 |
+
|
| 445 |
+
# Overall hot spots across all classes
|
| 446 |
+
overall_error = class_error_sum.sum(0) / class_counts.sum()
|
| 447 |
+
hot_threshold = overall_error.mean() + 2 * overall_error.std()
|
| 448 |
+
hot_patches = (overall_error > hot_threshold).sum().item()
|
| 449 |
+
print(f"\n Overall error map:")
|
| 450 |
+
print(f" Mean: {overall_error.mean():.6f}")
|
| 451 |
+
print(f" Std: {overall_error.std():.6f}")
|
| 452 |
+
print(f" Hot patches (>2Ο): {hot_patches}/{gh * gw}")
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 456 |
+
# SUMMARY
|
| 457 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 458 |
+
|
| 459 |
+
print(f"\n{'=' * 70}")
|
| 460 |
+
print(" THEOREM 3: RELEASE FIDELITY β SUMMARY")
|
| 461 |
+
print("=" * 70)
|
| 462 |
+
|
| 463 |
+
print(f"""
|
| 464 |
+
SPATIAL STRUCTURE:
|
| 465 |
+
Recon error spatial CV: {cv_arr.mean():.4f}
|
| 466 |
+
(Friction spatial CV was: 0.0137)
|
| 467 |
+
|
| 468 |
+
CLASSIFICATION (ridge probe, test accuracy):
|
| 469 |
+
Chance: 10.0%
|
| 470 |
+
Friction maps: 24.3% (from Cell 3)
|
| 471 |
+
Eigenvalue (S) maps: 21.0% (from Cell 3)
|
| 472 |
+
Release error maps: {acc_err:.1%}
|
| 473 |
+
Release + Eigenvalues: {acc_err_s:.1%}
|
| 474 |
+
Release + Friction: {acc_err_f:.1%}
|
| 475 |
+
FULL CONDUIT (all three): {acc_all:.1%}
|
| 476 |
+
|
| 477 |
+
THE QUESTION ANSWERED:
|
| 478 |
+
Does the release signal carry class-discriminative information
|
| 479 |
+
that eigenvalues and friction do not?
|
| 480 |
+
Lift from release over eigenvalues: {(acc_err2 - acc_s2) * 100:+.1f}pp
|
| 481 |
+
Lift from full conduit over eigenvalues: {(acc_all - acc_s2) * 100:+.1f}pp
|
| 482 |
+
""")
|