AbstractPhil commited on
Commit
a986628
Β·
verified Β·
1 Parent(s): 1d7327c

Create notebook_cell_3_theorem_2.py

Browse files
Files changed (1) hide show
  1. notebook_cell_3_theorem_2.py +525 -0
notebook_cell_3_theorem_2.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cell 3 β€” Spatial Friction Map Analysis
3
+ ========================================
4
+ The mean friction is uniform across classes (12.19 Β± 0.08).
5
+ But the SPATIAL PATTERN of friction within images might differ.
6
+
7
+ Questions:
8
+ 1. Do friction maps have spatial structure? (or uniform per image)
9
+ 2. Does the spatial pattern differ across classes?
10
+ 3. Do edge/boundary patches have higher friction than interior?
11
+ 4. Is per-patch friction discriminative even if per-class mean is not?
12
+ 5. What does the friction map look like for individual images?
13
+ """
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+
20
+ from geolip_core.linalg.conduit import FLEighConduit
21
+
22
+ device = torch.device('cuda')
23
+
24
+ # ═══════════════════════════════════════════════════════════════
25
+ # LOAD DATA
26
+ # ═══════════════════════════════════════════════════════════════
27
+
28
+ print("Loading Freckles v40 + CIFAR-10...")
29
+ from geolip_svae import load_model
30
+ import torchvision
31
+ import torchvision.transforms as T
32
+
33
+ freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
34
+ freckles.eval()
35
+
36
+ transform = T.Compose([T.Resize(64), T.ToTensor()])
37
+ cifar_test = torchvision.datasets.CIFAR10(
38
+ root='/content/data', train=False, download=True, transform=transform)
39
+ loader = torch.utils.data.DataLoader(
40
+ cifar_test, batch_size=64, shuffle=False, num_workers=4)
41
+
42
+ CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
43
+ 'dog', 'frog', 'horse', 'ship', 'truck']
44
+
45
+ conduit = FLEighConduit().to(device)
46
+ gh, gw = 16, 16 # patch grid
47
+
48
+
49
+ # ═══════════════════════════════════════════════════════════════
50
+ # COLLECT SPATIAL FRICTION MAPS
51
+ # ═══════════════════════════════════════════════════════════════
52
+
53
+ print("Collecting spatial friction maps (full test set)...\n")
54
+
55
+ # Per-class friction maps: (10, gh, gw, D=4)
56
+ class_friction_sum = torch.zeros(10, gh, gw, 4)
57
+ class_friction_sq = torch.zeros(10, gh, gw, 4)
58
+ class_settle_sum = torch.zeros(10, gh, gw, 4)
59
+ class_counts = torch.zeros(10)
60
+
61
+ # Also collect per-image statistics for discriminability analysis
62
+ all_friction_maps = [] # list of (friction_map, label)
63
+ all_settle_maps = []
64
+
65
+ n_images_collected = 0
66
+ max_collect = 2000 # collect individual maps for first 2000 images
67
+
68
+ for images, labels in tqdm(loader, desc="Processing"):
69
+ with torch.no_grad():
70
+ out = freckles(images.to(device))
71
+ S = out['svd']['S'] # (B, N, D)
72
+ Vt = out['svd']['Vt'] # (B, N, D, D)
73
+ B_img, N, D = S.shape
74
+
75
+ # Build Gram matrices
76
+ S2 = S.pow(2)
77
+ G = torch.einsum('bnij,bnj,bnjk->bnik',
78
+ Vt.transpose(-2, -1), S2, Vt)
79
+ G_flat = G.reshape(B_img * N, D, D)
80
+
81
+ packet = conduit(G_flat)
82
+
83
+ # Reshape to spatial: (B, gh, gw, D)
84
+ fric_map = packet.friction.reshape(B_img, gh, gw, D)
85
+ sett_map = packet.settle.reshape(B_img, gh, gw, D)
86
+
87
+ fric_cpu = fric_map.cpu()
88
+ sett_cpu = sett_map.cpu()
89
+
90
+ for i in range(B_img):
91
+ c = labels[i].item()
92
+ class_friction_sum[c] += fric_cpu[i]
93
+ class_friction_sq[c] += fric_cpu[i].pow(2)
94
+ class_settle_sum[c] += sett_cpu[i]
95
+ class_counts[c] += 1
96
+
97
+ if n_images_collected < max_collect:
98
+ all_friction_maps.append((fric_cpu[i], c))
99
+ all_settle_maps.append((sett_cpu[i], c))
100
+ n_images_collected += 1
101
+
102
+ print(f"\nCollected {int(class_counts.sum().item())} images, "
103
+ f"{n_images_collected} individual maps\n")
104
+
105
+
106
+ # ═══════════════════════════════════════════════════════════════
107
+ # 1. SPATIAL STRUCTURE WITHIN IMAGES
108
+ # ═══════════════════════════════════════════════════════════════
109
+
110
+ print("=" * 70)
111
+ print(" 1. SPATIAL STRUCTURE β€” Do friction maps have spatial variance?")
112
+ print("=" * 70)
113
+
114
+ # Per-image spatial variance: does friction vary across patches within ONE image?
115
+ per_image_spatial_var = []
116
+ for fric_map, label in all_friction_maps:
117
+ # fric_map: (gh, gw, 4)
118
+ # Spatial variance: how much does friction vary across the 16x16 grid?
119
+ per_mode_var = fric_map.reshape(-1, 4).var(dim=0) # var across 256 patches
120
+ per_image_spatial_var.append((per_mode_var, label))
121
+
122
+ spatial_vars = torch.stack([v for v, _ in per_image_spatial_var]) # (N, 4)
123
+
124
+ print(f"\n Per-image spatial friction variance (across 256 patches):")
125
+ print(f" Mode 0 (Sβ‚€): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}")
126
+ print(f" Mode 1 (S₁): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}")
127
+ print(f" Mode 2 (Sβ‚‚): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}")
128
+ print(f" Mode 3 (S₃): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}")
129
+
130
+ # Coefficient of variation: spatial_std / spatial_mean per image
131
+ spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps])
132
+ spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps])
133
+ spatial_cv = spatial_stds / (spatial_means + 1e-8)
134
+
135
+ print(f"\n Per-image spatial CV (std/mean):")
136
+ for d in range(4):
137
+ print(f" Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} "
138
+ f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}")
139
+
140
+ has_spatial_structure = spatial_cv.mean() > 0.1
141
+ print(f"\n VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} "
142
+ f"(mean CV = {spatial_cv.mean():.4f})")
143
+
144
+
145
+ # ═══════════════════════════════════════════════════════════════
146
+ # 2. PER-CLASS SPATIAL FRICTION PATTERNS
147
+ # ═══════════════════════════════════════════════════════════════
148
+
149
+ print(f"\n{'=' * 70}")
150
+ print(" 2. PER-CLASS SPATIAL PATTERNS β€” Do classes have different friction maps?")
151
+ print("=" * 70)
152
+
153
+ # Average friction map per class
154
+ class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1)
155
+ class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2)
156
+
157
+ # Flatten spatial maps and compare between classes
158
+ class_flat = class_means.reshape(10, -1) # (10, gh*gw*4)
159
+
160
+ # Inter-class distance matrix
161
+ dists = torch.cdist(class_flat, class_flat)
162
+
163
+ print(f"\n Inter-class friction map L2 distances:")
164
+ print(f" {'':>10s}", end="")
165
+ for c in range(10):
166
+ print(f" {CLASSES[c][:5]:>6s}", end="")
167
+ print()
168
+ for c1 in range(10):
169
+ print(f" {CLASSES[c1][:10]:>10s}", end="")
170
+ for c2 in range(10):
171
+ print(f" {dists[c1, c2]:6.3f}", end="")
172
+ print()
173
+
174
+ # Mean inter-class vs intra-class distance
175
+ inter_mask = ~torch.eye(10, dtype=torch.bool)
176
+ inter_dist = dists[inter_mask].mean().item()
177
+ print(f"\n Mean inter-class distance: {inter_dist:.4f}")
178
+
179
+ # Cosine similarity between class friction maps
180
+ class_flat_norm = F.normalize(class_flat, dim=-1)
181
+ cos_sim = class_flat_norm @ class_flat_norm.T
182
+ cos_off_diag = cos_sim[inter_mask].mean().item()
183
+ cos_min = cos_sim[inter_mask].min().item()
184
+ print(f" Mean cosine similarity: {cos_off_diag:.6f}")
185
+ print(f" Min cosine similarity: {cos_min:.6f}")
186
+ print(f" VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}")
187
+
188
+
189
+ # ═══════════════════════════════════════════════════════════════
190
+ # 3. CENTER vs EDGE FRICTION
191
+ # ═══════════════════════════════════════════════════════════════
192
+
193
+ print(f"\n{'=' * 70}")
194
+ print(" 3. CENTER vs EDGE β€” Do boundary patches have higher friction?")
195
+ print("=" * 70)
196
+
197
+ # Define center and edge regions
198
+ center_mask = torch.zeros(gh, gw, dtype=torch.bool)
199
+ center_mask[4:12, 4:12] = True # center 8Γ—8
200
+ edge_mask = ~center_mask # border ring
201
+
202
+ for c in range(10):
203
+ fric_c = class_means[c] # (gh, gw, 4)
204
+ center_fric = fric_c[center_mask].mean().item()
205
+ edge_fric = fric_c[edge_mask].mean().item()
206
+ ratio = edge_fric / (center_fric + 1e-8)
207
+ if c == 0:
208
+ print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}")
209
+ print(f" {'-' * 40}")
210
+ print(f" {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}")
211
+
212
+
213
+ # ═══════════════════════════════════════════════════════════════
214
+ # 4. PER-PATCH-POSITION DISCRIMINABILITY
215
+ # ═══════════════════════════════════════════════════════════════
216
+
217
+ print(f"\n{'=' * 70}")
218
+ print(" 4. PER-PATCH-POSITION DISCRIMINABILITY")
219
+ print("=" * 70)
220
+
221
+ # For each patch position (i,j), is friction discriminative across classes?
222
+ # Use inter-class variance / intra-class variance ratio (F-statistic proxy)
223
+
224
+ position_f_stat = torch.zeros(gh, gw, 4)
225
+
226
+ for pi in range(gh):
227
+ for pj in range(gw):
228
+ for d in range(4):
229
+ # Class means at this position
230
+ c_means = class_means[:, pi, pj, d] # (10,)
231
+ # Inter-class variance
232
+ inter_var = c_means.var().item()
233
+ # Intra-class variance (averaged)
234
+ intra_var = class_vars[:, pi, pj, d].mean().item()
235
+ position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10)
236
+
237
+ # Summary
238
+ print(f"\n F-statistic (inter-class var / intra-class var) per mode:")
239
+ for d in range(4):
240
+ fs = position_f_stat[:, :, d]
241
+ print(f" Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} "
242
+ f"top 5% threshold={fs.quantile(0.95):.6f}")
243
+
244
+ # Best discriminative positions
245
+ for d in range(4):
246
+ fs = position_f_stat[:, :, d]
247
+ best_idx = fs.argmax()
248
+ bi, bj = best_idx // gw, best_idx % gw
249
+ print(f" Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}")
250
+
251
+ overall_f = position_f_stat.mean(dim=-1) # avg across modes
252
+ print(f"\n Overall best discriminative patch position: "
253
+ f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} "
254
+ f"F={overall_f.max():.6f}")
255
+ print(f" Overall mean F-statistic: {overall_f.mean():.6f}")
256
+ print(f" VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}")
257
+
258
+
259
+ # ═══════════════════════════════════════════════════════════════
260
+ # 5. PER-MODE ANALYSIS β€” Which SVD mode carries most spatial variance?
261
+ # ═══════════════════════════════════════════════════════════════
262
+
263
+ print(f"\n{'=' * 70}")
264
+ print(" 5. PER-MODE SPATIAL VARIANCE β€” Which mode has the most structure?")
265
+ print("=" * 70)
266
+
267
+ for d in range(4):
268
+ # Spatial variance of mean friction map (across all images)
269
+ overall_mean_map = class_friction_sum.sum(0) / class_counts.sum() # (gh, gw, 4)
270
+ mode_map = overall_mean_map[:, :, d]
271
+ sv = mode_map.var().item()
272
+ sm = mode_map.mean().item()
273
+ print(f" Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}")
274
+
275
+
276
+ # ═══════════════════════════════════════════════════════════════
277
+ # 6. INDIVIDUAL IMAGE FRICTION MAPS
278
+ # ═══════════════════════════════════════════════════════════════
279
+
280
+ print(f"\n{'=' * 70}")
281
+ print(" 6. SAMPLE FRICTION MAPS β€” Individual images")
282
+ print("=" * 70)
283
+
284
+ # Show friction statistics for 2 images per class
285
+ for c in range(10):
286
+ maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2]
287
+ for idx, (fric_map, _) in enumerate(maps_c):
288
+ # fric_map: (gh, gw, 4)
289
+ flat = fric_map.reshape(-1, 4)
290
+ fmean = flat.mean(0)
291
+ fstd = flat.std(0)
292
+ fmin = flat.min(0).values
293
+ fmax = flat.max(0).values
294
+
295
+ # Spatial entropy: how concentrated is the friction?
296
+ fric_total = flat.sum(dim=-1) # per-patch total friction
297
+ fric_prob = fric_total / (fric_total.sum() + 1e-8)
298
+ entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item()
299
+ max_entropy = np.log(256) # uniform = max entropy
300
+
301
+ # Hot spots: patches with friction > 2Γ— mean
302
+ hot = (fric_total > 2 * fric_total.mean()).sum().item()
303
+
304
+ if idx == 0 and c == 0:
305
+ print(f"\n {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} "
306
+ f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}")
307
+ print(f" {'-' * 55}")
308
+
309
+ print(f" {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} "
310
+ f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}")
311
+
312
+
313
+ # ═══════════════════════════════════════════════════════════════
314
+ # 7. FRICTION MAP AS CLASSIFIER β€” Linear probe on spatial friction
315
+ # ═══════════════════════════════════════════════════════════════
316
+
317
+ print(f"\n{'=' * 70}")
318
+ print(" 7. LINEAR PROBE β€” Can flattened friction maps classify?")
319
+ print("=" * 70)
320
+
321
+ # Collect features and labels
322
+ features = []
323
+ labels_all = []
324
+ for fric_map, label in all_friction_maps:
325
+ features.append(fric_map.reshape(-1)) # (gh*gw*4,) = 1024
326
+ labels_all.append(label)
327
+
328
+ X = torch.stack(features) # (N, 1024)
329
+ y = torch.tensor(labels_all) # (N,)
330
+
331
+ # Train/test split
332
+ N = len(y)
333
+ perm = torch.randperm(N)
334
+ n_train = int(0.8 * N)
335
+ X_train, y_train = X[perm[:n_train]], y[perm[:n_train]]
336
+ X_test, y_test = X[perm[n_train:]], y[perm[n_train:]]
337
+
338
+ # Standardize
339
+ mean = X_train.mean(0)
340
+ std = X_train.std(0).clamp(min=1e-6)
341
+ X_train_n = (X_train - mean) / std
342
+ X_test_n = (X_test - mean) / std
343
+
344
+ # Ridge regression (closed form, no training loop)
345
+ lam = 1.0
346
+ n_classes = 10
347
+ Y_onehot = torch.zeros(n_train, n_classes)
348
+ Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0)
349
+
350
+ XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1])
351
+ XtY = X_train_n.T @ Y_onehot
352
+ W = torch.linalg.solve(XtX, XtY)
353
+
354
+ train_pred = (X_train_n @ W).argmax(1)
355
+ test_pred = (X_test_n @ W).argmax(1)
356
+ train_acc = (train_pred == y_train).float().mean().item()
357
+ test_acc = (test_pred == y_test).float().mean().item()
358
+
359
+ print(f"\n Features: flattened friction map ({X.shape[1]} dims)")
360
+ print(f" Train: {n_train}, Test: {N - n_train}")
361
+ print(f" Train accuracy: {train_acc:.1%}")
362
+ print(f" Test accuracy: {test_acc:.1%}")
363
+ print(f" Chance: 10.0%")
364
+
365
+ # Per-class accuracy
366
+ print(f"\n {'Class':<10s} {'Acc':>6s}")
367
+ print(f" {'-' * 18}")
368
+ for c in range(n_classes):
369
+ mask = y_test == c
370
+ if mask.sum() > 0:
371
+ acc = (test_pred[mask] == y_test[mask]).float().mean().item()
372
+ bar = 'β–ˆ' * int(acc * 20)
373
+ print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}")
374
+
375
+ print(f"\n VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} "
376
+ f"spatial friction signal")
377
+
378
+
379
+ # ═══════════════════════════════════════════════════════════════
380
+ # 8. SETTLE MAP ANALYSIS β€” Same treatment for settle times
381
+ # ═══════════════════════════════════════════════════════════════
382
+
383
+ print(f"\n{'=' * 70}")
384
+ print(" 8. SETTLE MAP β€” Spatial convergence patterns")
385
+ print("=" * 70)
386
+
387
+ settle_features = []
388
+ settle_labels = []
389
+ for sett_map, label in all_settle_maps:
390
+ settle_features.append(sett_map.reshape(-1))
391
+ settle_labels.append(label)
392
+
393
+ X_s = torch.stack(settle_features)
394
+ y_s = torch.tensor(settle_labels)
395
+
396
+ perm_s = torch.randperm(len(y_s))
397
+ n_train_s = int(0.8 * len(y_s))
398
+ X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]]
399
+ X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]]
400
+
401
+ mean_s = X_train_s.mean(0)
402
+ std_s = X_train_s.std(0).clamp(min=1e-6)
403
+ X_train_sn = (X_train_s - mean_s) / std_s
404
+ X_test_sn = (X_test_s - mean_s) / std_s
405
+
406
+ Y_onehot_s = torch.zeros(n_train_s, n_classes)
407
+ Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0)
408
+ XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1])
409
+ XtY_s = X_train_sn.T @ Y_onehot_s
410
+ W_s = torch.linalg.solve(XtX_s, XtY_s)
411
+
412
+ test_pred_s = (X_test_sn @ W_s).argmax(1)
413
+ test_acc_s = (test_pred_s == y_test_s).float().mean().item()
414
+
415
+ print(f" Settle map linear probe:")
416
+ print(f" Test accuracy: {test_acc_s:.1%}")
417
+ print(f" VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}")
418
+
419
+
420
+ # ═══════════════════════════════════════════════════════════════
421
+ # 9. COMBINED CONDUIT β€” friction + settle + eigenvalues
422
+ # ═══════════════════════════════════════════════════════════════
423
+
424
+ print(f"\n{'=' * 70}")
425
+ print(" 9. COMBINED CONDUIT β€” All evidence stacked")
426
+ print("=" * 70)
427
+
428
+ # Also test: raw eigenvalues (S values) as spatial maps for comparison
429
+ print("\n Collecting eigenvalue spatial maps...")
430
+ all_eval_maps = []
431
+ all_combined = []
432
+
433
+ for fric_map, label in all_friction_maps:
434
+ pass # Already collected
435
+
436
+ # Re-collect with eigenvalues
437
+ eval_features = []
438
+ combined_features = []
439
+ combined_labels = []
440
+
441
+ idx = 0
442
+ for images, labels_batch in loader:
443
+ if idx >= max_collect:
444
+ break
445
+ with torch.no_grad():
446
+ out = freckles(images.to(device))
447
+ S = out['svd']['S']
448
+ Vt = out['svd']['Vt']
449
+ B_img, N, D = S.shape
450
+
451
+ S2 = S.pow(2)
452
+ G = torch.einsum('bnij,bnj,bnjk->bnik',
453
+ Vt.transpose(-2, -1), S2, Vt)
454
+ G_flat = G.reshape(B_img * N, D, D)
455
+ packet = conduit(G_flat)
456
+
457
+ fric = packet.friction.reshape(B_img, gh, gw, D)
458
+ sett = packet.settle.reshape(B_img, gh, gw, D)
459
+ evals = S.reshape(B_img, gh, gw, D) # S values as spatial map
460
+
461
+ for i in range(B_img):
462
+ if idx >= max_collect:
463
+ break
464
+ # Eigenvalue spatial map
465
+ eval_features.append(evals[i].cpu().reshape(-1))
466
+ # Combined: friction + settle + eigenvalues
467
+ combined = torch.cat([
468
+ fric[i].cpu().reshape(-1),
469
+ sett[i].cpu().reshape(-1),
470
+ evals[i].cpu().reshape(-1),
471
+ ])
472
+ combined_features.append(combined)
473
+ combined_labels.append(labels_batch[i].item())
474
+ idx += 1
475
+
476
+ # Eigenvalue-only probe
477
+ X_e = torch.stack(eval_features)
478
+ y_e = torch.tensor(combined_labels)
479
+
480
+ perm_e = torch.randperm(len(y_e))
481
+ n_train_e = int(0.8 * len(y_e))
482
+
483
+ def ridge_probe(X, y, perm, n_train, name):
484
+ X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
485
+ X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
486
+ m = X_tr.mean(0)
487
+ s = X_tr.std(0).clamp(min=1e-6)
488
+ X_tr_n = (X_tr - m) / s
489
+ X_te_n = (X_te - m) / s
490
+ Y_oh = torch.zeros(n_train, n_classes)
491
+ Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
492
+ W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
493
+ acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
494
+ print(f" {name:<30s} dims={X.shape[1]:>5d} test_acc={acc:.1%}")
495
+ return acc
496
+
497
+ print(f"\n Linear probe comparison (all use same train/test split):\n")
498
+ acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial")
499
+ acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial")
500
+ acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial")
501
+
502
+ X_c = torch.stack(combined_features)
503
+ acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)")
504
+
505
+ print(f"\n Chance: 10.0%")
506
+ print(f" VERDICT: Combined vs eigenvalues-only lift = "
507
+ f"{(acc_comb - acc_evals) * 100:+.1f} percentage points")
508
+
509
+
510
+ # ═══════════════════════════════════════════════════════════════
511
+ # SUMMARY
512
+ # ═══════════════════════════════════════════════════════════════
513
+
514
+ print(f"\n{'=' * 70}")
515
+ print(" SPATIAL FRICTION ANALYSIS β€” SUMMARY")
516
+ print("=" * 70)
517
+ print(f" 1. Spatial structure within images: CV = {spatial_cv.mean():.4f}")
518
+ print(f" 2. Inter-class pattern distance: cos_min = {cos_min:.6f}")
519
+ print(f" 3. Center vs edge asymmetry: (see table above)")
520
+ print(f" 4. Per-position F-statistic: max = {overall_f.max():.6f}")
521
+ print(f" 5. Friction map linear probe: {test_acc:.1%}")
522
+ print(f" 6. Settle map linear probe: {test_acc_s:.1%}")
523
+ print(f" 7. Eigenvalue map linear probe: {acc_evals:.1%}")
524
+ print(f" 8. Combined conduit linear probe: {acc_comb:.1%}")
525
+ print(f" 9. Conduit lift over eigenvalues: {(acc_comb - acc_evals)*100:+.1f}pp")