AbstractPhil commited on
Commit
c674ac6
Β·
verified Β·
1 Parent(s): 31694be

Create deep_analysis.py

Browse files
Files changed (1) hide show
  1. deep_analysis.py +619 -0
deep_analysis.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flow Match Relay β€” Full Analysis Toolkit
4
+ ==========================================
5
+ Run after training. Analyzes:
6
+
7
+ 1. Relay diagnostics: drift, gates, anchor geometry
8
+ 2. CV measurement through the network at each layer
9
+ 3. Anchor utilization: which anchors are active per class?
10
+ 4. Generation quality: FID prep, per-class diversity
11
+ 5. The 0.29154 hunt: does drift converge to the binding constant?
12
+ 6. Feature map geometry: CV of bottleneck features
13
+ 7. Velocity field analysis: how does the relay affect v_pred?
14
+ 8. Gate dynamics: measure gate values at different timesteps
15
+ 9. Anchor constellation visualization
16
+ 10. Ablation: relay ON vs OFF generation comparison
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+ import math
24
+ import os
25
+ import json
26
+ import time
27
+ from torchvision import datasets, transforms
28
+ from torchvision.utils import save_image, make_grid
29
+
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ torch.manual_seed(42)
32
+
33
+ os.makedirs("analysis", exist_ok=True)
34
+
35
+
36
+ def compute_cv(points, n_samples=2000, n_points=5):
37
+ N = points.shape[0]
38
+ if N < n_points: return float('nan')
39
+ points = F.normalize(points.to(DEVICE).float(), dim=-1)
40
+ vols = []
41
+ for _ in range(n_samples):
42
+ idx = torch.randperm(min(N, 10000), device=DEVICE)[:n_points]
43
+ pts = points[idx].unsqueeze(0)
44
+ gram = torch.bmm(pts, pts.transpose(1, 2))
45
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
46
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
47
+ d2 = F.relu(d2)
48
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
49
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
50
+ v2 = -torch.linalg.det(cm) / 9216
51
+ if v2[0].item() > 1e-20:
52
+ vols.append(v2[0].sqrt().cpu())
53
+ if len(vols) < 50: return float('nan')
54
+ vt = torch.stack(vols)
55
+ return (vt.std() / (vt.mean() + 1e-8)).item()
56
+
57
+
58
+ def eff_dim(x):
59
+ x_c = x - x.mean(0, keepdim=True)
60
+ n = min(512, x.shape[0])
61
+ _, S, _ = torch.linalg.svd(x_c[:n].float(), full_matrices=False)
62
+ p = S / S.sum()
63
+ return p.pow(2).sum().reciprocal().item()
64
+
65
+
66
+ CLASS_NAMES = ['plane', 'auto', 'bird', 'cat', 'deer',
67
+ 'dog', 'frog', 'horse', 'ship', 'truck']
68
+
69
+ print("=" * 80)
70
+ print("FLOW MATCH RELAY β€” FULL ANALYSIS TOOLKIT")
71
+ print(f" Device: {DEVICE}")
72
+ print("=" * 80)
73
+
74
+ # ── Load model ──
75
+ from transformers import AutoModel
76
+
77
+ model = AutoModel.from_pretrained(
78
+ "AbstractPhil/geolip-diffusion-proto", trust_remote_code=True
79
+ ).to(DEVICE)
80
+ model.eval()
81
+
82
+ n_params = sum(p.numel() for p in model.parameters())
83
+ n_relay = sum(p.numel() for n, p in model.named_parameters() if 'relay' in n)
84
+ print(f" Params: {n_params:,} (relay: {n_relay:,}, {100*n_relay/n_params:.1f}%)")
85
+
86
+ # Find relay modules
87
+ relays = {}
88
+ for name, module in model.named_modules():
89
+ if hasattr(module, 'drift') and hasattr(module, 'anchors'):
90
+ relays[name] = module
91
+ print(f" Relay modules: {len(relays)}")
92
+
93
+
94
+ # ══════════════════════════════════════════════════════════════════
95
+ # TEST 1: RELAY DIAGNOSTICS
96
+ # ══════════════════════════════════════════════════════════════════
97
+
98
+ print(f"\n{'━'*80}")
99
+ print("TEST 1: Relay Diagnostics β€” Drift, Gates, Anchor Geometry")
100
+ print(f"{'━'*80}")
101
+
102
+ for name, relay in relays.items():
103
+ drift = relay.drift().detach().cpu() # (P, A)
104
+ gates = relay.gates.sigmoid().detach().cpu() # (P,)
105
+ home = F.normalize(relay.home, dim=-1).detach().cpu()
106
+ anchors = F.normalize(relay.anchors, dim=-1).detach().cpu()
107
+
108
+ P, A, d = home.shape
109
+
110
+ print(f"\n {name}:")
111
+ print(f" Patches: {P}, Anchors/patch: {A}, Patch dim: {d}")
112
+ print(f" Drift (rad): mean={drift.mean():.6f} std={drift.std():.6f} "
113
+ f"min={drift.min():.6f} max={drift.max():.6f}")
114
+ print(f" Drift (deg): mean={math.degrees(drift.mean()):.2f}Β° "
115
+ f"max={math.degrees(drift.max()):.2f}Β°")
116
+ print(f" Gates: mean={gates.mean():.4f} std={gates.std():.4f} "
117
+ f"min={gates.min():.4f} max={gates.max():.4f}")
118
+
119
+ # Anchor pairwise similarity within each patch
120
+ for p in range(min(4, P)):
121
+ sim = (anchors[p] @ anchors[p].T)
122
+ sim.fill_diagonal_(0)
123
+ print(f" Patch {p}: anchor_cos mean={sim.mean():.4f} max={sim.max():.4f} "
124
+ f"min={sim.min():.4f}")
125
+
126
+ # Near 0.29154?
127
+ near_029 = (drift - 0.29154).abs() < 0.05
128
+ pct_near = near_029.float().mean().item()
129
+ print(f" Near 0.29154: {pct_near:.1%} of anchors within Β±0.05")
130
+
131
+ # Per-patch drift
132
+ print(f" Per-patch mean drift:")
133
+ for p in range(P):
134
+ d_p = drift[p].mean().item()
135
+ marker = " β—„ 0.29" if abs(d_p - 0.29154) < 0.05 else ""
136
+ print(f" Patch {p:2d}: {d_p:.6f} rad ({math.degrees(d_p):.2f}Β°){marker}")
137
+
138
+
139
+ # ══════════════════════════════════════════════════════════════════
140
+ # TEST 2: BOTTLENECK FEATURE GEOMETRY
141
+ # ══════════════════════════════════════════════════════════════════
142
+
143
+ print(f"\n{'━'*80}")
144
+ print("TEST 2: Bottleneck Feature Geometry β€” CV at the relay point")
145
+ print(f"{'━'*80}")
146
+
147
+ # Load some real data
148
+ transform = transforms.Compose([
149
+ transforms.ToTensor(),
150
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
151
+ ])
152
+ test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
153
+ test_loader = torch.utils.data.DataLoader(test_ds, batch_size=256, shuffle=False)
154
+
155
+ # Hook to capture bottleneck features
156
+ bottleneck_features = {}
157
+
158
+ def hook_fn(name):
159
+ def fn(module, input, output):
160
+ if isinstance(output, torch.Tensor):
161
+ bottleneck_features[name] = output.detach()
162
+ return fn
163
+
164
+ # Register hooks ONLY on top-level mid blocks and relay modules (not submodules)
165
+ hooks = []
166
+ target_names = set(relays.keys()) | {'unet.mid_block1', 'unet.mid_block2', 'unet.mid_attn'}
167
+ for name, module in model.named_modules():
168
+ if name in target_names:
169
+ hooks.append(module.register_forward_hook(hook_fn(name)))
170
+
171
+ # Run a batch through at several timesteps
172
+ images, labels = next(iter(test_loader))
173
+ images = images.to(DEVICE)
174
+ labels_dev = labels.to(DEVICE)
175
+
176
+ print(f"\n CV of bottleneck features at different timesteps:")
177
+ print(f" {'t':>6} {'module':>40} {'CV':>8} {'eff_d':>8} {'norm':>8}")
178
+
179
+ for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
180
+ t = torch.full((images.shape[0],), t_val, device=DEVICE)
181
+ eps = torch.randn_like(images)
182
+ t_b = t.view(-1, 1, 1, 1)
183
+ x_t = (1 - t_b) * images + t_b * eps
184
+
185
+ bottleneck_features.clear()
186
+ with torch.no_grad():
187
+ _ = model(x_t, t, labels_dev)
188
+
189
+ for feat_name, feat in bottleneck_features.items():
190
+ if feat.dim() == 4:
191
+ # Feature map: pool spatial β†’ (B, C)
192
+ pooled = feat.mean(dim=(-2, -1))
193
+ elif feat.dim() == 2:
194
+ pooled = feat
195
+ else:
196
+ continue # skip 1D or other odd shapes
197
+ if pooled.dim() != 2 or pooled.shape[0] < 5 or pooled.shape[1] < 5:
198
+ continue
199
+ cv = compute_cv(pooled, n_samples=1000)
200
+ ed = eff_dim(pooled)
201
+ norm_mean = pooled.norm(dim=-1).mean().item()
202
+ print(f" {t_val:>6.2f} {feat_name:>40} {cv:>8.4f} {ed:>8.1f} {norm_mean:>8.2f}")
203
+
204
+ # Clean up hooks
205
+ for h in hooks:
206
+ h.remove()
207
+
208
+
209
+ # ══════════════════════════════════════════════════════════════════
210
+ # TEST 3: PER-CLASS ANCHOR UTILIZATION
211
+ # ══════════════════════════════════════════════════════════════════
212
+
213
+ print(f"\n{'━'*80}")
214
+ print("TEST 3: Per-Class Anchor Utilization")
215
+ print(f" Which anchors activate for each class?")
216
+ print(f"{'━'*80}")
217
+
218
+ # Collect bottleneck features per class
219
+ class_features = {c: [] for c in range(10)}
220
+
221
+ for images_batch, labels_batch in test_loader:
222
+ images_batch = images_batch.to(DEVICE)
223
+ labels_batch = labels_batch.to(DEVICE)
224
+ B = images_batch.shape[0]
225
+
226
+ t = torch.full((B,), 0.0, device=DEVICE) # clean images (t=0)
227
+
228
+ # Get features before relay
229
+ bottleneck_features.clear()
230
+ relay_name = list(relays.keys())[0]
231
+ relay_mod = relays[relay_name]
232
+ hook = relay_mod.register_forward_hook(hook_fn(relay_name))
233
+
234
+ with torch.no_grad():
235
+ _ = model(images_batch, t, labels_batch)
236
+
237
+ hook.remove()
238
+
239
+ if relay_name in bottleneck_features:
240
+ feat = bottleneck_features[relay_name]
241
+ if feat.dim() == 4:
242
+ pooled = feat.mean(dim=(-2, -1)) # (B, C)
243
+ else:
244
+ pooled = feat
245
+ for i in range(B):
246
+ c = labels_batch[i].item()
247
+ class_features[c].append(pooled[i].cpu())
248
+
249
+ if sum(len(v) for v in class_features.values()) > 5000:
250
+ break
251
+
252
+ # For each class, triangulate against the first relay's anchors
253
+ relay_mod = list(relays.values())[0]
254
+ anchors = F.normalize(relay_mod.anchors.detach(), dim=-1) # (P, A, d)
255
+ P, A, d = anchors.shape
256
+
257
+ print(f"\n Nearest anchor distribution per class (Patch 0):")
258
+ print(f" {'class':>10}", end="")
259
+ for a in range(A):
260
+ print(f" {a:>5}", end="")
261
+ print()
262
+
263
+ for c in range(10):
264
+ if not class_features[c]:
265
+ continue
266
+ feats = torch.stack(class_features[c]).to(DEVICE) # (N, C)
267
+ # Chunk into patches
268
+ patches = feats.reshape(-1, P, d)
269
+ patch0 = F.normalize(patches[:, 0], dim=-1) # (N, d)
270
+ # Find nearest anchor
271
+ cos = patch0 @ anchors[0].T # (N, A)
272
+ nearest = cos.argmax(dim=-1) # (N,)
273
+ counts = torch.bincount(nearest, minlength=A).float()
274
+ counts = counts / counts.sum()
275
+ row = f" {CLASS_NAMES[c]:>10}"
276
+ for a in range(A):
277
+ pct = counts[a].item()
278
+ marker = "β–ˆ" if pct > 0.15 else "β–“" if pct > 0.10 else "β–‘" if pct > 0.05 else " "
279
+ row += f" {pct:>4.0%}{marker}"
280
+ print(row)
281
+
282
+
283
+ # ══════════════════════════════════════════════════════════════════
284
+ # TEST 4: GATE DYNAMICS ACROSS TIMESTEPS
285
+ # ══════════════════════════════════════════════════════════════════
286
+
287
+ print(f"\n{'━'*80}")
288
+ print("TEST 4: Gate Dynamics β€” do relay gates respond to timestep?")
289
+ print(f"{'━'*80}")
290
+
291
+ # The gates are parameters (not input-dependent), so they're constant.
292
+ # But we can measure the relay's EFFECTIVE contribution at each t.
293
+ print(f" Note: gates are learned parameters, not t-dependent.")
294
+ print(f" Measuring relay output magnitude at different t instead.\n")
295
+
296
+ relay_name = list(relays.keys())[0]
297
+ relay_mod = relays[relay_name]
298
+
299
+ relay_in = {}
300
+ relay_out = {}
301
+
302
+ def hook_in(module, input, output):
303
+ if isinstance(input, tuple):
304
+ relay_in['x'] = input[0].detach()
305
+ else:
306
+ relay_in['x'] = input.detach()
307
+ relay_out['x'] = output.detach()
308
+
309
+ hook = relay_mod.register_forward_hook(hook_in)
310
+
311
+ images_small = images[:64]
312
+ labels_small = labels_dev[:64]
313
+
314
+ print(f" {'t':>6} {'relay_Ξ”_norm':>14} {'relay_Ξ”_cos':>14} {'input_norm':>12} {'output_norm':>12}")
315
+
316
+ for t_val in [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0]:
317
+ t = torch.full((64,), t_val, device=DEVICE)
318
+ eps = torch.randn_like(images_small)
319
+ t_b = t.view(-1, 1, 1, 1)
320
+ x_t = (1 - t_b) * images_small + t_b * eps
321
+
322
+ relay_in.clear(); relay_out.clear()
323
+ with torch.no_grad():
324
+ _ = model(x_t, t, labels_small)
325
+
326
+ if 'x' in relay_in and 'x' in relay_out:
327
+ x_in = relay_in['x']
328
+ x_out = relay_out['x']
329
+ delta = (x_out - x_in)
330
+ # Flatten everything beyond batch dim for norm
331
+ delta_flat = delta.reshape(delta.shape[0], -1)
332
+ in_flat = x_in.reshape(x_in.shape[0], -1)
333
+ out_flat = x_out.reshape(x_out.shape[0], -1)
334
+ delta_norm = delta_flat.norm(dim=-1).mean().item()
335
+ in_norm = in_flat.norm(dim=-1).mean().item()
336
+ out_norm = out_flat.norm(dim=-1).mean().item()
337
+
338
+ cos_change = 1 - F.cosine_similarity(in_flat, out_flat).mean().item()
339
+ print(f" {t_val:>6.2f} {delta_norm:>14.4f} {cos_change:>14.8f} "
340
+ f"{in_norm:>12.2f} {out_norm:>12.2f}")
341
+
342
+ hook.remove()
343
+
344
+
345
+ # ══════════════════════════════════════════════════════════════════
346
+ # TEST 5: GENERATION QUALITY β€” PER-CLASS DIVERSITY
347
+ # ══════════════════════════════════════════════════════════════════
348
+
349
+ print(f"\n{'━'*80}")
350
+ print("TEST 5: Generation Quality β€” Per-Class Diversity")
351
+ print(f"{'━'*80}")
352
+
353
+ print(f" {'class':>10} {'intra_cos':>10} {'intra_std':>10} {'CV':>8} {'norm':>8}")
354
+
355
+ all_generated = []
356
+ for c in range(10):
357
+ with torch.no_grad():
358
+ imgs = model.sample(n_samples=64, class_label=c) # (64, 3, 32, 32) in [0,1]
359
+ all_generated.append(imgs)
360
+
361
+ flat = imgs.reshape(64, -1) # (64, 3072)
362
+ flat_n = F.normalize(flat, dim=-1)
363
+
364
+ # Intra-class cosine similarity
365
+ sim = flat_n @ flat_n.T
366
+ mask = ~torch.eye(64, device=DEVICE, dtype=torch.bool)
367
+ intra_cos = sim[mask].mean().item()
368
+ intra_std = sim[mask].std().item()
369
+
370
+ cv = compute_cv(flat, n_samples=500)
371
+ norm_mean = flat.norm(dim=-1).mean().item()
372
+
373
+ print(f" {CLASS_NAMES[c]:>10} {intra_cos:>10.4f} {intra_std:>10.4f} "
374
+ f"{cv:>8.4f} {norm_mean:>8.2f}")
375
+
376
+ # Save per-class grid
377
+ for c in range(10):
378
+ grid = make_grid(all_generated[c][:16], nrow=4)
379
+ save_image(grid, f"analysis/class_{CLASS_NAMES[c]}.png")
380
+
381
+ # All classes grid
382
+ all_grid = torch.cat([imgs[:4] for imgs in all_generated])
383
+ save_image(make_grid(all_grid, nrow=10), "analysis/all_classes.png")
384
+ print(f"\n βœ“ Saved per-class grids to analysis/")
385
+
386
+
387
+ # ══════════════════════════════════════════════════════════════════
388
+ # TEST 6: VELOCITY FIELD ANALYSIS
389
+ # ═══════════════════════════════════════════════════════════��══════
390
+
391
+ print(f"\n{'━'*80}")
392
+ print("TEST 6: Velocity Field β€” how does v_pred behave across t?")
393
+ print(f"{'━'*80}")
394
+
395
+ images_v = images[:128]
396
+ labels_v = labels_dev[:128]
397
+
398
+ print(f" {'t':>6} {'v_norm':>10} {'v_std':>10} {'vΒ·target':>10} {'v_cos_t':>10}")
399
+
400
+ for t_val in [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]:
401
+ t = torch.full((128,), t_val, device=DEVICE)
402
+ eps = torch.randn_like(images_v)
403
+ t_b = t.view(-1, 1, 1, 1)
404
+ x_t = (1 - t_b) * images_v + t_b * eps
405
+ v_target = eps - images_v
406
+
407
+ with torch.no_grad():
408
+ v_pred = model(x_t, t, labels_v)
409
+
410
+ v_norm = v_pred.reshape(128, -1).norm(dim=-1).mean().item()
411
+ v_std = v_pred.std().item()
412
+ # Cosine between predicted and target velocity
413
+ v_cos = F.cosine_similarity(
414
+ v_pred.reshape(128, -1), v_target.reshape(128, -1)).mean().item()
415
+ # MSE
416
+ mse = F.mse_loss(v_pred, v_target).item()
417
+
418
+ print(f" {t_val:>6.2f} {v_norm:>10.2f} {v_std:>10.4f} "
419
+ f"{v_cos:>10.4f} {mse:>10.4f}")
420
+
421
+
422
+ # ══════════════════════════════════════════════════════════════════
423
+ # TEST 7: ABLATION β€” RELAY ON vs OFF
424
+ # ══════════════════════════════════════════════════════════════════
425
+
426
+ print(f"\n{'━'*80}")
427
+ print("TEST 7: Ablation β€” Relay ON vs OFF during generation")
428
+ print(f" Disable relay gates, measure generation difference")
429
+ print(f"{'━'*80}")
430
+
431
+ # Save original gate values
432
+ original_gates = {}
433
+ for name, relay in relays.items():
434
+ original_gates[name] = relay.gates.data.clone()
435
+
436
+ # Generate with relay ON
437
+ torch.manual_seed(123)
438
+ with torch.no_grad():
439
+ imgs_on = model.sample(n_samples=32, class_label=3)
440
+
441
+ # Disable relays (set gates to -100 β†’ sigmoid β‰ˆ 0)
442
+ for name, relay in relays.items():
443
+ relay.gates.data.fill_(-100.0)
444
+
445
+ # Generate with relay OFF (same seed)
446
+ torch.manual_seed(123)
447
+ with torch.no_grad():
448
+ imgs_off = model.sample(n_samples=32, class_label=3)
449
+
450
+ # Restore gates
451
+ for name, relay in relays.items():
452
+ relay.gates.data.copy_(original_gates[name])
453
+
454
+ # Compare
455
+ delta = (imgs_on - imgs_off)
456
+ pixel_diff = delta.abs().mean().item()
457
+ cos_diff = F.cosine_similarity(
458
+ imgs_on.reshape(32, -1), imgs_off.reshape(32, -1)).mean().item()
459
+
460
+ print(f" Relay ON β€” mean pixel: {imgs_on.mean():.4f} std: {imgs_on.std():.4f}")
461
+ print(f" Relay OFF β€” mean pixel: {imgs_off.mean():.4f} std: {imgs_off.std():.4f}")
462
+ print(f" Pixel diff: {pixel_diff:.6f}")
463
+ print(f" Cosine sim: {cos_diff:.6f}")
464
+ print(f" Max pixel Ξ”: {delta.abs().max():.6f}")
465
+
466
+ # Save comparison
467
+ comparison = torch.cat([imgs_on[:8], imgs_off[:8]], dim=0)
468
+ save_image(make_grid(comparison, nrow=8), "analysis/relay_ablation.png")
469
+ print(f" βœ“ Saved analysis/relay_ablation.png (top=ON, bottom=OFF)")
470
+
471
+
472
+ # ══════════════════════════════════════════════════════════════════
473
+ # TEST 8: ANCHOR CONSTELLATION STRUCTURE
474
+ # ══════════════════════════════════════════════════════════════════
475
+
476
+ print(f"\n{'━'*80}")
477
+ print("TEST 8: Anchor Constellation Structure")
478
+ print(f"{'━'*80}")
479
+
480
+ for name, relay in relays.items():
481
+ home = F.normalize(relay.home.detach().cpu(), dim=-1)
482
+ curr = F.normalize(relay.anchors.detach().cpu(), dim=-1)
483
+ P, A, d = home.shape
484
+
485
+ print(f"\n {name}:")
486
+
487
+ # Home vs current β€” did training move them?
488
+ home_curr_cos = (home * curr).sum(dim=-1) # (P, A)
489
+ print(f" Home↔Current cos: mean={home_curr_cos.mean():.6f} "
490
+ f"min={home_curr_cos.min():.6f}")
491
+
492
+ # Anchor spread β€” how well-distributed?
493
+ for p in range(min(4, P)):
494
+ cos_matrix = curr[p] @ curr[p].T # (A, A)
495
+ cos_matrix.fill_diagonal_(0)
496
+ print(f" Patch {p} anchor spread: "
497
+ f"mean_cos={cos_matrix.mean():.4f} "
498
+ f"max_cos={cos_matrix.max():.4f} "
499
+ f"min_cos={cos_matrix.min():.4f}")
500
+
501
+ # Effective anchor dimensionality
502
+ for p in range(min(4, P)):
503
+ _, S, _ = torch.linalg.svd(curr[p].float(), full_matrices=False)
504
+ pr = S / S.sum()
505
+ anchor_eff_dim = pr.pow(2).sum().reciprocal().item()
506
+ print(f" Patch {p} anchor eff_dim: {anchor_eff_dim:.1f} / {A}")
507
+
508
+
509
+ # ══════════════════════════════════════════════════════════════════
510
+ # TEST 9: SAMPLING TRAJECTORY β€” TRACK CV THROUGH ODE
511
+ # ════════════════════════════════════════════════���═════════════════
512
+
513
+ print(f"\n{'━'*80}")
514
+ print("TEST 9: Sampling Trajectory β€” CV through ODE steps")
515
+ print(f"{'━'*80}")
516
+
517
+ n_steps = 50
518
+ B_traj = 256
519
+
520
+ x = torch.randn(B_traj, 3, 32, 32, device=DEVICE)
521
+ labels_traj = torch.randint(0, 10, (B_traj,), device=DEVICE)
522
+ dt = 1.0 / n_steps
523
+
524
+ print(f" {'step':>6} {'t':>6} {'x_norm':>10} {'x_std':>10} {'CV_pixel':>10}")
525
+
526
+ checkpoints = [0, 1, 5, 10, 20, 30, 40, 49]
527
+ for step in range(n_steps):
528
+ t_val = 1.0 - step * dt
529
+ t = torch.full((B_traj,), t_val, device=DEVICE)
530
+
531
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
532
+ v = model(x, t, labels_traj)
533
+ x = x - v.float() * dt
534
+
535
+ if step in checkpoints:
536
+ x_flat = x.reshape(B_traj, -1)
537
+ norm = x_flat.norm(dim=-1).mean().item()
538
+ std = x.std().item()
539
+ cv = compute_cv(x_flat, n_samples=500)
540
+ print(f" {step:>6} {t_val:>6.2f} {norm:>10.2f} {std:>10.4f} {cv:>10.4f}")
541
+
542
+
543
+ # ══════════════════════════════════════════════════════════════════
544
+ # TEST 10: INTER-CLASS vs INTRA-CLASS GEOMETRY
545
+ # ══════════════════════════════════════════════════════════════════
546
+
547
+ print(f"\n{'━'*80}")
548
+ print("TEST 10: Inter-Class vs Intra-Class Separation")
549
+ print(f"{'━'*80}")
550
+
551
+ # Use generated images
552
+ class_means = []
553
+ for c in range(10):
554
+ flat = all_generated[c].reshape(64, -1)
555
+ class_means.append(F.normalize(flat.mean(dim=0, keepdim=True), dim=-1))
556
+
557
+ class_means = torch.cat(class_means, dim=0) # (10, 3072)
558
+ inter_sim = class_means @ class_means.T
559
+
560
+ print(f" Inter-class cosine similarity matrix:")
561
+ print(f" {'':>8}", end="")
562
+ for c in range(10):
563
+ print(f" {CLASS_NAMES[c][:4]:>5}", end="")
564
+ print()
565
+
566
+ for i in range(10):
567
+ print(f" {CLASS_NAMES[i]:>8}", end="")
568
+ for j in range(10):
569
+ val = inter_sim[i, j].item()
570
+ if i == j:
571
+ print(f" 1.0", end="")
572
+ else:
573
+ print(f" {val:>5.2f}", end="")
574
+ print()
575
+
576
+ # Intra vs inter
577
+ intra_sims = []
578
+ inter_sims = []
579
+ for c in range(10):
580
+ flat = F.normalize(all_generated[c].reshape(64, -1), dim=-1)
581
+ sim = flat @ flat.T
582
+ mask = ~torch.eye(64, device=DEVICE, dtype=torch.bool)
583
+ intra_sims.append(sim[mask].mean().item())
584
+
585
+ for i in range(10):
586
+ for j in range(i+1, 10):
587
+ flat_i = F.normalize(all_generated[i].reshape(64, -1), dim=-1)
588
+ flat_j = F.normalize(all_generated[j].reshape(64, -1), dim=-1)
589
+ cross = (flat_i @ flat_j.T).mean().item()
590
+ inter_sims.append(cross)
591
+
592
+ print(f"\n Intra-class cos: {np.mean(intra_sims):.4f} Β± {np.std(intra_sims):.4f}")
593
+ print(f" Inter-class cos: {np.mean(inter_sims):.4f} Β± {np.std(inter_sims):.4f}")
594
+ print(f" Separation ratio: {np.mean(intra_sims) / (np.mean(inter_sims) + 1e-8):.2f}Γ—")
595
+
596
+
597
+ # ══════════════════════════════════════════════════════════════════
598
+ # SUMMARY
599
+ # ══════════════════════════════════════════════════════════════════
600
+
601
+ print(f"\n{'='*80}")
602
+ print("ANALYSIS COMPLETE")
603
+ print(f"{'='*80}")
604
+ print(f"""
605
+ Files saved to analysis/:
606
+ - class_*.png: per-class generated samples
607
+ - all_classes.png: 4 samples per class, 10 columns
608
+ - relay_ablation.png: relay ON (top) vs OFF (bottom)
609
+
610
+ Key metrics to look for:
611
+ 1. Anchor drift β†’ did any converge near 0.29154?
612
+ 2. Gate values β†’ did they learn to open from init (0.047)?
613
+ 3. Per-class anchor utilization β†’ class-specific routing?
614
+ 4. Relay ablation β†’ does turning off the relay change generation?
615
+ 5. Intra/inter-class ratio β†’ > 1.0 means classes are separable
616
+ 6. Velocity cosine β†’ higher = better flow matching
617
+ 7. CV through ODE β†’ how does geometry evolve during generation?
618
+ """)
619
+ print(f"{'='*80}")