| """ |
| Cell 3 β Spatial Friction Map Analysis |
| ======================================== |
| The mean friction is uniform across classes (12.19 Β± 0.08). |
| But the SPATIAL PATTERN of friction within images might differ. |
| |
| Questions: |
| 1. Do friction maps have spatial structure? (or uniform per image) |
| 2. Does the spatial pattern differ across classes? |
| 3. Do edge/boundary patches have higher friction than interior? |
| 4. Is per-patch friction discriminative even if per-class mean is not? |
| 5. What does the friction map look like for individual images? |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from geolip_core.linalg.conduit import FLEighConduit |
|
|
| device = torch.device('cuda') |
|
|
| |
| |
| |
|
|
| print("Loading Freckles v40 + CIFAR-10...") |
| from geolip_svae import load_model |
| import torchvision |
| import torchvision.transforms as T |
|
|
| freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device) |
| freckles.eval() |
|
|
| transform = T.Compose([T.Resize(64), T.ToTensor()]) |
| cifar_test = torchvision.datasets.CIFAR10( |
| root='/content/data', train=False, download=True, transform=transform) |
| loader = torch.utils.data.DataLoader( |
| cifar_test, batch_size=64, shuffle=False, num_workers=4) |
|
|
| CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| conduit = FLEighConduit().to(device) |
| gh, gw = 16, 16 |
|
|
|
|
| |
| |
| |
|
|
| print("Collecting spatial friction maps (full test set)...\n") |
|
|
| |
| class_friction_sum = torch.zeros(10, gh, gw, 4) |
| class_friction_sq = torch.zeros(10, gh, gw, 4) |
| class_settle_sum = torch.zeros(10, gh, gw, 4) |
| class_counts = torch.zeros(10) |
|
|
| |
| all_friction_maps = [] |
| all_settle_maps = [] |
|
|
| n_images_collected = 0 |
| max_collect = 2000 |
|
|
| for images, labels in tqdm(loader, desc="Processing"): |
| with torch.no_grad(): |
| out = freckles(images.to(device)) |
| S = out['svd']['S'] |
| Vt = out['svd']['Vt'] |
| B_img, N, D = S.shape |
|
|
| |
| S2 = S.pow(2) |
| G = torch.einsum('bnij,bnj,bnjk->bnik', |
| Vt.transpose(-2, -1), S2, Vt) |
| G_flat = G.reshape(B_img * N, D, D) |
|
|
| packet = conduit(G_flat) |
|
|
| |
| fric_map = packet.friction.reshape(B_img, gh, gw, D) |
| sett_map = packet.settle.reshape(B_img, gh, gw, D) |
|
|
| fric_cpu = fric_map.cpu() |
| sett_cpu = sett_map.cpu() |
|
|
| for i in range(B_img): |
| c = labels[i].item() |
| class_friction_sum[c] += fric_cpu[i] |
| class_friction_sq[c] += fric_cpu[i].pow(2) |
| class_settle_sum[c] += sett_cpu[i] |
| class_counts[c] += 1 |
|
|
| if n_images_collected < max_collect: |
| all_friction_maps.append((fric_cpu[i], c)) |
| all_settle_maps.append((sett_cpu[i], c)) |
| n_images_collected += 1 |
|
|
| print(f"\nCollected {int(class_counts.sum().item())} images, " |
| f"{n_images_collected} individual maps\n") |
|
|
|
|
| |
| |
| |
|
|
| print("=" * 70) |
| print(" 1. SPATIAL STRUCTURE β Do friction maps have spatial variance?") |
| print("=" * 70) |
|
|
| |
| per_image_spatial_var = [] |
| for fric_map, label in all_friction_maps: |
| |
| |
| per_mode_var = fric_map.reshape(-1, 4).var(dim=0) |
| per_image_spatial_var.append((per_mode_var, label)) |
|
|
| spatial_vars = torch.stack([v for v, _ in per_image_spatial_var]) |
|
|
| print(f"\n Per-image spatial friction variance (across 256 patches):") |
| print(f" Mode 0 (Sβ): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}") |
| print(f" Mode 1 (Sβ): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}") |
| print(f" Mode 2 (Sβ): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}") |
| print(f" Mode 3 (Sβ): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}") |
|
|
| |
| spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps]) |
| spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps]) |
| spatial_cv = spatial_stds / (spatial_means + 1e-8) |
|
|
| print(f"\n Per-image spatial CV (std/mean):") |
| for d in range(4): |
| print(f" Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} " |
| f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}") |
|
|
| has_spatial_structure = spatial_cv.mean() > 0.1 |
| print(f"\n VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} " |
| f"(mean CV = {spatial_cv.mean():.4f})") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 2. PER-CLASS SPATIAL PATTERNS β Do classes have different friction maps?") |
| print("=" * 70) |
|
|
| |
| class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1) |
| class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2) |
|
|
| |
| class_flat = class_means.reshape(10, -1) |
|
|
| |
| dists = torch.cdist(class_flat, class_flat) |
|
|
| print(f"\n Inter-class friction map L2 distances:") |
| print(f" {'':>10s}", end="") |
| for c in range(10): |
| print(f" {CLASSES[c][:5]:>6s}", end="") |
| print() |
| for c1 in range(10): |
| print(f" {CLASSES[c1][:10]:>10s}", end="") |
| for c2 in range(10): |
| print(f" {dists[c1, c2]:6.3f}", end="") |
| print() |
|
|
| |
| inter_mask = ~torch.eye(10, dtype=torch.bool) |
| inter_dist = dists[inter_mask].mean().item() |
| print(f"\n Mean inter-class distance: {inter_dist:.4f}") |
|
|
| |
| class_flat_norm = F.normalize(class_flat, dim=-1) |
| cos_sim = class_flat_norm @ class_flat_norm.T |
| cos_off_diag = cos_sim[inter_mask].mean().item() |
| cos_min = cos_sim[inter_mask].min().item() |
| print(f" Mean cosine similarity: {cos_off_diag:.6f}") |
| print(f" Min cosine similarity: {cos_min:.6f}") |
| print(f" VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 3. CENTER vs EDGE β Do boundary patches have higher friction?") |
| print("=" * 70) |
|
|
| |
| center_mask = torch.zeros(gh, gw, dtype=torch.bool) |
| center_mask[4:12, 4:12] = True |
| edge_mask = ~center_mask |
|
|
| for c in range(10): |
| fric_c = class_means[c] |
| center_fric = fric_c[center_mask].mean().item() |
| edge_fric = fric_c[edge_mask].mean().item() |
| ratio = edge_fric / (center_fric + 1e-8) |
| if c == 0: |
| print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}") |
| print(f" {'-' * 40}") |
| print(f" {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 4. PER-PATCH-POSITION DISCRIMINABILITY") |
| print("=" * 70) |
|
|
| |
| |
|
|
| position_f_stat = torch.zeros(gh, gw, 4) |
|
|
| for pi in range(gh): |
| for pj in range(gw): |
| for d in range(4): |
| |
| c_means = class_means[:, pi, pj, d] |
| |
| inter_var = c_means.var().item() |
| |
| intra_var = class_vars[:, pi, pj, d].mean().item() |
| position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10) |
|
|
| |
| print(f"\n F-statistic (inter-class var / intra-class var) per mode:") |
| for d in range(4): |
| fs = position_f_stat[:, :, d] |
| print(f" Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} " |
| f"top 5% threshold={fs.quantile(0.95):.6f}") |
|
|
| |
| for d in range(4): |
| fs = position_f_stat[:, :, d] |
| best_idx = fs.argmax() |
| bi, bj = best_idx // gw, best_idx % gw |
| print(f" Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}") |
|
|
| overall_f = position_f_stat.mean(dim=-1) |
| print(f"\n Overall best discriminative patch position: " |
| f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} " |
| f"F={overall_f.max():.6f}") |
| print(f" Overall mean F-statistic: {overall_f.mean():.6f}") |
| print(f" VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 5. PER-MODE SPATIAL VARIANCE β Which mode has the most structure?") |
| print("=" * 70) |
|
|
| for d in range(4): |
| |
| overall_mean_map = class_friction_sum.sum(0) / class_counts.sum() |
| mode_map = overall_mean_map[:, :, d] |
| sv = mode_map.var().item() |
| sm = mode_map.mean().item() |
| print(f" Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 6. SAMPLE FRICTION MAPS β Individual images") |
| print("=" * 70) |
|
|
| |
| for c in range(10): |
| maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2] |
| for idx, (fric_map, _) in enumerate(maps_c): |
| |
| flat = fric_map.reshape(-1, 4) |
| fmean = flat.mean(0) |
| fstd = flat.std(0) |
| fmin = flat.min(0).values |
| fmax = flat.max(0).values |
|
|
| |
| fric_total = flat.sum(dim=-1) |
| fric_prob = fric_total / (fric_total.sum() + 1e-8) |
| entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item() |
| max_entropy = np.log(256) |
|
|
| |
| hot = (fric_total > 2 * fric_total.mean()).sum().item() |
|
|
| if idx == 0 and c == 0: |
| print(f"\n {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} " |
| f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}") |
| print(f" {'-' * 55}") |
|
|
| print(f" {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} " |
| f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 7. LINEAR PROBE β Can flattened friction maps classify?") |
| print("=" * 70) |
|
|
| |
| features = [] |
| labels_all = [] |
| for fric_map, label in all_friction_maps: |
| features.append(fric_map.reshape(-1)) |
| labels_all.append(label) |
|
|
| X = torch.stack(features) |
| y = torch.tensor(labels_all) |
|
|
| |
| N = len(y) |
| perm = torch.randperm(N) |
| n_train = int(0.8 * N) |
| X_train, y_train = X[perm[:n_train]], y[perm[:n_train]] |
| X_test, y_test = X[perm[n_train:]], y[perm[n_train:]] |
|
|
| |
| mean = X_train.mean(0) |
| std = X_train.std(0).clamp(min=1e-6) |
| X_train_n = (X_train - mean) / std |
| X_test_n = (X_test - mean) / std |
|
|
| |
| lam = 1.0 |
| n_classes = 10 |
| Y_onehot = torch.zeros(n_train, n_classes) |
| Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0) |
|
|
| XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1]) |
| XtY = X_train_n.T @ Y_onehot |
| W = torch.linalg.solve(XtX, XtY) |
|
|
| train_pred = (X_train_n @ W).argmax(1) |
| test_pred = (X_test_n @ W).argmax(1) |
| train_acc = (train_pred == y_train).float().mean().item() |
| test_acc = (test_pred == y_test).float().mean().item() |
|
|
| print(f"\n Features: flattened friction map ({X.shape[1]} dims)") |
| print(f" Train: {n_train}, Test: {N - n_train}") |
| print(f" Train accuracy: {train_acc:.1%}") |
| print(f" Test accuracy: {test_acc:.1%}") |
| print(f" Chance: 10.0%") |
|
|
| |
| print(f"\n {'Class':<10s} {'Acc':>6s}") |
| print(f" {'-' * 18}") |
| for c in range(n_classes): |
| mask = y_test == c |
| if mask.sum() > 0: |
| acc = (test_pred[mask] == y_test[mask]).float().mean().item() |
| bar = 'β' * int(acc * 20) |
| print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}") |
|
|
| print(f"\n VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} " |
| f"spatial friction signal") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 8. SETTLE MAP β Spatial convergence patterns") |
| print("=" * 70) |
|
|
| settle_features = [] |
| settle_labels = [] |
| for sett_map, label in all_settle_maps: |
| settle_features.append(sett_map.reshape(-1)) |
| settle_labels.append(label) |
|
|
| X_s = torch.stack(settle_features) |
| y_s = torch.tensor(settle_labels) |
|
|
| perm_s = torch.randperm(len(y_s)) |
| n_train_s = int(0.8 * len(y_s)) |
| X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]] |
| X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]] |
|
|
| mean_s = X_train_s.mean(0) |
| std_s = X_train_s.std(0).clamp(min=1e-6) |
| X_train_sn = (X_train_s - mean_s) / std_s |
| X_test_sn = (X_test_s - mean_s) / std_s |
|
|
| Y_onehot_s = torch.zeros(n_train_s, n_classes) |
| Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0) |
| XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1]) |
| XtY_s = X_train_sn.T @ Y_onehot_s |
| W_s = torch.linalg.solve(XtX_s, XtY_s) |
|
|
| test_pred_s = (X_test_sn @ W_s).argmax(1) |
| test_acc_s = (test_pred_s == y_test_s).float().mean().item() |
|
|
| print(f" Settle map linear probe:") |
| print(f" Test accuracy: {test_acc_s:.1%}") |
| print(f" VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 9. COMBINED CONDUIT β All evidence stacked") |
| print("=" * 70) |
|
|
| |
| print("\n Collecting eigenvalue spatial maps...") |
| all_eval_maps = [] |
| all_combined = [] |
|
|
| for fric_map, label in all_friction_maps: |
| pass |
|
|
| |
| eval_features = [] |
| combined_features = [] |
| combined_labels = [] |
|
|
| idx = 0 |
| for images, labels_batch in loader: |
| if idx >= max_collect: |
| break |
| with torch.no_grad(): |
| out = freckles(images.to(device)) |
| S = out['svd']['S'] |
| Vt = out['svd']['Vt'] |
| B_img, N, D = S.shape |
|
|
| S2 = S.pow(2) |
| G = torch.einsum('bnij,bnj,bnjk->bnik', |
| Vt.transpose(-2, -1), S2, Vt) |
| G_flat = G.reshape(B_img * N, D, D) |
| packet = conduit(G_flat) |
|
|
| fric = packet.friction.reshape(B_img, gh, gw, D) |
| sett = packet.settle.reshape(B_img, gh, gw, D) |
| evals = S.reshape(B_img, gh, gw, D) |
|
|
| for i in range(B_img): |
| if idx >= max_collect: |
| break |
| |
| eval_features.append(evals[i].cpu().reshape(-1)) |
| |
| combined = torch.cat([ |
| fric[i].cpu().reshape(-1), |
| sett[i].cpu().reshape(-1), |
| evals[i].cpu().reshape(-1), |
| ]) |
| combined_features.append(combined) |
| combined_labels.append(labels_batch[i].item()) |
| idx += 1 |
|
|
| |
| X_e = torch.stack(eval_features) |
| y_e = torch.tensor(combined_labels) |
|
|
| perm_e = torch.randperm(len(y_e)) |
| n_train_e = int(0.8 * len(y_e)) |
|
|
| def ridge_probe(X, y, perm, n_train, name): |
| X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]] |
| X_te, y_te = X[perm[n_train:]], y[perm[n_train:]] |
| m = X_tr.mean(0) |
| s = X_tr.std(0).clamp(min=1e-6) |
| X_tr_n = (X_tr - m) / s |
| X_te_n = (X_te - m) / s |
| Y_oh = torch.zeros(n_train, n_classes) |
| Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0) |
| W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh) |
| acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item() |
| print(f" {name:<30s} dims={X.shape[1]:>5d} test_acc={acc:.1%}") |
| return acc |
|
|
| print(f"\n Linear probe comparison (all use same train/test split):\n") |
| acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial") |
| acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial") |
| acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial") |
|
|
| X_c = torch.stack(combined_features) |
| acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)") |
|
|
| print(f"\n Chance: 10.0%") |
| print(f" VERDICT: Combined vs eigenvalues-only lift = " |
| f"{(acc_comb - acc_evals) * 100:+.1f} percentage points") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" SPATIAL FRICTION ANALYSIS β SUMMARY") |
| print("=" * 70) |
| print(f" 1. Spatial structure within images: CV = {spatial_cv.mean():.4f}") |
| print(f" 2. Inter-class pattern distance: cos_min = {cos_min:.6f}") |
| print(f" 3. Center vs edge asymmetry: (see table above)") |
| print(f" 4. Per-position F-statistic: max = {overall_f.max():.6f}") |
| print(f" 5. Friction map linear probe: {test_acc:.1%}") |
| print(f" 6. Settle map linear probe: {test_acc_s:.1%}") |
| print(f" 7. Eigenvalue map linear probe: {acc_evals:.1%}") |
| print(f" 8. Combined conduit linear probe: {acc_comb:.1%}") |
| print(f" 9. Conduit lift over eigenvalues: {(acc_comb - acc_evals)*100:+.1f}pp") |