AbstractPhil commited on
Commit
91721ea
Β·
verified Β·
1 Parent(s): a240dad

Create cell_7_conduit_sweep_external_svd_correctly_applied.py

Browse files
cell_7_conduit_sweep_external_svd_correctly_applied.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cell 7 β€” Fresnel Conduit Sweep (CORRECT)
3
+ ==========================================
4
+ Uses solver='conduit' to capture telemetry from THE REAL
5
+ decomposition inside the forward pass. No circular reconstruction.
6
+
7
+ The friction, settle, and extraction order come from the actual
8
+ Gram matrices the encoder produces, decomposed by FLEighConduit
9
+ as the model runs.
10
+
11
+ 8 configurations through conv on 16Γ—16 spatial grids.
12
+ No pooling. No flattening. Respects geometric structure.
13
+
14
+ Channels:
15
+ S_orig: 4ch β€” raw eigenvalues (pre-cross-attention)
16
+ S_coord: 4ch β€” coordinated eigenvalues (post-cross-attention)
17
+ Friction: 4ch β€” log1p(friction) from real decomposition
18
+ Settle: 4ch β€” convergence iterations from real decomposition
19
+ Error: 1ch β€” per-patch reconstruction MSE
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ import time
27
+ from tqdm import tqdm
28
+
29
+ device = torch.device('cuda')
30
+
31
+ # ═══════════════════════════════════════════════════════════════
32
+ # LOAD FRESNEL WITH CONDUIT SOLVER
33
+ # ═══════════════════════════════════════════════════════════════
34
+
35
+ FRESNEL_VERSION = 'v50_fresnel_64'
36
+ IMG_SIZE = 64
37
+
38
+ print(f"Loading Fresnel ({FRESNEL_VERSION}) with conduit solver...")
39
+ from geolip_svae import load_model
40
+ from geolip_svae.model import extract_patches
41
+ import torchvision
42
+ import torchvision.transforms as T
43
+
44
+ fresnel, cfg = load_model(hf_version=FRESNEL_VERSION, device=device)
45
+ fresnel.eval()
46
+ fresnel.solver = 'conduit' # Enable conduit β€” telemetry from real decomposition
47
+
48
+ ps = fresnel.patch_size
49
+ gh, gw = IMG_SIZE // ps, IMG_SIZE // ps
50
+ D = fresnel.D
51
+ N = gh * gw
52
+
53
+ print(f" Patch size: {ps}, Grid: {gh}x{gw}, D={D}, Patches/image: {N}")
54
+ print(f" Solver: {fresnel.solver}")
55
+
56
+ # Verify conduit works
57
+ with torch.no_grad():
58
+ dummy = torch.randn(2, 3, IMG_SIZE, IMG_SIZE, device=device)
59
+ out = fresnel(dummy)
60
+ packet = fresnel.last_conduit_packet
61
+ assert packet is not None, "Conduit packet is None β€” solver not active"
62
+ print(f" Conduit packet shape: friction={packet.friction.shape}")
63
+ print(f" Expected: ({2 * N}, {D}) = ({2*N}, {D})")
64
+ print(f" CONDUIT ACTIVE βœ“")
65
+
66
+ # Load CIFAR-10
67
+ transform = T.Compose([T.Resize(IMG_SIZE), T.ToTensor()])
68
+ cifar_train = torchvision.datasets.CIFAR10(
69
+ root='/content/data', train=True, download=True, transform=transform)
70
+ cifar_test = torchvision.datasets.CIFAR10(
71
+ root='/content/data', train=False, download=True, transform=transform)
72
+
73
+ train_loader = torch.utils.data.DataLoader(
74
+ cifar_train, batch_size=128, shuffle=False, num_workers=4)
75
+ test_loader = torch.utils.data.DataLoader(
76
+ cifar_test, batch_size=128, shuffle=False, num_workers=4)
77
+
78
+ CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
79
+ 'dog', 'frog', 'horse', 'ship', 'truck']
80
+
81
+
82
+ # ═══════════════════════════════════════════════════════════════
83
+ # PRECOMPUTE β€” Extract real conduit telemetry from forward pass
84
+ # ═══════════════════════════════════════════════════════════════
85
+
86
+ def extract_all(loader, desc="Extracting"):
87
+ """Run Fresnel with conduit solver, capture everything."""
88
+ all_s_orig = []
89
+ all_s_coord = []
90
+ all_friction = []
91
+ all_settle = []
92
+ all_error = []
93
+ all_labels = []
94
+
95
+ for images, labels in tqdm(loader, desc=desc):
96
+ with torch.no_grad():
97
+ images_gpu = images.to(device)
98
+
99
+ # Forward pass β€” conduit captures from real decomposition
100
+ out = fresnel(images_gpu)
101
+ packet = fresnel.last_conduit_packet
102
+
103
+ B = images_gpu.shape[0]
104
+
105
+ # SVD outputs
106
+ S_orig = out['svd']['S_orig'] # (B, N, D) raw eigenvalues
107
+ S_coord = out['svd']['S'] # (B, N, D) cross-attention coordinated
108
+
109
+ # Conduit telemetry from the REAL decomposition
110
+ friction = packet.friction.reshape(B, N, D) # (B, N, D)
111
+ settle = packet.settle.reshape(B, N, D) # (B, N, D)
112
+
113
+ # Per-patch reconstruction error
114
+ recon = out['recon']
115
+ inp_p, _, _ = extract_patches(images_gpu, ps)
116
+ rec_p, _, _ = extract_patches(recon, ps)
117
+ patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) # (B, N)
118
+
119
+ # Reshape to spatial grids and move to CPU
120
+ all_s_orig.append(S_orig.reshape(B, gh, gw, D).cpu())
121
+ all_s_coord.append(S_coord.reshape(B, gh, gw, D).cpu())
122
+ all_friction.append(friction.reshape(B, gh, gw, D).cpu())
123
+ all_settle.append(settle.reshape(B, gh, gw, D).cpu())
124
+ all_error.append(patch_mse.reshape(B, gh, gw, 1).cpu())
125
+ all_labels.append(labels)
126
+
127
+ return {
128
+ 'S_orig': torch.cat(all_s_orig), # (N, gh, gw, 4)
129
+ 'S_coord': torch.cat(all_s_coord), # (N, gh, gw, 4)
130
+ 'friction': torch.cat(all_friction), # (N, gh, gw, 4)
131
+ 'settle': torch.cat(all_settle), # (N, gh, gw, 4)
132
+ 'error': torch.cat(all_error), # (N, gh, gw, 1)
133
+ 'labels': torch.cat(all_labels), # (N,)
134
+ }
135
+
136
+
137
+ print("\nPrecomputing train set (real conduit telemetry)...")
138
+ train_data = extract_all(train_loader, "Train")
139
+ print(f" Train: {len(train_data['labels'])} images")
140
+
141
+ print("Precomputing test set...")
142
+ test_data = extract_all(test_loader, "Test")
143
+ print(f" Test: {len(test_data['labels'])} images")
144
+
145
+
146
+ # ═══════════════════════════════════════════════════════════════
147
+ # SIGNAL PROFILE β€” What does the real conduit data look like?
148
+ # ═══════════════════════════════════════════════════════════════
149
+
150
+ print(f"\n{'=' * 70}")
151
+ print(" SIGNAL PROFILE β€” Real conduit telemetry from Fresnel")
152
+ print("=" * 70)
153
+
154
+ for key in ['S_orig', 'S_coord', 'friction', 'settle', 'error']:
155
+ t = train_data[key]
156
+ flat = t.reshape(t.shape[0], -1)
157
+ print(f" {key:10s}: mean={flat.mean():12.4f} std={flat.std():12.4f} "
158
+ f"min={flat.min():12.4f} max={flat.max():12.2f}")
159
+
160
+ # Log-friction profile
161
+ log_fric = torch.log1p(train_data['friction'])
162
+ flat_lf = log_fric.reshape(log_fric.shape[0], -1)
163
+ print(f" {'log1p_fric':10s}: mean={flat_lf.mean():12.4f} std={flat_lf.std():12.4f} "
164
+ f"min={flat_lf.min():12.4f} max={flat_lf.max():12.2f}")
165
+
166
+ # Per-class friction means
167
+ print(f"\n Per-class mean friction (raw):")
168
+ for c in range(10):
169
+ mask = train_data['labels'] == c
170
+ fm = train_data['friction'][mask].mean().item()
171
+ lm = torch.log1p(train_data['friction'][mask]).mean().item()
172
+ print(f" {CLASSES[c]:<10s}: raw={fm:10.2f} log1p={lm:6.3f}")
173
+
174
+ # Spatial CV of each signal
175
+ print(f"\n Spatial CV (per-image std/mean across 16x16 grid):")
176
+ for key in ['S_orig', 'friction', 'settle', 'error']:
177
+ t = train_data[key]
178
+ per_img = t.reshape(t.shape[0], -1)
179
+ cvs = per_img.std(dim=1) / (per_img.mean(dim=1).abs() + 1e-8)
180
+ print(f" {key:10s}: mean_CV={cvs.mean():.4f} median_CV={cvs.median():.4f}")
181
+
182
+
183
+ # ═══════════════════════════════════════════════════════════════
184
+ # CONV CLASSIFIER
185
+ # ═══════════════════════════════════════════════════════════════
186
+
187
+ class SpatialConvClassifier(nn.Module):
188
+ def __init__(self, in_channels, n_classes=10):
189
+ super().__init__()
190
+ self.conv = nn.Sequential(
191
+ nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), # 16β†’8
192
+ nn.GELU(),
193
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), # 8β†’4
194
+ nn.GELU(),
195
+ nn.Conv2d(128, 128, 3, stride=1, padding=1), # 4β†’4
196
+ nn.GELU(),
197
+ nn.AdaptiveAvgPool2d(1), # 4β†’1
198
+ )
199
+ self.head = nn.Sequential(
200
+ nn.Linear(128, 64),
201
+ nn.GELU(),
202
+ nn.Linear(64, n_classes),
203
+ )
204
+
205
+ def forward(self, x):
206
+ h = self.conv(x).squeeze(-1).squeeze(-1)
207
+ return self.head(h)
208
+
209
+
210
+ class ConduitDataset(torch.utils.data.Dataset):
211
+ def __init__(self, data, channels='O', augment=False):
212
+ """
213
+ Channel codes:
214
+ O = S_orig (raw eigenvalues, 4ch)
215
+ C = S_coord (cross-attention coordinated, 4ch)
216
+ F = friction (log1p transformed, 4ch)
217
+ T = settle (raw, 4ch)
218
+ E = error (per-patch recon MSE, 1ch)
219
+ """
220
+ self.labels = data['labels']
221
+ self.augment = augment
222
+ parts = []
223
+ if 'O' in channels:
224
+ parts.append(data['S_orig'].permute(0, 3, 1, 2))
225
+ if 'C' in channels:
226
+ parts.append(data['S_coord'].permute(0, 3, 1, 2))
227
+ if 'F' in channels:
228
+ # Log-compress friction: [4, 25M] β†’ [1.7, 17]
229
+ parts.append(torch.log1p(data['friction']).permute(0, 3, 1, 2))
230
+ if 'T' in channels:
231
+ parts.append(data['settle'].permute(0, 3, 1, 2))
232
+ if 'E' in channels:
233
+ parts.append(data['error'].permute(0, 3, 1, 2))
234
+ self.maps = torch.cat(parts, dim=1)
235
+ self.n_channels = self.maps.shape[1]
236
+
237
+ def __len__(self):
238
+ return len(self.labels)
239
+
240
+ def __getitem__(self, idx):
241
+ x = self.maps[idx]
242
+ if self.augment and torch.rand(1).item() > 0.5:
243
+ x = x.flip(-1)
244
+ return x, self.labels[idx]
245
+
246
+
247
+ # ═══════════════════════════════════════════════════════════════
248
+ # TRAINING
249
+ # ═══════════════════════════════════════════════════════════════
250
+
251
+ def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4):
252
+ train_ds = ConduitDataset(train_data, channels, augment=True)
253
+ test_ds = ConduitDataset(test_data, channels, augment=False)
254
+ n_ch = train_ds.n_channels
255
+
256
+ tr_loader = torch.utils.data.DataLoader(
257
+ train_ds, batch_size=batch_size, shuffle=True,
258
+ num_workers=4, pin_memory=True, drop_last=True)
259
+ te_loader = torch.utils.data.DataLoader(
260
+ test_ds, batch_size=batch_size, shuffle=False,
261
+ num_workers=4, pin_memory=True)
262
+
263
+ model = SpatialConvClassifier(n_ch, 10).to(device)
264
+ n_params = sum(p.numel() for p in model.parameters())
265
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
266
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
267
+
268
+ best_acc = 0
269
+ t0 = time.time()
270
+
271
+ for epoch in range(1, epochs + 1):
272
+ model.train()
273
+ correct, total = 0, 0
274
+ for x, y in tr_loader:
275
+ x, y = x.to(device), y.to(device)
276
+ logits = model(x)
277
+ loss = F.cross_entropy(logits, y)
278
+ opt.zero_grad()
279
+ loss.backward()
280
+ opt.step()
281
+ correct += (logits.argmax(-1) == y).sum().item()
282
+ total += len(y)
283
+ sched.step()
284
+ train_acc = correct / total
285
+
286
+ model.eval()
287
+ tc, tt = 0, 0
288
+ pcc = torch.zeros(10)
289
+ pct = torch.zeros(10)
290
+ with torch.no_grad():
291
+ for x, y in te_loader:
292
+ x, y = x.to(device), y.to(device)
293
+ preds = model(x).argmax(-1)
294
+ tc += (preds == y).sum().item()
295
+ tt += len(y)
296
+ for c in range(10):
297
+ m = y == c
298
+ pcc[c] += (preds[m] == y[m]).sum().item()
299
+ pct[c] += m.sum().item()
300
+
301
+ test_acc = tc / tt
302
+ if test_acc > best_acc:
303
+ best_acc = test_acc
304
+
305
+ if epoch % 5 == 0 or epoch == epochs:
306
+ print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}")
307
+
308
+ elapsed = time.time() - t0
309
+ pca = pcc / (pct + 1e-8)
310
+
311
+ print(f"\n {name}")
312
+ print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s")
313
+ print(f" Best test: {best_acc:.1%}")
314
+ print(f"\n {'Class':<10s} {'Acc':>6s}")
315
+ print(f" {'-' * 22}")
316
+ for c in range(10):
317
+ bar = 'β–ˆ' * int(pca[c] * 20)
318
+ print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}")
319
+ print()
320
+ return best_acc, n_params
321
+
322
+
323
+ # ═══════════════════════════════════════════════════════════════
324
+ # RUN ALL CONFIGURATIONS
325
+ # ═══════════════════════════════════════════════════════════════
326
+
327
+ print(f"\n{'=' * 70}")
328
+ print(" FRESNEL CONDUIT β€” Spatial Conv Readout (Real Decomposition)")
329
+ print("=" * 70)
330
+
331
+ results = {}
332
+
333
+ configs = [
334
+ ('O', "S_orig (raw eigenvalues) β€” 4ch"),
335
+ ('C', "S_coord (cross-attn coordinated) β€” 4ch"),
336
+ ('F', "Friction (log1p, real decomp) β€” 4ch"),
337
+ ('E', "Release error only β€” 1ch"),
338
+ ('T', "Settle only β€” 4ch"),
339
+ ('OF', "S_orig + Friction β€” 8ch"),
340
+ ('OE', "S_orig + Release β€” 5ch"),
341
+ ('OFE', "S_orig + Friction + Release β€” 9ch"),
342
+ ('OFET', "FULL CONDUIT β€” 13ch"),
343
+ ]
344
+
345
+ for channels, name in configs:
346
+ print(f"\n{'─' * 70}")
347
+ print(f" Training: {name}")
348
+ print(f"{'─' * 70}")
349
+ acc, params = train_and_eval(channels, name)
350
+ results[channels] = (acc, params, name)
351
+
352
+
353
+ # ═══════════════════════════════════════════════════════════════
354
+ # SCOREBOARD
355
+ # ═══════════════════════════════════════════════════════════════
356
+
357
+ print(f"\n{'=' * 70}")
358
+ print(f" SCOREBOARD β€” Fresnel Conduit (Real Decomposition Telemetry)")
359
+ print("=" * 70)
360
+
361
+ print(f"\n {'Configuration':<40s} {'Ch':>4s} {'Params':>10s} {'Test Acc':>9s}")
362
+ print(f" {'-' * 66}")
363
+ print(f" {'Chance':<40s} {'β€”':>4s} {'β€”':>10s} {'10.0%':>9s}")
364
+
365
+ for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]):
366
+ ds = ConduitDataset.__new__(ConduitDataset)
367
+ n_ch = sum([4 if c in 'OCFT' else 1 for c in channels])
368
+ print(f" {name:<40s} {n_ch:>4d} {params:>10,d} {acc:>8.1%}")
369
+
370
+ print(f"\n {'--- PREVIOUS (circular Gram, INVALID) ---'}")
371
+ print(f" {'Freckles friction conv (circular)':40s} {'4':>4s} {'232K':>10s} {'45.8%':>9s}")
372
+ print(f" {'Freckles S conv (circular)':40s} {'4':>4s} {'232K':>10s} {'20.9%':>9s}")
373
+
374
+ o_acc = results.get('O', (0, 0, ''))[0]
375
+ best_ch, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0])
376
+ print(f"\n S_orig only: {o_acc:.1%}")
377
+ print(f" Best conduit: {best_acc:.1%} ({best_name})")
378
+ print(f" Conduit lift: {(best_acc - o_acc) * 100:+.1f}pp")
379
+ print(f"\n THIS IS THE REAL TEST.")
380
+ print(f" Friction from the actual decomposition, not reconstructed Gram matrices.")