""" Cell 4 — Theorem 3: Release Fidelity ======================================= The light speeding back up after leaving the lens. Full encode→SVD→decode round-trip reconstruction analysis. NOT the SVD-only residual (which is ~1e-12). The FULL decoder reconstruction — where the model chooses what to preserve and what to lose. Questions: 1. Does per-patch reconstruction error vary spatially? 2. Does it differ across classes? 3. Per-mode reconstruction: which modes matter for which patches? 4. Does the release residual map classify better than friction? 5. Combined release + friction + eigenvalues — full conduit test 6. Where does the model FAIL to reconstruct? Those are the boundaries. """ import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm device = torch.device('cuda') # ═══════════════════════════════════════════════════════════════ # LOAD # ═══════════════════════════════════════════════════════════════ print("Loading Freckles v40 + CIFAR-10...") from geolip_svae import load_model from geolip_svae.model import extract_patches, stitch_patches import torchvision import torchvision.transforms as T freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device) freckles.eval() ps = freckles.patch_size # 4 transform = T.Compose([T.Resize(64), T.ToTensor()]) cifar_test = torchvision.datasets.CIFAR10( root='/content/data', train=False, download=True, transform=transform) loader = torch.utils.data.DataLoader( cifar_test, batch_size=64, shuffle=False, num_workers=4) CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] gh, gw = 64 // ps, 64 // ps # 16, 16 n_patches = gh * gw # 256 # ═══════════════════════════════════════════════════════════════ # 1. FULL ROUND-TRIP RECONSTRUCTION — Per-patch error maps # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" 1. FULL ROUND-TRIP — Per-patch reconstruction error") print("=" * 70) print("\nCollecting per-patch reconstruction errors...\n") # Per-class spatial error maps class_error_sum = torch.zeros(10, gh, gw) class_error_sq = torch.zeros(10, gh, gw) class_counts = torch.zeros(10) # Individual maps for probing all_error_maps = [] # (error_map, label) all_s_maps = [] # (S_map, label) max_collect = 2000 n_collected = 0 for images, labels in tqdm(loader, desc="Reconstructing"): with torch.no_grad(): images_gpu = images.to(device) out = freckles(images_gpu) recon = out['recon'] B = images_gpu.shape[0] S = out['svd']['S'] # (B, N, D) # Per-patch error: split input and recon into patches, compare # Input patches: (B, N, C*ps*ps) input_patches, _, _ = extract_patches(images_gpu, ps) recon_patches, _, _ = extract_patches(recon, ps) # Per-patch MSE: (B, N) patch_mse = (input_patches - recon_patches).pow(2).mean(dim=-1) # Reshape to spatial: (B, gh, gw) error_map = patch_mse.reshape(B, gh, gw) s_map = S.reshape(B, gh, gw, -1) error_cpu = error_map.cpu() s_cpu = s_map.cpu() for i in range(B): c = labels[i].item() class_error_sum[c] += error_cpu[i] class_error_sq[c] += error_cpu[i].pow(2) class_counts[c] += 1 if n_collected < max_collect: all_error_maps.append((error_cpu[i], c)) all_s_maps.append((s_cpu[i], c)) n_collected += 1 print(f"Collected {int(class_counts.sum().item())} images, " f"{n_collected} individual maps\n") # ═══════════════════════════════════════════════════════════════ # 1a. SPATIAL STRUCTURE OF RECONSTRUCTION ERROR # ═══════════════════════════════════════════════════════════════ print("=" * 70) print(" 1a. SPATIAL STRUCTURE — Does recon error vary across patches?") print("=" * 70) per_image_cv = [] for error_map, label in all_error_maps: flat = error_map.reshape(-1) cv = flat.std() / (flat.mean() + 1e-10) per_image_cv.append(cv.item()) cv_arr = np.array(per_image_cv) print(f"\n Per-image spatial CV of reconstruction error:") print(f" Mean CV: {cv_arr.mean():.4f}") print(f" Median CV: {np.median(cv_arr):.4f}") print(f" Min CV: {cv_arr.min():.4f}") print(f" Max CV: {cv_arr.max():.4f}") print(f" VERDICT: {'HAS SPATIAL STRUCTURE' if cv_arr.mean() > 0.1 else 'SPATIALLY UNIFORM'}") # ═══════════════════════════════════════════════════════════════ # 1b. PER-CLASS RECONSTRUCTION ERROR # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 1b. PER-CLASS RECONSTRUCTION ERROR") print("=" * 70) class_means = class_error_sum / class_counts[:, None, None].clamp(min=1) class_vars = class_error_sq / class_counts[:, None, None].clamp(min=1) - class_means.pow(2) print(f"\n {'Class':<10s} {'Mean MSE':>10s} {'Std MSE':>10s} {'Max patch':>10s}") print(f" {'-' * 42}") for c in range(10): m = class_means[c] print(f" {CLASSES[c]:<10s} {m.mean():10.6f} {m.std():10.6f} {m.max():10.6f}") # Inter-class distance class_flat = class_means.reshape(10, -1) class_flat_norm = F.normalize(class_flat, dim=-1) cos_sim = class_flat_norm @ class_flat_norm.T inter_mask = ~torch.eye(10, dtype=torch.bool) print(f"\n Mean inter-class cosine similarity: {cos_sim[inter_mask].mean():.6f}") print(f" Min inter-class cosine similarity: {cos_sim[inter_mask].min():.6f}") print(f" VERDICT: {'DISTINCT PATTERNS' if cos_sim[inter_mask].min() < 0.99 else 'SIMILAR PATTERNS'}") # ═══════════════════════════════════════════════════════════════ # 2. CENTER vs EDGE RECONSTRUCTION # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 2. CENTER vs EDGE — Where does reconstruction fail?") print("=" * 70) center_mask = torch.zeros(gh, gw, dtype=torch.bool) center_mask[4:12, 4:12] = True edge_mask = ~center_mask # Corner masks for finer granularity corner_mask = torch.zeros(gh, gw, dtype=torch.bool) corner_mask[:4, :4] = True corner_mask[:4, 12:] = True corner_mask[12:, :4] = True corner_mask[12:, 12:] = True print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Corner':>8s} {'E/C ratio':>10s}") print(f" {'-' * 48}") for c in range(10): m = class_means[c] center = m[center_mask].mean().item() edge = m[edge_mask].mean().item() corner = m[corner_mask].mean().item() ratio = edge / (center + 1e-10) print(f" {CLASSES[c]:<10s} {center:8.6f} {edge:8.6f} {corner:8.6f} {ratio:10.4f}") # ═══════════════════════════════════════════════════════════════ # 3. PER-MODE RECONSTRUCTION — Which modes carry class signal? # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 3. PER-MODE RECONSTRUCTION — Ablating SVD modes") print("=" * 70) print("\nReconstructing with individual modes...") # For a subset, reconstruct using only mode k n_ablate = 256 subset = torch.utils.data.Subset(cifar_test, range(n_ablate)) ablate_loader = torch.utils.data.DataLoader(subset, batch_size=64) mode_errors = {k: [] for k in range(4)} mode_labels = [] full_errors = [] for images, labels in ablate_loader: with torch.no_grad(): images_gpu = images.to(device) out = freckles(images_gpu) S = out['svd']['S'] # (B, N, D) U = out['svd']['U'] # (B, N, V, D) Vt = out['svd']['Vt'] # (B, N, D, D) B_img, N, D = S.shape # Full reconstruction error per patch recon = out['recon'] input_p, _, _ = extract_patches(images_gpu, ps) recon_p, _, _ = extract_patches(recon, ps) full_err = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N) full_errors.append(full_err.cpu()) # Per-mode ablation: reconstruct using only mode k for k in range(D): # Zero out all modes except k S_ablated = torch.zeros_like(S) S_ablated[:, :, k] = S[:, :, k] # Reconstruct: decoded_patches = U @ diag(S) @ Vt decoded = torch.einsum('bnvd,bnd,bndk->bnvk', U, S_ablated, Vt) # decoded: (B, N, V, D) but we need (B, N, V*D) = (B, N, patch_dim) # Actually the SVAE decoder is more complex — it uses cross-attention. # For a clean per-mode test, compare S_ablated contribution to full S. # Mode k's contribution to the enc_out matrix M: # M_k = U[:,:,:,k] * S[:,:,k] @ Vt[:,:,k,:] # Fraction of total energy in mode k: mode_energy = S[:, :, k].pow(2) / (S.pow(2).sum(dim=-1) + 1e-10) mode_errors[k].append(mode_energy.cpu()) mode_labels.append(labels) mode_labels = torch.cat(mode_labels) full_errors = torch.cat(full_errors) # (N_img, N_patches) print(f"\n Per-mode energy fraction (how much each mode contributes):") print(f"\n {'Class':<10s}", end="") for k in range(4): print(f" {'Mode'+str(k):>8s}", end="") print(f" {'FullMSE':>10s}") print(f" {'-' * 50}") for c in range(10): mask = mode_labels == c if mask.sum() == 0: continue print(f" {CLASSES[c]:<10s}", end="") for k in range(4): me = torch.cat(mode_errors[k]) energy = me[mask].mean().item() print(f" {energy:8.4f}", end="") ferr = full_errors[mask].mean().item() print(f" {ferr:10.6f}") # ═══════════════════════════════════════════════════════════════ # 4. RECONSTRUCTION ERROR AS CLASSIFIER # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 4. LINEAR PROBE — Reconstruction error maps as features") print("=" * 70) # Flatten per-patch error map as feature error_features = [] error_labels = [] for error_map, label in all_error_maps: error_features.append(error_map.reshape(-1)) # (256,) error_labels.append(label) X_err = torch.stack(error_features) # (N, 256) y_err = torch.tensor(error_labels) N = len(y_err) perm = torch.randperm(N) n_train = int(0.8 * N) n_classes = 10 def ridge_probe(X, y, perm, n_train, name, lam=1.0): X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]] X_te, y_te = X[perm[n_train:]], y[perm[n_train:]] m = X_tr.mean(0) s = X_tr.std(0).clamp(min=1e-8) X_tr_n = (X_tr - m) / s X_te_n = (X_te - m) / s Y_oh = torch.zeros(len(y_tr), n_classes) Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0) W = torch.linalg.solve( X_tr_n.T @ X_tr_n + lam * torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh) train_acc = ((X_tr_n @ W).argmax(1) == y_tr).float().mean().item() test_acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item() print(f" {name:<40s} dims={X.shape[1]:>5d} " f"train={train_acc:.1%} test={test_acc:.1%}") # Per-class preds = (X_te_n @ W).argmax(1) for c in range(n_classes): cm = y_te == c if cm.sum() > 0: acc = (preds[cm] == y_te[cm]).float().mean().item() if c == 0: print(f" {'Class':<10s} {'Acc':>6s}") print(f" {'-' * 18}") bar = '█' * int(acc * 20) print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}") return test_acc print(f"\n Ridge probe comparison:\n") acc_err = ridge_probe(X_err, y_err, perm, n_train, "Recon error spatial map") # ═══════════════════════════════════════════════════════════════ # 5. COMBINED: RELEASE + EIGENVALUES + FRICTION # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 5. FULL CONDUIT — Release error + eigenvalues + friction") print("=" * 70) # Rebuild combined features with release error included from geolip_core.linalg.conduit import FLEighConduit conduit = FLEighConduit().to(device) combined_features = [] combined_labels = [] n_collected2 = 0 for images, labels in tqdm(loader, desc="Full conduit"): if n_collected2 >= max_collect: break with torch.no_grad(): images_gpu = images.to(device) out = freckles(images_gpu) recon = out['recon'] S = out['svd']['S'] Vt = out['svd']['Vt'] B_img, N, D = S.shape # Per-patch recon error input_p, _, _ = extract_patches(images_gpu, ps) recon_p, _, _ = extract_patches(recon, ps) patch_mse = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N) # Friction from conduit S2 = S.pow(2) G = torch.einsum('bnij,bnj,bnjk->bnik', Vt.transpose(-2, -1), S2, Vt) G_flat = G.reshape(B_img * N, D, D) packet = conduit(G_flat) fric = packet.friction.reshape(B_img, N, D) # Combine: error_map(256) + S_map(256×4) + friction_map(256×4) err_flat = patch_mse.reshape(B_img, gh * gw) s_flat = S.reshape(B_img, gh * gw * D) f_flat = fric.reshape(B_img, gh * gw * D) for i in range(B_img): if n_collected2 >= max_collect: break feat = torch.cat([ err_flat[i].cpu(), s_flat[i].cpu(), f_flat[i].cpu(), ]) combined_features.append(feat) combined_labels.append(labels[i].item()) n_collected2 += 1 X_full = torch.stack(combined_features) y_full = torch.tensor(combined_labels) perm2 = torch.randperm(len(y_full)) n_train2 = int(0.8 * len(y_full)) print(f"\n Comparative linear probes:\n") # Individual features X_err_only = X_full[:, :256] X_s_only = X_full[:, 256:256 + 256 * 4] X_f_only = X_full[:, 256 + 256 * 4:] acc_err2 = ridge_probe(X_err_only, y_full, perm2, n_train2, "Release error only") print() acc_s2 = ridge_probe(X_s_only, y_full, perm2, n_train2, "Eigenvalues (S) only") print() acc_f2 = ridge_probe(X_f_only, y_full, perm2, n_train2, "Friction only") # Combinations print(f"\n Combinations:\n") X_err_s = torch.cat([X_err_only, X_s_only], dim=-1) acc_err_s = ridge_probe(X_err_s, y_full, perm2, n_train2, "Release + Eigenvalues") X_err_f = torch.cat([X_err_only, X_f_only], dim=-1) acc_err_f = ridge_probe(X_err_f, y_full, perm2, n_train2, "Release + Friction") acc_all = ridge_probe(X_full, y_full, perm2, n_train2, "Release + Eigenvalues + Friction") # ═══════════════════════════════════════════════════════════════ # 6. HIGH-ERROR PATCH ANALYSIS — Where does the model fail? # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 6. HIGH-ERROR PATCHES — Where does reconstruction fail?") print("=" * 70) # For each class, find patches with highest error print(f"\n Top error positions per class (patch coordinates):") print(f" {'Class':<10s} {'Top 3 positions (row, col)':>40s} {'Error ratio':>12s}") print(f" {'-' * 64}") for c in range(10): cm = class_means[c] # (gh, gw) flat = cm.reshape(-1) top3 = flat.argsort(descending=True)[:3] positions = [(idx.item() // gw, idx.item() % gw) for idx in top3] errs = [flat[idx].item() for idx in top3] mean_err = cm.mean().item() ratio = errs[0] / (mean_err + 1e-10) pos_str = ", ".join(f"({r},{c_})" for r, c_ in positions) print(f" {CLASSES[c]:<10s} {pos_str:>40s} {ratio:12.2f}x") # Overall hot spots across all classes overall_error = class_error_sum.sum(0) / class_counts.sum() hot_threshold = overall_error.mean() + 2 * overall_error.std() hot_patches = (overall_error > hot_threshold).sum().item() print(f"\n Overall error map:") print(f" Mean: {overall_error.mean():.6f}") print(f" Std: {overall_error.std():.6f}") print(f" Hot patches (>2σ): {hot_patches}/{gh * gw}") # ═══════════════════════════════════════════════════════════════ # SUMMARY # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" THEOREM 3: RELEASE FIDELITY — SUMMARY") print("=" * 70) print(f""" SPATIAL STRUCTURE: Recon error spatial CV: {cv_arr.mean():.4f} (Friction spatial CV was: 0.0137) CLASSIFICATION (ridge probe, test accuracy): Chance: 10.0% Friction maps: 24.3% (from Cell 3) Eigenvalue (S) maps: 21.0% (from Cell 3) Release error maps: {acc_err:.1%} Release + Eigenvalues: {acc_err_s:.1%} Release + Friction: {acc_err_f:.1%} FULL CONDUIT (all three): {acc_all:.1%} THE QUESTION ANSWERED: Does the release signal carry class-discriminative information that eigenvalues and friction do not? Lift from release over eigenvalues: {(acc_err2 - acc_s2) * 100:+.1f}pp Lift from full conduit over eigenvalues: {(acc_all - acc_s2) * 100:+.1f}pp """)