AbstractPhil commited on
Commit
386700f
·
verified ·
1 Parent(s): e50ad6c

Create constellation_a_test_trainer_cifar10.py

Browse files
constellation_a_test_trainer_cifar10.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoLIP Core — Back to Basics
4
+ ==============================
5
+ Conv encoder → sphere → ConstellationCore → classifier.
6
+
7
+ Two augmented views → InfoNCE + CE + attract + CV + spread.
8
+ Anchor push every N batches (self-distillation across time).
9
+
10
+ Uses constellation.py for all geometric components.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import os, time
17
+ from tqdm import tqdm
18
+ from torchvision import datasets, transforms
19
+ from torch.utils.tensorboard import SummaryWriter
20
+
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.allow_tf32 = True
24
+
25
+
26
+ # ══════════════════════════════════════════════════════════════════
27
+ # CONV ENCODER
28
+ # ══════════════════════════════════════════════════════════════════
29
+
30
+ class ConvEncoder(nn.Module):
31
+ """6-layer conv backbone → flat vector → project → LayerNorm."""
32
+ def __init__(self, output_dim=192):
33
+ super().__init__()
34
+ self.features = nn.Sequential(
35
+ nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
36
+ nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
37
+ nn.MaxPool2d(2),
38
+
39
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
40
+ nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
41
+ nn.MaxPool2d(2),
42
+
43
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
44
+ nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
45
+ nn.MaxPool2d(2),
46
+
47
+ nn.AdaptiveAvgPool2d(1),
48
+ nn.Flatten(),
49
+ )
50
+ self.proj = nn.Sequential(
51
+ nn.Linear(256, output_dim),
52
+ nn.LayerNorm(output_dim),
53
+ )
54
+
55
+ def forward(self, x):
56
+ return self.proj(self.features(x))
57
+
58
+
59
+ # ══════════════════════════════════════════════════════════════════
60
+ # GEOLIP CORE — encoder + constellation pipeline
61
+ # ══════════════════════════════════════════════════════════════════
62
+
63
+ class GeoLIPCore(nn.Module):
64
+ """Conv encoder → L2 normalize → ConstellationCore.
65
+
66
+ The encoder is the only component that sees pixels.
67
+ Everything after normalization is geometric.
68
+ """
69
+ def __init__(self, num_classes=10, output_dim=192,
70
+ n_anchors=64, n_comp=8, d_comp=64,
71
+ anchor_drop=0.15, activation='squared_relu',
72
+ cv_target=0.22, infonce_temp=0.07):
73
+ super().__init__()
74
+ self.output_dim = output_dim
75
+
76
+ self.config = {k: v for k, v in locals().items()
77
+ if k != 'self' and not k.startswith('_')}
78
+
79
+ self.encoder = ConvEncoder(output_dim)
80
+ self.core = ConstellationCore(
81
+ num_classes=num_classes,
82
+ dim=output_dim,
83
+ n_anchors=n_anchors,
84
+ n_comp=n_comp,
85
+ d_comp=d_comp,
86
+ anchor_drop=anchor_drop,
87
+ activation=activation,
88
+ cv_target=cv_target,
89
+ infonce_temp=infonce_temp,
90
+ )
91
+
92
+ self._init_encoder_weights()
93
+
94
+ def _init_encoder_weights(self):
95
+ for m in self.encoder.modules():
96
+ if isinstance(m, nn.Linear):
97
+ nn.init.trunc_normal_(m.weight, std=0.02)
98
+ if m.bias is not None:
99
+ nn.init.zeros_(m.bias)
100
+ elif isinstance(m, nn.Conv2d):
101
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
102
+ if m.bias is not None:
103
+ nn.init.zeros_(m.bias)
104
+ elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
105
+ nn.init.ones_(m.weight)
106
+ nn.init.zeros_(m.bias)
107
+
108
+ def forward(self, x):
109
+ feat = self.encoder(x)
110
+ emb = F.normalize(feat, dim=-1)
111
+ return self.core(emb)
112
+
113
+ def compute_loss(self, output, targets, output_aug=None):
114
+ return self.core.compute_loss(output, targets, output_aug)
115
+
116
+ def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
117
+ return self.core.push_anchors_to_centroids(emb_buffer, label_buffer, lr)
118
+
119
+
120
+ # ══════════════════════════════════════════════════════════════════
121
+ # DATA
122
+ # ════════════════════════════════════════════���═════════════════════
123
+
124
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
125
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
126
+
127
+
128
+ class TwoViewDataset(torch.utils.data.Dataset):
129
+ def __init__(self, base_ds, transform):
130
+ self.base = base_ds
131
+ self.transform = transform
132
+ def __len__(self):
133
+ return len(self.base)
134
+ def __getitem__(self, i):
135
+ img, label = self.base[i]
136
+ return self.transform(img), self.transform(img), label
137
+
138
+
139
+ # ══════════════════════════════════════════════════════════════════
140
+ # TRAINING
141
+ # ══════════════════════════════════════════════════════════════════
142
+
143
+ # Config
144
+ NUM_CLASSES = 10
145
+ OUTPUT_DIM = 256
146
+ N_ANCHORS = 64
147
+ N_COMP = 8
148
+ D_COMP = 64
149
+ BATCH = 256
150
+ EPOCHS = 100
151
+ LR = 3e-4
152
+ ACTIVATION = 'squared_relu'
153
+
154
+ # Anchor push config
155
+ PUSH_INTERVAL = 100
156
+ PUSH_LR = 0.1
157
+ PUSH_BUFFER_SIZE = 5000
158
+
159
+ print("=" * 60)
160
+ print("GeoLIP Core — Conv + ConstellationCore")
161
+ print(f" Encoder: 6-layer conv → {OUTPUT_DIM}-d sphere")
162
+ print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}×{D_COMP} patchwork")
163
+ print(f" Activation: {ACTIVATION}")
164
+ print(f" Loss: CE + InfoNCE + attract + CV(0.22) + spread")
165
+ print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}")
166
+ print(f" Push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}")
167
+ print(f" Device: {DEVICE}")
168
+ print("=" * 60)
169
+
170
+ aug_transform = transforms.Compose([
171
+ transforms.RandomCrop(32, padding=4),
172
+ transforms.RandomHorizontalFlip(),
173
+ transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
174
+ transforms.ToTensor(),
175
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
176
+ ])
177
+ val_transform = transforms.Compose([
178
+ transforms.ToTensor(),
179
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
180
+ ])
181
+
182
+ raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
183
+ train_ds = TwoViewDataset(raw_train, aug_transform)
184
+ val_ds = datasets.CIFAR10(root='./data', train=False,
185
+ download=True, transform=val_transform)
186
+
187
+ train_loader = torch.utils.data.DataLoader(
188
+ train_ds, batch_size=BATCH, shuffle=True,
189
+ num_workers=2, pin_memory=True, drop_last=True)
190
+ val_loader = torch.utils.data.DataLoader(
191
+ val_ds, batch_size=BATCH, shuffle=False,
192
+ num_workers=2, pin_memory=True)
193
+
194
+ print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}")
195
+
196
+ # Build
197
+ model = GeoLIPCore(
198
+ num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM,
199
+ n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
200
+ activation=ACTIVATION,
201
+ ).to(DEVICE)
202
+
203
+ n_params = sum(p.numel() for p in model.parameters())
204
+ n_enc = sum(p.numel() for p in model.encoder.parameters())
205
+ n_core = sum(p.numel() for p in model.core.parameters())
206
+ print(f" Parameters: {n_params:,} (encoder: {n_enc:,}, core: {n_core:,})")
207
+
208
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
209
+ total_steps = len(train_loader) * EPOCHS
210
+ warmup_steps = len(train_loader) * 3
211
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
212
+ optimizer,
213
+ [torch.optim.lr_scheduler.LinearLR(
214
+ optimizer, start_factor=0.01, total_iters=warmup_steps),
215
+ torch.optim.lr_scheduler.CosineAnnealingLR(
216
+ optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
217
+ milestones=[warmup_steps])
218
+
219
+ scaler = torch.amp.GradScaler("cuda")
220
+ os.makedirs("checkpoints", exist_ok=True)
221
+ writer = SummaryWriter("runs/geolip_core")
222
+ best_acc = 0.0
223
+ gs = 0
224
+
225
+ emb_buffer = None
226
+ lbl_buffer = None
227
+ push_count = 0
228
+
229
+ print(f"\n{'='*60}")
230
+ print(f"TRAINING — {EPOCHS} epochs")
231
+ print(f"{'='*60}")
232
+
233
+ for epoch in range(EPOCHS):
234
+ model.train()
235
+ t0 = time.time()
236
+ tot_loss, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0
237
+ correct, total = 0, 0
238
+
239
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
240
+ for v1, v2, targets in pbar:
241
+ v1 = v1.to(DEVICE, non_blocking=True)
242
+ v2 = v2.to(DEVICE, non_blocking=True)
243
+ targets = targets.to(DEVICE, non_blocking=True)
244
+
245
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
246
+ out1 = model(v1)
247
+ out2 = model(v2)
248
+ loss, ld = model.compute_loss(out1, targets, output_aug=out2)
249
+
250
+ optimizer.zero_grad(set_to_none=True)
251
+ scaler.scale(loss).backward()
252
+ scaler.unscale_(optimizer)
253
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
254
+ scaler.step(optimizer)
255
+ scaler.update()
256
+ scheduler.step()
257
+ gs += 1
258
+
259
+ # Accumulate embeddings for anchor push
260
+ with torch.no_grad():
261
+ batch_emb = out1['embedding'].detach().float()
262
+ if emb_buffer is None:
263
+ emb_buffer = batch_emb
264
+ lbl_buffer = targets.detach()
265
+ else:
266
+ emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:]
267
+ lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:]
268
+
269
+ # Periodic anchor push
270
+ if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500:
271
+ moved = model.push_anchors_to_centroids(
272
+ emb_buffer, lbl_buffer, lr=PUSH_LR)
273
+ push_count += 1
274
+ writer.add_scalar("step/anchors_moved", moved, gs)
275
+
276
+ preds = out1['logits'].argmax(-1)
277
+ correct += (preds == targets).sum().item()
278
+ total += targets.shape[0]
279
+ tot_loss += loss.item()
280
+ tot_nce_acc += ld.get('nce_acc', 0)
281
+ tot_nearest_cos += ld.get('nearest_cos', 0)
282
+ n += 1
283
+
284
+ if n % 10 == 0:
285
+ pbar.set_postfix(
286
+ loss=f"{tot_loss/n:.4f}",
287
+ acc=f"{100*correct/total:.0f}%",
288
+ nce=f"{tot_nce_acc/n:.2f}",
289
+ cos=f"{ld.get('nearest_cos', 0):.3f}",
290
+ push=push_count,
291
+ ordered=True)
292
+
293
+ elapsed = time.time() - t0
294
+ train_acc = 100 * correct / total
295
+
296
+ # Val
297
+ model.eval()
298
+ vc, vt_n = 0, 0
299
+ all_embs = []
300
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
301
+ for imgs, lbls in val_loader:
302
+ imgs = imgs.to(DEVICE)
303
+ lbls = lbls.to(DEVICE)
304
+ out = model(imgs)
305
+ vc += (out['logits'].argmax(-1) == lbls).sum().item()
306
+ vt_n += lbls.shape[0]
307
+ all_embs.append(out['embedding'].float().cpu())
308
+
309
+ val_acc = 100 * vc / vt_n
310
+
311
+ # CV metric
312
+ embs = torch.cat(all_embs)[:2000].to(DEVICE)
313
+ with torch.no_grad():
314
+ v_cv = GeometricOps.cv_metric(embs, n_samples=200)
315
+ diag = GeometricOps.diagnostics(model.core.constellation, embs)
316
+
317
+ writer.add_scalar("epoch/train_acc", train_acc, epoch + 1)
318
+ writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
319
+ writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
320
+ writer.add_scalar("epoch/anchors", diag['n_active'], epoch + 1)
321
+ writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch + 1)
322
+ writer.add_scalar("epoch/push_count", push_count, epoch + 1)
323
+
324
+ mk = ""
325
+ if val_acc > best_acc:
326
+ best_acc = val_acc
327
+ torch.save({
328
+ "state_dict": model.state_dict(),
329
+ "config": model.config,
330
+ "epoch": epoch + 1,
331
+ "val_acc": val_acc,
332
+ }, "checkpoints/geolip_core_best.pt")
333
+ mk = " ★"
334
+
335
+ nce_m = tot_nce_acc / n
336
+ cos_m = tot_nearest_cos / n
337
+ cv_band = "✓" if 0.18 <= v_cv <= 0.25 else "✗"
338
+ print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
339
+ f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} "
340
+ f"cv={v_cv:.4f}({cv_band}) anch={diag['n_active']}/{N_ANCHORS} "
341
+ f"push={push_count} ({elapsed:.0f}s){mk}")
342
+
343
+ writer.close()
344
+ print(f"\n Best val accuracy: {best_acc:.1f}%")
345
+ print(f" Parameters: {n_params:,}")
346
+ print(f"\n{'='*60}")
347
+ print("DONE")
348
+ print(f"{'='*60}")