AbstractPhil commited on
Commit
651f747
Β·
verified Β·
1 Parent(s): 69769dc

Create cell_6_fresnel_freckles.py

Browse files
Files changed (1) hide show
  1. cell_6_fresnel_freckles.py +326 -0
cell_6_fresnel_freckles.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cell 6 β€” Fresnel Spatial Conv Readout
3
+ =======================================
4
+ Fresnel v50 β€” trained on CLEAN ImageNet-64. No noise.
5
+ The SVD learned real structural decomposition.
6
+
7
+ Same 8 conduit configurations through conv on the 16Γ—16 grid.
8
+ No pooling. No flattening. Spatial readout respects geometric structure.
9
+
10
+ CRITICAL DIFFERENCE FROM FRECKLES:
11
+ Freckles learned noise reconstruction features.
12
+ Fresnel learned clean image structural decomposition.
13
+ The SVD elements from Fresnel actually encode learned relational behavior.
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ import time
21
+ from tqdm import tqdm
22
+
23
+ device = torch.device('cuda')
24
+
25
+ # ═══════════════════════════════════════════════════════════════
26
+ # CONFIG β€” Set the Fresnel version here
27
+ # ═══════════════════════════════════════════════════════════════
28
+
29
+ FRESNEL_VERSION = 'v50_fresnel_64' # adjust if different checkpoint
30
+ IMG_SIZE = 64
31
+
32
+ # ═══════════════════════════════════════════════════════════════
33
+ # LOAD FRESNEL + CIFAR-10
34
+ # ═══════════════════════════════════════════════════════════════
35
+
36
+ print(f"Loading Fresnel ({FRESNEL_VERSION}) + CIFAR-10...")
37
+ from geolip_svae import load_model
38
+ from geolip_svae.model import extract_patches
39
+ import torchvision
40
+ import torchvision.transforms as T
41
+ from geolip_core.linalg.conduit import FLEighConduit
42
+
43
+ fresnel, cfg = load_model(hf_version=FRESNEL_VERSION, device=device)
44
+ fresnel.eval()
45
+
46
+ ps = fresnel.patch_size
47
+ gh, gw = IMG_SIZE // ps, IMG_SIZE // ps
48
+ D = cfg.get('D', 4) if isinstance(cfg, dict) else 4
49
+
50
+ print(f" Patch size: {ps}, Grid: {gh}x{gw}, D={D}")
51
+ print(f" Params: {sum(p.numel() for p in fresnel.parameters()):,}")
52
+
53
+ transform = T.Compose([T.Resize(IMG_SIZE), T.ToTensor()])
54
+ cifar_train = torchvision.datasets.CIFAR10(
55
+ root='/content/data', train=True, download=True, transform=transform)
56
+ cifar_test = torchvision.datasets.CIFAR10(
57
+ root='/content/data', train=False, download=True, transform=transform)
58
+
59
+ train_loader = torch.utils.data.DataLoader(
60
+ cifar_train, batch_size=128, shuffle=False, num_workers=4)
61
+ test_loader = torch.utils.data.DataLoader(
62
+ cifar_test, batch_size=128, shuffle=False, num_workers=4)
63
+
64
+ conduit = FLEighConduit().to(device)
65
+
66
+ CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
67
+ 'dog', 'frog', 'horse', 'ship', 'truck']
68
+
69
+ # Quick S statistics from Fresnel
70
+ print("\nFresnel S-value profile on CIFAR-10 sample:")
71
+ with torch.no_grad():
72
+ sample = next(iter(test_loader))[0][:16].to(device)
73
+ out = fresnel(sample)
74
+ S = out['svd']['S']
75
+ print(f" S mean: {S.mean(dim=(0,1)).tolist()}")
76
+ print(f" S std: {S.std(dim=(0,1)).tolist()}")
77
+ print(f" MSE: {F.mse_loss(out['recon'], sample):.6f}")
78
+
79
+
80
+ # ═══════════════════════════════════════════════════════════════
81
+ # PRECOMPUTE ALL CONDUIT MAPS
82
+ # ═══════════════════════════════════════════════════════════════
83
+
84
+ def extract_conduit_maps(loader, desc="Extracting"):
85
+ all_S = []
86
+ all_fric = []
87
+ all_settle = []
88
+ all_error = []
89
+ all_labels = []
90
+
91
+ for images, labels in tqdm(loader, desc=desc):
92
+ with torch.no_grad():
93
+ images_gpu = images.to(device)
94
+ out = fresnel(images_gpu)
95
+ recon = out['recon']
96
+ S = out['svd']['S']
97
+ Vt = out['svd']['Vt']
98
+ B_img, N, _ = S.shape
99
+
100
+ # Per-patch recon error
101
+ inp_p, _, _ = extract_patches(images_gpu, ps)
102
+ rec_p, _, _ = extract_patches(recon, ps)
103
+ patch_mse = (inp_p - rec_p).pow(2).mean(dim=-1)
104
+
105
+ # Gram matrices for conduit
106
+ S2 = S.pow(2)
107
+ G = torch.einsum('bnij,bnj,bnjk->bnik',
108
+ Vt.transpose(-2, -1), S2, Vt)
109
+ G_flat = G.reshape(B_img * N, D, D)
110
+ packet = conduit(G_flat)
111
+
112
+ all_S.append(S.reshape(B_img, gh, gw, D).cpu())
113
+ all_fric.append(packet.friction.reshape(B_img, gh, gw, D).cpu())
114
+ all_settle.append(packet.settle.reshape(B_img, gh, gw, D).cpu())
115
+ all_error.append(patch_mse.reshape(B_img, gh, gw, 1).cpu())
116
+ all_labels.append(labels)
117
+
118
+ return {
119
+ 'S': torch.cat(all_S),
120
+ 'friction': torch.cat(all_fric),
121
+ 'settle': torch.cat(all_settle),
122
+ 'error': torch.cat(all_error),
123
+ 'labels': torch.cat(all_labels),
124
+ }
125
+
126
+ print("\nPrecomputing train set...")
127
+ train_data = extract_conduit_maps(train_loader, "Train")
128
+ print(f" Train: {len(train_data['labels'])} images")
129
+
130
+ print("Precomputing test set...")
131
+ test_data = extract_conduit_maps(test_loader, "Test")
132
+ print(f" Test: {len(test_data['labels'])} images")
133
+
134
+ # Signal profile
135
+ print(f"\n Fresnel signal profile:")
136
+ for key in ['S', 'friction', 'settle', 'error']:
137
+ t = train_data[key]
138
+ flat = t.reshape(t.shape[0], -1)
139
+ print(f" {key:10s}: mean={flat.mean():.4f} std={flat.std():.4f} "
140
+ f"min={flat.min():.4f} max={flat.max():.4f}")
141
+
142
+
143
+ # ═══════════════════════════════════════════════════════════════
144
+ # CONV CLASSIFIER
145
+ # ═══════════════════════════════════════════════════════════════
146
+
147
+ class SpatialConvClassifier(nn.Module):
148
+ def __init__(self, in_channels, n_classes=10):
149
+ super().__init__()
150
+ self.conv = nn.Sequential(
151
+ nn.Conv2d(in_channels, 64, 3, stride=2, padding=1),
152
+ nn.GELU(),
153
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
154
+ nn.GELU(),
155
+ nn.Conv2d(128, 128, 3, stride=1, padding=1),
156
+ nn.GELU(),
157
+ nn.AdaptiveAvgPool2d(1),
158
+ )
159
+ self.head = nn.Sequential(
160
+ nn.Linear(128, 64),
161
+ nn.GELU(),
162
+ nn.Linear(64, n_classes),
163
+ )
164
+
165
+ def forward(self, x):
166
+ h = self.conv(x).squeeze(-1).squeeze(-1)
167
+ return self.head(h)
168
+
169
+
170
+ class ConduitDataset(torch.utils.data.Dataset):
171
+ def __init__(self, data, channels='S', augment=False):
172
+ self.labels = data['labels']
173
+ self.augment = augment
174
+ parts = []
175
+ if 'S' in channels:
176
+ parts.append(data['S'].permute(0, 3, 1, 2))
177
+ if 'F' in channels:
178
+ parts.append(data['friction'].permute(0, 3, 1, 2))
179
+ if 'T' in channels:
180
+ parts.append(data['settle'].permute(0, 3, 1, 2))
181
+ if 'E' in channels:
182
+ parts.append(data['error'].permute(0, 3, 1, 2))
183
+ self.maps = torch.cat(parts, dim=1)
184
+ self.n_channels = self.maps.shape[1]
185
+
186
+ def __len__(self):
187
+ return len(self.labels)
188
+
189
+ def __getitem__(self, idx):
190
+ x = self.maps[idx]
191
+ if self.augment and torch.rand(1).item() > 0.5:
192
+ x = x.flip(-1)
193
+ return x, self.labels[idx]
194
+
195
+
196
+ # ═══════════════════════════════════════════════════════════════
197
+ # TRAINING
198
+ # ═══════════════════════════════════════════════════════════════
199
+
200
+ def train_and_eval(channels, name, epochs=30, batch_size=128, lr=3e-4):
201
+ train_ds = ConduitDataset(train_data, channels, augment=True)
202
+ test_ds = ConduitDataset(test_data, channels, augment=False)
203
+ n_ch = train_ds.n_channels
204
+
205
+ tr_loader = torch.utils.data.DataLoader(
206
+ train_ds, batch_size=batch_size, shuffle=True,
207
+ num_workers=4, pin_memory=True, drop_last=True)
208
+ te_loader = torch.utils.data.DataLoader(
209
+ test_ds, batch_size=batch_size, shuffle=False,
210
+ num_workers=4, pin_memory=True)
211
+
212
+ model = SpatialConvClassifier(n_ch, 10).to(device)
213
+ n_params = sum(p.numel() for p in model.parameters())
214
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
215
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
216
+
217
+ best_acc = 0
218
+ t0 = time.time()
219
+
220
+ for epoch in range(1, epochs + 1):
221
+ model.train()
222
+ correct, total = 0, 0
223
+ for x, y in tr_loader:
224
+ x, y = x.to(device), y.to(device)
225
+ logits = model(x)
226
+ loss = F.cross_entropy(logits, y)
227
+ opt.zero_grad()
228
+ loss.backward()
229
+ opt.step()
230
+ correct += (logits.argmax(-1) == y).sum().item()
231
+ total += len(y)
232
+ sched.step()
233
+ train_acc = correct / total
234
+
235
+ model.eval()
236
+ tc, tt = 0, 0
237
+ pcc = torch.zeros(10)
238
+ pct = torch.zeros(10)
239
+ with torch.no_grad():
240
+ for x, y in te_loader:
241
+ x, y = x.to(device), y.to(device)
242
+ preds = model(x).argmax(-1)
243
+ tc += (preds == y).sum().item()
244
+ tt += len(y)
245
+ for c in range(10):
246
+ m = y == c
247
+ pcc[c] += (preds[m] == y[m]).sum().item()
248
+ pct[c] += m.sum().item()
249
+
250
+ test_acc = tc / tt
251
+ if test_acc > best_acc:
252
+ best_acc = test_acc
253
+
254
+ if epoch % 5 == 0 or epoch == epochs:
255
+ print(f" ep{epoch:3d} train={train_acc:.1%} test={test_acc:.1%}")
256
+
257
+ elapsed = time.time() - t0
258
+ pca = pcc / (pct + 1e-8)
259
+
260
+ print(f"\n {name}")
261
+ print(f" Channels: {n_ch}, Params: {n_params:,}, Time: {elapsed:.0f}s")
262
+ print(f" Best test: {best_acc:.1%}")
263
+ print(f"\n {'Class':<10s} {'Acc':>6s}")
264
+ print(f" {'-' * 22}")
265
+ for c in range(10):
266
+ bar = 'β–ˆ' * int(pca[c] * 20)
267
+ print(f" {CLASSES[c]:<10s} {pca[c]:5.1%} {bar}")
268
+ print()
269
+ return best_acc, n_params
270
+
271
+
272
+ # ═══════════════════════════════════════════════════════════════
273
+ # RUN ALL CONFIGURATIONS
274
+ # ═══════════════════════════════════════════════════════════════
275
+
276
+ print("\n" + "=" * 70)
277
+ print(" FRESNEL β€” Spatial Conv Readout β€” All Conduit Configurations")
278
+ print("=" * 70)
279
+
280
+ results = {}
281
+
282
+ configs = [
283
+ ('S', "Eigenvalues (S) only β€” 4ch"),
284
+ ('F', "Friction only β€” 4ch"),
285
+ ('E', "Release error only β€” 1ch"),
286
+ ('T', "Settle only β€” 4ch"),
287
+ ('SF', "S + Friction β€” 8ch"),
288
+ ('SE', "S + Release error β€” 5ch"),
289
+ ('SFE', "S + Friction + Release β€” 9ch"),
290
+ ('SFET', "FULL CONDUIT β€” 13ch"),
291
+ ]
292
+
293
+ for channels, name in configs:
294
+ print(f"\n{'─' * 70}")
295
+ print(f" Training: {name}")
296
+ print(f"{'─' * 70}")
297
+ acc, params = train_and_eval(channels, name)
298
+ results[channels] = (acc, params, name)
299
+
300
+
301
+ # ═══════════════════════════════════════════════════════════════
302
+ # SCOREBOARD
303
+ # ═══════════════════════════════════════════════════════════════
304
+
305
+ print(f"\n{'=' * 70}")
306
+ print(f" SCOREBOARD β€” Fresnel ({FRESNEL_VERSION}) Spatial Conv Readout")
307
+ print("=" * 70)
308
+
309
+ print(f"\n {'Configuration':<35s} {'Ch':>4s} {'Params':>10s} {'Test Acc':>9s}")
310
+ print(f" {'-' * 62}")
311
+ print(f" {'Chance':<35s} {'β€”':>4s} {'β€”':>10s} {'10.0%':>9s}")
312
+
313
+ for channels, (acc, params, name) in sorted(results.items(), key=lambda x: x[1][0]):
314
+ n_ch = sum([4 if c in 'SFT' else 1 for c in channels])
315
+ print(f" {name:<35s} {n_ch:>4d} {params:>10,d} {acc:>8.1%}")
316
+
317
+ print(f"\n {'--- FRECKLES REFERENCE ---':<35s}")
318
+ print(f" {'Scatter + conv (Freckles S)':<35s} {'4':>4s} {'2.9M':>10s} {'70.5%':>9s}")
319
+
320
+ s_acc = results.get('S', (0, 0, ''))[0]
321
+ best_ch, (best_acc, _, best_name) = max(results.items(), key=lambda x: x[1][0])
322
+ print(f"\n Fresnel S-only: {s_acc:.1%}")
323
+ print(f" Best conduit: {best_acc:.1%} ({best_name})")
324
+ print(f" Conduit lift: {(best_acc - s_acc) * 100:+.1f}pp")
325
+ print(f"\n KEY QUESTION: Does Fresnel's clean training produce")
326
+ print(f" conduit signals that Freckles' noise training could not?")