AbstractPhil commited on
Commit
69769dc
Β·
verified Β·
1 Parent(s): 1b9a2c2

Create cell_5_conduit_sweep.py

Browse files
Files changed (1) hide show
  1. cell_5_conduit_sweep.py +331 -0
cell_5_conduit_sweep.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cell 5 β€” Spatial Conv Readout on Conduit Maps
3
+ ===============================================
4
+ No pooling. No flattening. Conv reads the 16Γ—16 spatial grid directly.
5
+
6
+ This is the CORRECT way to evaluate whether conduit signals carry
7
+ class-discriminative information. The linear probe was wrong β€”
8
+ it destroyed the spatial structure that IS the signal.
9
+
10
+ Channels on the 16Γ—16 grid:
11
+ S values: 4 channels (eigenvalues per patch)
12
+ Friction: 4 channels (solver struggle per mode)
13
+ Release error: 1 channel (reconstruction fidelity per patch)
14
+ Settle: 4 channels (convergence speed per mode)
15
+
16
+ Test each signal alone and combined, all through conv readout.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+ import time
24
+ from tqdm import tqdm
25
+
26
+ device = torch.device('cuda')
27
+
28
+ # ═══════════════════════════════════════════════════════════════
29
+ # LOAD
30
+ # ═══════════════════════════════════════════════════════════════
31
+
32
+ print("Loading Freckles v40 + CIFAR-10...")
33
+ from geolip_svae import load_model
34
+ from geolip_svae.model import extract_patches
35
+ import torchvision
36
+ import torchvision.transforms as T
37
+ from geolip_core.linalg.conduit import FLEighConduit
38
+
39
+ freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
40
+ freckles.eval()
41
+
42
+ ps = freckles.patch_size
43
+ gh, gw = 64 // ps, 64 // ps
44
+ D = 4
45
+
46
+ transform = T.Compose([T.Resize(64), T.ToTensor()])
47
+ cifar_train = torchvision.datasets.CIFAR10(
48
+ root='/content/data', train=True, download=True, transform=transform)
49
+ cifar_test = torchvision.datasets.CIFAR10(
50
+ root='/content/data', train=False, download=True, transform=transform)
51
+
52
+ train_loader = torch.utils.data.DataLoader(
53
+ cifar_train, batch_size=128, shuffle=False, num_workers=4)
54
+ test_loader = torch.utils.data.DataLoader(
55
+ cifar_test, batch_size=128, shuffle=False, num_workers=4)
56
+
57
+ conduit = FLEighConduit().to(device)
58
+
59
+ CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
60
+ 'dog', 'frog', 'horse', 'ship', 'truck']
61
+
62
+
63
+ # ═══════════════════════════════════════════════════════════════
64
+ # PRECOMPUTE ALL CONDUIT MAPS
65
+ # ═══════════════════════════════════════════════════════════════
66
+
67
+ def extract_conduit_maps(loader, desc="Extracting"):
68
+ """Extract spatial conduit maps from all images.
69
+
70
+ Returns per image:
71
+ S_map: (gh, gw, 4) eigenvalues
72
+ friction_map:(gh, gw, 4) solver friction
73
+ settle_map: (gh, gw, 4) settle times
74
+ error_map: (gh, gw, 1) per-patch recon error
75
+ label: int
76
+ """
77
+ all_S = []
78
+ all_fric = []
79
+ all_settle = []
80
+ all_error = []
81
+ all_labels = []
82
+
83
+ for images, labels in tqdm(loader, desc=desc):
84
+ with torch.no_grad():
85
+ images_gpu = images.to(device)
86
+ out = freckles(images_gpu)
87
+ recon = out['recon']
88
+ S = out['svd']['S'] # (B, N, D)
89
+ Vt = out['svd']['Vt'] # (B, N, D, D)
90
+ B_img, N, _ = S.shape
91
+
92
+ # Per-patch recon error
93
+ inp_p, _, _ = extract_patches(images_gpu, ps)
94
+ rec_p, _, _ = extract_patches(recon, ps)
95
+ patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1) # (B, N)
96
+
97
+ # Gram matrices for conduit
98
+ S2 = S.pow(2)
99
+ G = torch.einsum('bnij,bnj,bnjk->bnik',
100
+ Vt.transpose(-2, -1), S2, Vt)
101
+ G_flat = G.reshape(B_img * N, D, D)
102
+ packet = conduit(G_flat)
103
+
104
+ # Reshape to spatial
105
+ all_S.append(S.reshape(B_img, gh, gw, D).cpu())
106
+ all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu())
107
+ all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu())
108
+ all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu())
109
+ all_labels.append(labels)
110
+
111
+ return {
112
+ 'S': torch.cat(all_S), # (N, gh, gw, 4)
113
+ 'friction': torch.cat(all_fric), # (N, gh, gw, 4)
114
+ 'settle': torch.cat(all_settle), # (N, gh, gw, 4)
115
+ 'error': torch.cat(all_error), # (N, gh, gw, 1)
116
+ 'labels': torch.cat(all_labels), # (N,)
117
+ }
118
+
119
+ print("\nPrecomputing train set...")
120
+ train_data = extract_conduit_maps(train_loader, "Train")
121
+ print(f" Train: {len(train_data['labels'])} images")
122
+
123
+ print("Precomputing test set...")
124
+ test_data = extract_conduit_maps(test_loader, "Test")
125
+ print(f" Test: {len(test_data['labels'])} images")
126
+
127
+
128
+ # ═════════════════════��═════════════════════════════════════════
129
+ # CONV CLASSIFIER β€” reads spatial maps directly
130
+ # ═══════════════════════════════════════════════════════════════
131
+
132
+ class SpatialConvClassifier(nn.Module):
133
+ """Conv readout on 16Γ—16 spatial maps. No pooling until final adaptive."""
134
+
135
+ def __init__(self, in_channels, n_classes=10):
136
+ super().__init__()
137
+ self.conv = nn.Sequential(
138
+ nn.Conv2d(in_channels, 64, 3, stride=2, padding=1), # 16β†’8
139
+ nn.GELU(),
140
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), # 8β†’4
141
+ nn.GELU(),
142
+ nn.Conv2d(128, 128, 3, stride=1, padding=1), # 4β†’4
143
+ nn.GELU(),
144
+ nn.AdaptiveAvgPool2d(1), # 4β†’1
145
+ )
146
+ self.head = nn.Sequential(
147
+ nn.Linear(128, 64),
148
+ nn.GELU(),
149
+ nn.Linear(64, n_classes),
150
+ )
151
+
152
+ def forward(self, x):
153
+ # x: (B, C, H, W)
154
+ h = self.conv(x).squeeze(-1).squeeze(-1)
155
+ return self.head(h)
156
+
157
+
158
+ class ConduitDataset(torch.utils.data.Dataset):
159
+ """Serves selected channels from precomputed conduit maps."""
160
+
161
+ def __init__(self, data, channels='S', augment=False):
162
+ self.labels = data['labels']
163
+ self.augment = augment
164
+
165
+ # Build channel tensor: (N, C, gh, gw)
166
+ parts = []
167
+ if 'S' in channels:
168
+ parts.append(data['S'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
169
+ if 'F' in channels:
170
+ parts.append(data['friction'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
171
+ if 'T' in channels:
172
+ parts.append(data['settle'].permute(0, 3, 1, 2)) # (N, 4, gh, gw)
173
+ if 'E' in channels:
174
+ parts.append(data['error'].permute(0, 3, 1, 2)) # (N, 1, gh, gw)
175
+
176
+ self.maps = torch.cat(parts, dim=1) # (N, total_C, gh, gw)
177
+ self.n_channels = self.maps.shape[1]
178
+
179
+ def __len__(self):
180
+ return len(self.labels)
181
+
182
+ def __getitem__(self, idx):
183
+ x = self.maps[idx]
184
+ label = self.labels[idx]
185
+
186
+ if self.augment:
187
+ if torch.rand(1).item() > 0.5:
188
+ x = x.flip(-1) # horizontal flip
189
+
190
+ return x, label
191
+
192
+
193
+ # ═══════════════════════════════════════════════════════════════
194
+ # TRAINING LOOP
195
+ # ═══════════════════════════════════════════════════════════════
196
+
197
+ def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4):
198
+ """Train conv classifier on specified conduit channels."""
199
+
200
+ train_ds = ConduitDataset(train_data, channels, augment=True)
201
+ test_ds = ConduitDataset(test_data, channels, augment=False)
202
+ n_ch = train_ds.n_channels
203
+
204
+ tr_loader = torch.utils.data.DataLoader(
205
+ train_ds, batch_size=batch_size, shuffle=True,
206
+ num_workers=4, pin_memory=True, drop_last=True)
207
+ te_loader = torch.utils.data.DataLoader(
208
+ test_ds, batch_size=batch_size, shuffle=False,
209
+ num_workers=4, pin_memory=True)
210
+
211
+ model = SpatialConvClassifier(n_ch, 10).to(device)
212
+ n_params = sum(p.numel() for p in model.parameters())
213
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
214
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
215
+
216
+ best_acc = 0
217
+ t0 = time.time()
218
+
219
+ for epoch in range(1, epochs + 1):
220
+ model.train()
221
+ correct, total = 0, 0
222
+ for x, y in tr_loader:
223
+ x, y = x.to(device), y.to(device)
224
+ logits = model(x)
225
+ loss = F.cross_entropy(logits, y)
226
+ opt.zero_grad()
227
+ loss.backward()
228
+ opt.step()
229
+ correct += (logits.argmax(-1) == y).sum().item()
230
+ total += len(y)
231
+ sched.step()
232
+ train_acc = correct / total
233
+
234
+ # Test
235
+ model.eval()
236
+ tc, tt = 0, 0
237
+ pcc = torch.zeros(10)
238
+ pct = torch.zeros(10)
239
+ with torch.no_grad():
240
+ for x, y in te_loader:
241
+ x, y = x.to(device), y.to(device)
242
+ preds = model(x).argmax(-1)
243
+ tc += (preds == y).sum().item()
244
+ tt += len(y)
245
+ for c in range(10):
246
+ m = y == c
247
+ pcc[c] += (preds[m] == y[m]).sum().item()
248
+ pct[c] += m.sum().item()
249
+
250
+ test_acc = tc / tt
251
+ if test_acc > best_acc:
252
+ best_acc = test_acc
253
+
254
+ if epoch % 5 == 0 or epoch == epochs:
255
+ print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}")
256
+
257
+ elapsed = time.time() - t0
258
+ pca = pcc / (pct + 1e-8)
259
+
260
+ print(f"\n {name}")
261
+ print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s")
262
+ print(f" Best test: {best_acc:.1%}")
263
+ print(f"\n {'Class':<10s} {'Acc':>6s}")
264
+ print(f" {'-' * 22}")
265
+ for c in range(10):
266
+ bar = 'β–ˆ' * int(pca[c] * 20)
267
+ print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}")
268
+ print()
269
+
270
+ return best_acc, n_params
271
+
272
+
273
+ # ═══════════════════════════════════════════════════════════════
274
+ # RUN ALL CONFIGURATIONS
275
+ # ═══════════════════════════════════════════════════════════════
276
+
277
+ print("\n" + "=" * 70)
278
+ print(" SPATIAL CONV READOUT β€” All conduit configurations")
279
+ print("=" * 70)
280
+
281
+ results = {}
282
+
283
+ configs = [
284
+ ('S', "Eigenvalues (S) only β€” 4ch"),
285
+ ('F', "Friction only β€” 4ch"),
286
+ ('E', "Release error only β€” 1ch"),
287
+ ('T', "Settle only β€” 4ch"),
288
+ ('SF', "S + Friction β€” 8ch"),
289
+ ('SE', "S + Release error β€” 5ch"),
290
+ ('SFE', "S + Friction + Release β€” 9ch"),
291
+ ('SFET', "FULL CONDUIT β€” 13ch"),
292
+ ]
293
+
294
+ for channels, name in configs:
295
+ print(f"\n{'─' * 70}")
296
+ print(f" Training: {name}")
297
+ print(f"{'─' * 70}")
298
+ acc, params = train_and_eval(channels, name)
299
+ results[channels] = (acc, params, name)
300
+
301
+
302
+ # ═══════════════════════════════════════════════════════════════
303
+ # SCOREBOARD
304
+ # ═══════════════════════════════════════════════════════════════
305
+
306
+ print(f"\n{'=' * 70}")
307
+ print(" SCOREBOARD β€” Spatial Conv Readout")
308
+ print("=" * 70)
309
+
310
+ print(f"\n {'Configuration':<35s} {'Channels':>8s} {'Params':>10s} {'Test Acc':>9s}")
311
+ print(f" {'-' * 65}")
312
+ print(f" {'Chance':<35s} {'β€”':>8s} {'β€”':>10s} {'10.0%':>9s}")
313
+
314
+ for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]):
315
+ print(f" {name:<35s} {channels:>8s} {params:>10,d} {acc:>8.1%}")
316
+
317
+ # Reference results from earlier experiments
318
+ print(f"\n {'--- REFERENCE (from earlier) ---':<35s}")
319
+ print(f" {'Linear probe (friction flat)':<35s} {'β€”':>8s} {'β€”':>10s} {'24.3%':>9s}")
320
+ print(f" {'Linear probe (S flat)':<35s} {'β€”':>8s} {'β€”':>10s} {'21.0%':>9s}")
321
+ print(f" {'Patchwork + calibrated embeds':<35s} {'β€”':>8s} {'530K':>10s} {'48.0%':>9s}")
322
+ print(f" {'Scatter + conv (raw S)':<35s} {'β€”':>8s} {'2.9M':>10s} {'70.5%':>9s}")
323
+ print(f" {'CNN condensed (SGD)':<35s} {'β€”':>8s} {'730K':>10s} {'74.7%':>9s}")
324
+
325
+ # Lift analysis
326
+ s_acc = results.get('S', (0, 0, ''))[0]
327
+ best_channels, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0])
328
+ print(f"\n S-only conv: {s_acc:.1%}")
329
+ print(f" Best conduit: {best_acc:.1%} ({best_name})")
330
+ print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp")
331
+ print(f" vs scatter+conv reference: {(best_acc - 0.705) * 100:+.1f}pp")